From 7468723697e4d292315ce807b5000062a02b37be Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 16 Nov 2021 08:47:36 -0500 Subject: Add most missing type hints to synapse.util (#11328) --- synapse/util/caches/__init__.py | 32 +++++++++-------- synapse/util/caches/deferred_cache.py | 2 +- synapse/util/caches/descriptors.py | 67 ++++++++++++++++++++++------------- synapse/util/caches/expiringcache.py | 10 +++--- 4 files changed, 66 insertions(+), 45 deletions(-) (limited to 'synapse/util/caches') diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index df4d61e4b6..15debd6c46 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -17,7 +17,7 @@ import logging import typing from enum import Enum, auto from sys import intern -from typing import Callable, Dict, Optional, Sized +from typing import Any, Callable, Dict, List, Optional, Sized import attr from prometheus_client.core import Gauge @@ -58,20 +58,20 @@ class EvictionReason(Enum): time = auto() -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class CacheMetric: - _cache = attr.ib() - _cache_type = attr.ib(type=str) - _cache_name = attr.ib(type=str) - _collect_callback = attr.ib(type=Optional[Callable]) + _cache: Sized + _cache_type: str + _cache_name: str + _collect_callback: Optional[Callable] - hits = attr.ib(default=0) - misses = attr.ib(default=0) + hits: int = 0 + misses: int = 0 eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib( factory=collections.Counter ) - memory_usage = attr.ib(default=None) + memory_usage: Optional[int] = None def inc_hits(self) -> None: self.hits += 1 @@ -89,13 +89,14 @@ class CacheMetric: self.memory_usage += memory def dec_memory_usage(self, memory: int) -> None: + assert self.memory_usage is not None self.memory_usage -= memory def clear_memory_usage(self) -> None: if self.memory_usage is not None: self.memory_usage = 0 - def describe(self): + def describe(self) -> List[str]: return [] def collect(self) -> None: @@ -118,8 +119,9 @@ class CacheMetric: self.eviction_size_by_reason[reason] ) cache_total.labels(self._cache_name).set(self.hits + self.misses) - if getattr(self._cache, "max_size", None): - cache_max_size.labels(self._cache_name).set(self._cache.max_size) + max_size = getattr(self._cache, "max_size", None) + if max_size: + cache_max_size.labels(self._cache_name).set(max_size) if TRACK_MEMORY_USAGE: # self.memory_usage can be None if nothing has been inserted @@ -193,7 +195,7 @@ KNOWN_KEYS = { } -def intern_string(string): +def intern_string(string: Optional[str]) -> Optional[str]: """Takes a (potentially) unicode string and interns it if it's ascii""" if string is None: return None @@ -204,7 +206,7 @@ def intern_string(string): return string -def intern_dict(dictionary): +def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]: """Takes a dictionary and interns well known keys and their values""" return { KNOWN_KEYS.get(key, key): _intern_known_values(key, value) @@ -212,7 +214,7 @@ def intern_dict(dictionary): } -def _intern_known_values(key, value): +def _intern_known_values(key: str, value: Any) -> Any: intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key") if key in intern_keys: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index da502aec11..3c4cc093af 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -289,7 +289,7 @@ class DeferredCache(Generic[KT, VT]): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) - def invalidate(self, key) -> None: + def invalidate(self, key: KT) -> None: """Delete a key, or tree of entries If the cache is backed by a regular dict, then "key" must be of diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index b9dcca17f1..375cd443f1 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -19,12 +19,15 @@ import logging from typing import ( Any, Callable, + Dict, Generic, + Hashable, Iterable, Mapping, Optional, Sequence, Tuple, + Type, TypeVar, Union, cast, @@ -32,6 +35,7 @@ from typing import ( from weakref import WeakValueDictionary from twisted.internet import defer +from twisted.python.failure import Failure from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError @@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]): class _CacheDescriptorBase: - def __init__(self, orig: Callable[..., Any], num_args, cache_context=False): + def __init__( + self, + orig: Callable[..., Any], + num_args: Optional[int], + cache_context: bool = False, + ): self.orig = orig arg_spec = inspect.getfullargspec(orig) @@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase): def __init__( self, - orig, + orig: Callable[..., Any], max_entries: int = 1000, cache_context: bool = False, ): super().__init__(orig, num_args=None, cache_context=cache_context) self.max_entries = max_entries - def __get__(self, obj, owner): + def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: LruCache[CacheKey, Any] = LruCache( cache_name=self.orig.__name__, max_size=self.max_entries, @@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): sentinel = LruCacheDescriptor._Sentinel.sentinel @functools.wraps(self.orig) - def _wrapped(*args, **kwargs): + def _wrapped(*args: Any, **kwargs: Any) -> Any: invalidate_callback = kwargs.pop("on_invalidate", None) callbacks = (invalidate_callback,) if invalidate_callback else () @@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): return r1 + r2 Args: - num_args (int): number of positional arguments (excluding ``self`` and + num_args: number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. """ def __init__( self, - orig, - max_entries=1000, - num_args=None, - tree=False, - cache_context=False, - iterable=False, + orig: Callable[..., Any], + max_entries: int = 1000, + num_args: Optional[int] = None, + tree: bool = False, + cache_context: bool = False, + iterable: bool = False, prune_unread_entries: bool = True, ): super().__init__(orig, num_args=num_args, cache_context=cache_context) @@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): self.iterable = iterable self.prune_unread_entries = prune_unread_entries - def __get__(self, obj, owner): + def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: DeferredCache[CacheKey, Any] = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, @@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): get_cache_key = self.cache_key_builder @functools.wraps(self.orig) - def _wrapped(*args, **kwargs): + def _wrapped(*args: Any, **kwargs: Any) -> Any: # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): of results. """ - def __init__(self, orig, cached_method_name, list_name, num_args=None): + def __init__( + self, + orig: Callable[..., Any], + cached_method_name: str, + list_name: str, + num_args: Optional[int] = None, + ): """ Args: - orig (function) - cached_method_name (str): The name of the cached method. - list_name (str): Name of the argument which is the bulk lookup list - num_args (int): number of positional arguments (excluding ``self``, + orig + cached_method_name: The name of the cached method. + list_name: Name of the argument which is the bulk lookup list + num_args: number of positional arguments (excluding ``self``, but including list_name) to use as cache keys. Defaults to all named args of the function. """ @@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): % (self.list_name, cached_method_name) ) - def __get__(self, obj, objtype=None): + def __get__( + self, obj: Optional[Any], objtype: Optional[Type] = None + ) -> Callable[..., Any]: cached_method = getattr(obj, self.cached_method_name) cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> Any: # If we're passed a cache_context then we'll want to call its # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): results = {} - def update_results_dict(res, arg): + def update_results_dict(res: Any, arg: Hashable) -> None: results[arg] = res # list of deferreds to wait for @@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): # otherwise a tuple is used. if num_args == 1: - def arg_to_cache_key(arg): + def arg_to_cache_key(arg: Hashable) -> Hashable: return arg else: keylist = list(keyargs) - def arg_to_cache_key(arg): + def arg_to_cache_key(arg: Hashable) -> Hashable: keylist[self.list_pos] = arg return tuple(keylist) @@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): key = arg_to_cache_key(arg) cache.set(key, deferred, callback=invalidate_callback) - def complete_all(res): + def complete_all(res: Dict[Hashable, Any]) -> None: # the wrapped function has completed. It returns a # a dict. We can now resolve the observable deferreds in # the cache and update our own result map. @@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): deferreds_map[e].callback(val) results[e] = val - def errback(f): + def errback(f: Failure) -> Failure: # the wrapped function has failed. Invalidate any cache # entries we're supposed to be populating, and fail # their deferreds. diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index c3f72aa06d..6a7e534576 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload import attr from typing_extensions import Literal +from twisted.internet import defer + from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util import Clock @@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]): # Don't bother starting the loop if things never expire return - def f(): + def f() -> "defer.Deferred[None]": return run_as_background_process( "prune_cache_%s" % self._cache_name, self._prune_cache ) @@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]): return False -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _CacheEntry: - time = attr.ib(type=int) - value = attr.ib() + time: int + value: Any -- cgit 1.5.1 From 84fac0f814f69645ff1ad564ef8294b31203dc95 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 17 Nov 2021 19:07:02 +0000 Subject: Add type annotations to `synapse.metrics` (#10847) --- changelog.d/10847.misc | 1 + mypy.ini | 3 + synapse/app/_base.py | 2 +- synapse/groups/attestations.py | 4 +- synapse/handlers/typing.py | 2 +- synapse/metrics/__init__.py | 101 ++++++++++++++++++-------- synapse/metrics/_exposition.py | 34 ++++----- synapse/metrics/background_process_metrics.py | 78 +++++++++++++++----- synapse/metrics/jemalloc.py | 10 ++- synapse/storage/database.py | 6 +- synapse/util/caches/expiringcache.py | 2 +- synapse/util/metrics.py | 15 ++-- 12 files changed, 173 insertions(+), 85 deletions(-) create mode 100644 changelog.d/10847.misc (limited to 'synapse/util/caches') diff --git a/changelog.d/10847.misc b/changelog.d/10847.misc new file mode 100644 index 0000000000..7933a38dca --- /dev/null +++ b/changelog.d/10847.misc @@ -0,0 +1 @@ +Add type annotations to `synapse.metrics`. diff --git a/mypy.ini b/mypy.ini index f32c6c41a3..308cfd95d8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -160,6 +160,9 @@ disallow_untyped_defs = True [mypy-synapse.handlers.*] disallow_untyped_defs = True +[mypy-synapse.metrics.*] +disallow_untyped_defs = True + [mypy-synapse.push.*] disallow_untyped_defs = True diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 573bb487b2..807ee3d46e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -402,7 +402,7 @@ async def start(hs: "HomeServer") -> None: if hasattr(signal, "SIGHUP"): @wrap_as_background_process("sighup") - def handle_sighup(*args: Any, **kwargs: Any) -> None: + async def handle_sighup(*args: Any, **kwargs: Any) -> None: # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. sdnotify(b"RELOADING=1") diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 53f99031b1..a87896e538 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -40,6 +40,8 @@ from typing import TYPE_CHECKING, Optional, Tuple from signedjson.sign import sign_json +from twisted.internet.defer import Deferred + from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, get_domain_from_id @@ -166,7 +168,7 @@ class GroupAttestionRenewer: return {} - def _start_renew_attestations(self) -> None: + def _start_renew_attestations(self) -> "Deferred[None]": return run_as_background_process("renew_attestations", self._renew_attestations) async def _renew_attestations(self) -> None: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 22c6174821..1676ebd057 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -90,7 +90,7 @@ class FollowerTypingHandler: self.wheel_timer = WheelTimer(bucket_size=5000) @wrap_as_background_process("typing._handle_timeouts") - def _handle_timeouts(self) -> None: + async def _handle_timeouts(self) -> None: logger.debug("Checking for typing timeouts") now = self.clock.time_msec() diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 91ee5c8193..ceef57ad88 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -20,10 +20,25 @@ import os import platform import threading import time -from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) import attr -from prometheus_client import Counter, Gauge, Histogram +from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric from prometheus_client.core import ( REGISTRY, CounterMetricFamily, @@ -32,6 +47,7 @@ from prometheus_client.core import ( ) from twisted.internet import reactor +from twisted.internet.base import ReactorBase from twisted.python.threadpool import ThreadPool import synapse @@ -54,7 +70,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") class RegistryProxy: @staticmethod - def collect(): + def collect() -> Iterable[Metric]: for metric in REGISTRY.collect(): if not metric.name.startswith("__"): yield metric @@ -74,7 +90,7 @@ class LaterGauge: ] ) - def collect(self): + def collect(self) -> Iterable[Metric]: g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) @@ -93,10 +109,10 @@ class LaterGauge: yield g - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._register() - def _register(self): + def _register(self) -> None: if self.name in all_gauges.keys(): logger.warning("%s already registered, reregistering" % (self.name,)) REGISTRY.unregister(all_gauges.pop(self.name)) @@ -105,7 +121,12 @@ class LaterGauge: all_gauges[self.name] = self -class InFlightGauge: +# `MetricsEntry` only makes sense when it is a `Protocol`, +# but `Protocol` can't be used as a `TypeVar` bound. +MetricsEntry = TypeVar("MetricsEntry") + + +class InFlightGauge(Generic[MetricsEntry]): """Tracks number of things (e.g. requests, Measure blocks, etc) in flight at any given time. @@ -115,14 +136,19 @@ class InFlightGauge: callbacks. Args: - name (str) - desc (str) - labels (list[str]) - sub_metrics (list[str]): A list of sub metrics that the callbacks - will update. + name + desc + labels + sub_metrics: A list of sub metrics that the callbacks will update. """ - def __init__(self, name, desc, labels, sub_metrics): + def __init__( + self, + name: str, + desc: str, + labels: Sequence[str], + sub_metrics: Sequence[str], + ): self.name = name self.desc = desc self.labels = labels @@ -130,19 +156,25 @@ class InFlightGauge: # Create a class which have the sub_metrics values as attributes, which # default to 0 on initialization. Used to pass to registered callbacks. - self._metrics_class = attr.make_class( + self._metrics_class: Type[MetricsEntry] = attr.make_class( "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True ) # Counts number of in flight blocks for a given set of label values - self._registrations: Dict = {} + self._registrations: Dict[ + Tuple[str, ...], Set[Callable[[MetricsEntry], None]] + ] = {} # Protects access to _registrations self._lock = threading.Lock() self._register_with_collector() - def register(self, key, callback): + def register( + self, + key: Tuple[str, ...], + callback: Callable[[MetricsEntry], None], + ) -> None: """Registers that we've entered a new block with labels `key`. `callback` gets called each time the metrics are collected. The same @@ -158,13 +190,17 @@ class InFlightGauge: with self._lock: self._registrations.setdefault(key, set()).add(callback) - def unregister(self, key, callback): + def unregister( + self, + key: Tuple[str, ...], + callback: Callable[[MetricsEntry], None], + ) -> None: """Registers that we've exited a block with labels `key`.""" with self._lock: self._registrations.setdefault(key, set()).discard(callback) - def collect(self): + def collect(self) -> Iterable[Metric]: """Called by prometheus client when it reads metrics. Note: may be called by a separate thread. @@ -200,7 +236,7 @@ class InFlightGauge: gauge.add_metric(key, getattr(metrics, name)) yield gauge - def _register_with_collector(self): + def _register_with_collector(self) -> None: if self.name in all_gauges.keys(): logger.warning("%s already registered, reregistering" % (self.name,)) REGISTRY.unregister(all_gauges.pop(self.name)) @@ -230,7 +266,7 @@ class GaugeBucketCollector: name: str, documentation: str, buckets: Iterable[float], - registry=REGISTRY, + registry: CollectorRegistry = REGISTRY, ): """ Args: @@ -257,12 +293,12 @@ class GaugeBucketCollector: registry.register(self) - def collect(self): + def collect(self) -> Iterable[Metric]: # Don't report metrics unless we've already collected some data if self._metric is not None: yield self._metric - def update_data(self, values: Iterable[float]): + def update_data(self, values: Iterable[float]) -> None: """Update the data to be reported by the metric The existing data is cleared, and each measurement in the input is assigned @@ -304,7 +340,7 @@ class GaugeBucketCollector: class CPUMetrics: - def __init__(self): + def __init__(self) -> None: ticks_per_sec = 100 try: # Try and get the system config @@ -314,7 +350,7 @@ class CPUMetrics: self.ticks_per_sec = ticks_per_sec - def collect(self): + def collect(self) -> Iterable[Metric]: if not HAVE_PROC_SELF_STAT: return @@ -364,7 +400,7 @@ gc_time = Histogram( class GCCounts: - def collect(self): + def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) for n, m in enumerate(gc.get_count()): cm.add_metric([str(n)], m) @@ -382,7 +418,7 @@ if not running_on_pypy: class PyPyGCStats: - def collect(self): + def collect(self) -> Iterable[Metric]: # @stats is a pretty-printer object with __str__() returning a nice table, # plus some fields that contain data from that table. @@ -565,7 +601,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None: class ReactorLastSeenMetric: - def collect(self): + def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily( "python_twisted_reactor_last_seen", "Seconds since the Twisted reactor was last seen", @@ -584,9 +620,12 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0) _last_gc = [0.0, 0.0, 0.0] -def runUntilCurrentTimer(reactor, func): +F = TypeVar("F", bound=Callable[..., Any]) + + +def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F: @functools.wraps(func) - def f(*args, **kwargs): + def f(*args: Any, **kwargs: Any) -> Any: now = reactor.seconds() num_pending = 0 @@ -649,7 +688,7 @@ def runUntilCurrentTimer(reactor, func): return ret - return f + return cast(F, f) try: @@ -677,5 +716,5 @@ __all__ = [ "start_http_server", "LaterGauge", "InFlightGauge", - "BucketCollector", + "GaugeBucketCollector", ] diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index bb9bcb5592..353d0a63b6 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -25,27 +25,25 @@ import math import threading from http.server import BaseHTTPRequestHandler, HTTPServer from socketserver import ThreadingMixIn -from typing import Dict, List +from typing import Any, Dict, List, Type, Union from urllib.parse import parse_qs, urlparse -from prometheus_client import REGISTRY +from prometheus_client import REGISTRY, CollectorRegistry +from prometheus_client.core import Sample from twisted.web.resource import Resource +from twisted.web.server import Request from synapse.util import caches CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" -INF = float("inf") -MINUS_INF = float("-inf") - - -def floatToGoString(d): +def floatToGoString(d: Union[int, float]) -> str: d = float(d) - if d == INF: + if d == math.inf: return "+Inf" - elif d == MINUS_INF: + elif d == -math.inf: return "-Inf" elif math.isnan(d): return "NaN" @@ -60,7 +58,7 @@ def floatToGoString(d): return s -def sample_line(line, name): +def sample_line(line: Sample, name: str) -> str: if line.labels: labelstr = "{{{0}}}".format( ",".join( @@ -82,7 +80,7 @@ def sample_line(line, name): return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) -def generate_latest(registry, emit_help=False): +def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: # Trigger the cache metrics to be rescraped, which updates the common # metrics but do not produce metrics themselves @@ -187,7 +185,7 @@ class MetricsHandler(BaseHTTPRequestHandler): registry = REGISTRY - def do_GET(self): + def do_GET(self) -> None: registry = self.registry params = parse_qs(urlparse(self.path).query) @@ -207,11 +205,11 @@ class MetricsHandler(BaseHTTPRequestHandler): self.end_headers() self.wfile.write(output) - def log_message(self, format, *args): + def log_message(self, format: str, *args: Any) -> None: """Log nothing.""" @classmethod - def factory(cls, registry): + def factory(cls, registry: CollectorRegistry) -> Type: """Returns a dynamic MetricsHandler class tied to the passed registry. """ @@ -236,7 +234,9 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): daemon_threads = True -def start_http_server(port, addr="", registry=REGISTRY): +def start_http_server( + port: int, addr: str = "", registry: CollectorRegistry = REGISTRY +) -> None: """Starts an HTTP server for prometheus metrics as a daemon thread""" CustomMetricsHandler = MetricsHandler.factory(registry) httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) @@ -252,10 +252,10 @@ class MetricsResource(Resource): isLeaf = True - def __init__(self, registry=REGISTRY): + def __init__(self, registry: CollectorRegistry = REGISTRY): self.registry = registry - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) response = generate_latest(self.registry) request.setHeader(b"Content-Length", str(len(response))) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 2ab599a334..53c508af91 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -15,19 +15,37 @@ import logging import threading from functools import wraps -from typing import TYPE_CHECKING, Dict, Optional, Set, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + Optional, + Set, + Type, + TypeVar, + Union, + cast, +) +from prometheus_client import Metric from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + ContextResourceUsage, + LoggingContext, + PreserveLoggingContext, +) from synapse.logging.opentracing import ( SynapseTags, noop_context_manager, start_active_span, ) -from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import resource @@ -116,7 +134,7 @@ class _Collector: before they are returned. """ - def collect(self): + def collect(self) -> Iterable[Metric]: global _background_processes_active_since_last_scrape # We swap out the _background_processes set with an empty one so that @@ -144,12 +162,12 @@ REGISTRY.register(_Collector()) class _BackgroundProcess: - def __init__(self, desc, ctx): + def __init__(self, desc: str, ctx: LoggingContext): self.desc = desc self._context = ctx - self._reported_stats = None + self._reported_stats: Optional[ContextResourceUsage] = None - def update_metrics(self): + def update_metrics(self) -> None: """Updates the metrics with values from this process.""" new_stats = self._context.get_resource_usage() if self._reported_stats is None: @@ -169,7 +187,16 @@ class _BackgroundProcess: ) -def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs): +R = TypeVar("R") + + +def run_as_background_process( + desc: str, + func: Callable[..., Awaitable[Optional[R]]], + *args: Any, + bg_start_span: bool = True, + **kwargs: Any, +) -> "defer.Deferred[Optional[R]]": """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -189,11 +216,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar args: positional args for func kwargs: keyword args for func - Returns: Deferred which returns the result of func, but note that it does not - follow the synapse logcontext rules. + Returns: + Deferred which returns the result of func, or `None` if func raises. + Note that the returned Deferred does not follow the synapse logcontext + rules. """ - async def run(): + async def run() -> Optional[R]: with _bg_metrics_lock: count = _background_process_counts.get(desc, 0) _background_process_counts[desc] = count + 1 @@ -210,12 +239,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar else: ctx = noop_context_manager() with ctx: - return await maybe_awaitable(func(*args, **kwargs)) + return await func(*args, **kwargs) except Exception: logger.exception( "Background process '%s' threw an exception", desc, ) + return None finally: _background_process_in_flight_count.labels(desc).dec() @@ -225,19 +255,24 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar return defer.ensureDeferred(run()) -def wrap_as_background_process(desc): +F = TypeVar("F", bound=Callable[..., Awaitable[Optional[Any]]]) + + +def wrap_as_background_process(desc: str) -> Callable[[F], F]: """Decorator that wraps a function that gets called as a background process. - Equivalent of calling the function with `run_as_background_process` + Equivalent to calling the function with `run_as_background_process` """ - def wrap_as_background_process_inner(func): + def wrap_as_background_process_inner(func: F) -> F: @wraps(func) - def wrap_as_background_process_inner_2(*args, **kwargs): + def wrap_as_background_process_inner_2( + *args: Any, **kwargs: Any + ) -> "defer.Deferred[Optional[R]]": return run_as_background_process(desc, func, *args, **kwargs) - return wrap_as_background_process_inner_2 + return cast(F, wrap_as_background_process_inner_2) return wrap_as_background_process_inner @@ -265,7 +300,7 @@ class BackgroundProcessLoggingContext(LoggingContext): super().__init__("%s-%s" % (name, instance_id)) self._proc = _BackgroundProcess(name, self) - def start(self, rusage: "Optional[resource.struct_rusage]"): + def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """Log context has started running (again).""" super().start(rusage) @@ -276,7 +311,12 @@ class BackgroundProcessLoggingContext(LoggingContext): with _bg_metrics_lock: _background_processes_active_since_last_scrape.add(self._proc) - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: """Log context has finished.""" super().__exit__(type, value, traceback) diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py index 29ab6c0229..98ed9c0829 100644 --- a/synapse/metrics/jemalloc.py +++ b/synapse/metrics/jemalloc.py @@ -16,14 +16,16 @@ import ctypes import logging import os import re -from typing import Optional +from typing import Iterable, Optional + +from prometheus_client import Metric from synapse.metrics import REGISTRY, GaugeMetricFamily logger = logging.getLogger(__name__) -def _setup_jemalloc_stats(): +def _setup_jemalloc_stats() -> None: """Checks to see if jemalloc is loaded, and hooks up a collector to record statistics exposed by jemalloc. """ @@ -135,7 +137,7 @@ def _setup_jemalloc_stats(): class JemallocCollector: """Metrics for internal jemalloc stats.""" - def collect(self): + def collect(self) -> Iterable[Metric]: _jemalloc_refresh_stats() g = GaugeMetricFamily( @@ -185,7 +187,7 @@ def _setup_jemalloc_stats(): logger.debug("Added jemalloc stats") -def setup_jemalloc_stats(): +def setup_jemalloc_stats() -> None: """Try to setup jemalloc stats, if jemalloc is loaded.""" try: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index d4cab69ebf..0693d39006 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -188,7 +188,7 @@ class LoggingDatabaseConnection: # The type of entry which goes on our after_callbacks and exception_callbacks lists. -_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]] +_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]] R = TypeVar("R") @@ -235,7 +235,7 @@ class LoggingTransaction: self.after_callbacks = after_callbacks self.exception_callbacks = exception_callbacks - def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any): + def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. @@ -247,7 +247,7 @@ class LoggingTransaction: self.after_callbacks.append((callback, args, kwargs)) def call_on_exception( - self, callback: Callable[..., None], *args: Any, **kwargs: Any + self, callback: Callable[..., object], *args: Any, **kwargs: Any ): # if self.exception_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 6a7e534576..67ee4c693b 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -159,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]): self[key] = value return value - def _prune_cache(self) -> None: + async def _prune_cache(self) -> None: if not self._expiry_ms: # zero expiry time means don't expire. This should never get called # since we have this check in start too. diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index ad775dfc7d..98ee49af6e 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -56,8 +56,15 @@ block_db_sched_duration = Counter( "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"] ) + +# This is dynamically created in InFlightGauge.__init__. +class _InFlightMetric(Protocol): + real_time_max: float + real_time_sum: float + + # Tracks the number of blocks currently active -in_flight = InFlightGauge( +in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge( "synapse_util_metrics_block_in_flight", "", labels=["block_name"], @@ -65,12 +72,6 @@ in_flight = InFlightGauge( ) -# This is dynamically created in InFlightGauge.__init__. -class _InFlightMetric(Protocol): - real_time_max: float - real_time_sum: float - - T = TypeVar("T", bound=Callable[..., Any]) -- cgit 1.5.1