summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2021-10-06 11:20:49 +0100
committerGitHub <noreply@github.com>2021-10-06 11:20:49 +0100
commitf8d0f72b27e158738f3c75a38399b967f7478011 (patch)
treead1848e462ac985017f62f96734cfd3164c7768f /synapse/util
parentRemove "reference" wording according Synapse homeserver (#10971) (diff)
downloadsynapse-f8d0f72b27e158738f3c75a38399b967f7478011.tar.xz
More types for synapse.util, part 1 (#10888)
The following modules now pass `disallow_untyped_defs`:

* synapse.util.caches.cached_call 
* synapse.util.caches.lrucache
* synapse.util.caches.response_cache 
* synapse.util.caches.stream_change_cache
* synapse.util.caches.ttlcache pass
* synapse.util.daemonize
* synapse.util.patch_inline_callbacks pass `no-untyped-defs`
* synapse.util.versionstring

Additional typing in synapse.util.metrics. Didn't get this to pass `no-untyped-defs`, think I'll need to watch #10847
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/caches/cached_call.py2
-rw-r--r--synapse/util/caches/deferred_cache.py11
-rw-r--r--synapse/util/caches/lrucache.py57
-rw-r--r--synapse/util/caches/response_cache.py6
-rw-r--r--synapse/util/caches/stream_change_cache.py6
-rw-r--r--synapse/util/caches/ttlcache.py12
-rw-r--r--synapse/util/daemonize.py8
-rw-r--r--synapse/util/metrics.py27
-rw-r--r--synapse/util/patch_inline_callbacks.py28
-rw-r--r--synapse/util/versionstring.py25
10 files changed, 109 insertions, 73 deletions
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index e58dd91eda..470f4f91a5 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -85,7 +85,7 @@ class CachedCall(Generic[TV]):
             # result in the deferred, since `awaiting` a deferred destroys its result.
             # (Also, if it's a Failure, GCing the deferred would log a critical error
             # about unhandled Failures)
-            def got_result(r):
+            def got_result(r: Union[TV, Failure]) -> None:
                 self._result = r
 
             self._deferred.addBoth(got_result)
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 6262efe072..da502aec11 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -31,6 +31,7 @@ from prometheus_client import Gauge
 
 from twisted.internet import defer
 from twisted.python import failure
+from twisted.python.failure import Failure
 
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.lrucache import LruCache
@@ -112,7 +113,7 @@ class DeferredCache(Generic[KT, VT]):
         self.thread: Optional[threading.Thread] = None
 
     @property
-    def max_entries(self):
+    def max_entries(self) -> int:
         return self.cache.max_size
 
     def check_thread(self) -> None:
@@ -258,7 +259,7 @@ class DeferredCache(Generic[KT, VT]):
 
             return False
 
-        def cb(result) -> None:
+        def cb(result: VT) -> None:
             if compare_and_pop():
                 self.cache.set(key, result, entry.callbacks)
             else:
@@ -270,7 +271,7 @@ class DeferredCache(Generic[KT, VT]):
                 # not have been. Either way, let's double-check now.
                 entry.invalidate()
 
-        def eb(_fail) -> None:
+        def eb(_fail: Failure) -> None:
             compare_and_pop()
             entry.invalidate()
 
@@ -284,11 +285,11 @@ class DeferredCache(Generic[KT, VT]):
 
     def prefill(
         self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
-    ):
+    ) -> None:
         callbacks = [callback] if callback else []
         self.cache.set(key, value, callbacks=callbacks)
 
-    def invalidate(self, key):
+    def invalidate(self, key) -> None:
         """Delete a key, or tree of entries
 
         If the cache is backed by a regular dict, then "key" must be of
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4ff62b403f..a0a7a9de32 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -52,7 +52,7 @@ logger = logging.getLogger(__name__)
 try:
     from pympler.asizeof import Asizer
 
-    def _get_size_of(val: Any, *, recurse=True) -> int:
+    def _get_size_of(val: Any, *, recurse: bool = True) -> int:
         """Get an estimate of the size in bytes of the object.
 
         Args:
@@ -71,7 +71,7 @@ try:
 
 except ImportError:
 
-    def _get_size_of(val: Any, *, recurse=True) -> int:
+    def _get_size_of(val: Any, *, recurse: bool = True) -> int:
         return 0
 
 
@@ -85,15 +85,6 @@ VT = TypeVar("VT")
 # a general type var, distinct from either KT or VT
 T = TypeVar("T")
 
-
-def enumerate_leaves(node, depth):
-    if depth == 0:
-        yield node
-    else:
-        for n in node.values():
-            yield from enumerate_leaves(n, depth - 1)
-
-
 P = TypeVar("P")
 
 
@@ -102,7 +93,7 @@ class _TimedListNode(ListNode[P]):
 
     __slots__ = ["last_access_ts_secs"]
 
-    def update_last_access(self, clock: Clock):
+    def update_last_access(self, clock: Clock) -> None:
         self.last_access_ts_secs = int(clock.time())
 
 
@@ -115,7 +106,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
 
 
 @wrap_as_background_process("LruCache._expire_old_entries")
-async def _expire_old_entries(clock: Clock, expiry_seconds: int):
+async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
     """Walks the global cache list to find cache entries that haven't been
     accessed in the given number of seconds.
     """
@@ -163,7 +154,7 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int):
     logger.info("Dropped %d items from caches", i)
 
 
-def setup_expire_lru_cache_entries(hs: "HomeServer"):
+def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
     """Start a background job that expires all cache entries if they have not
     been accessed for the given number of seconds.
     """
@@ -183,7 +174,7 @@ def setup_expire_lru_cache_entries(hs: "HomeServer"):
     )
 
 
-class _Node:
+class _Node(Generic[KT, VT]):
     __slots__ = [
         "_list_node",
         "_global_list_node",
@@ -197,8 +188,8 @@ class _Node:
     def __init__(
         self,
         root: "ListNode[_Node]",
-        key,
-        value,
+        key: KT,
+        value: VT,
         cache: "weakref.ReferenceType[LruCache]",
         clock: Clock,
         callbacks: Collection[Callable[[], None]] = (),
@@ -409,7 +400,7 @@ class LruCache(Generic[KT, VT]):
 
         def synchronized(f: FT) -> FT:
             @wraps(f)
-            def inner(*args, **kwargs):
+            def inner(*args: Any, **kwargs: Any) -> Any:
                 with lock:
                     return f(*args, **kwargs)
 
@@ -418,17 +409,19 @@ class LruCache(Generic[KT, VT]):
         cached_cache_len = [0]
         if size_callback is not None:
 
-            def cache_len():
+            def cache_len() -> int:
                 return cached_cache_len[0]
 
         else:
 
-            def cache_len():
+            def cache_len() -> int:
                 return len(cache)
 
         self.len = synchronized(cache_len)
 
-        def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
+        def add_node(
+            key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
+        ) -> None:
             node = _Node(
                 list_root,
                 key,
@@ -446,7 +439,7 @@ class LruCache(Generic[KT, VT]):
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.inc_memory_usage(node.memory)
 
-        def move_node_to_front(node: _Node):
+        def move_node_to_front(node: _Node) -> None:
             node.move_to_front(real_clock, list_root)
 
         def delete_node(node: _Node) -> int:
@@ -488,7 +481,7 @@ class LruCache(Generic[KT, VT]):
             default: Optional[T] = None,
             callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
-        ):
+        ) -> Union[None, T, VT]:
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
@@ -502,7 +495,9 @@ class LruCache(Generic[KT, VT]):
                 return default
 
         @synchronized
-        def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
+        def cache_set(
+            key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
+        ) -> None:
             node = cache.get(key, None)
             if node is not None:
                 # We sometimes store large objects, e.g. dicts, which cause
@@ -547,7 +542,7 @@ class LruCache(Generic[KT, VT]):
             ...
 
         @synchronized
-        def cache_pop(key: KT, default: Optional[T] = None):
+        def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]:
             node = cache.get(key, None)
             if node:
                 delete_node(node)
@@ -612,25 +607,25 @@ class LruCache(Generic[KT, VT]):
         self.contains = cache_contains
         self.clear = cache_clear
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: KT) -> VT:
         result = self.get(key, self.sentinel)
         if result is self.sentinel:
             raise KeyError()
         else:
-            return result
+            return cast(VT, result)
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: KT, value: VT) -> None:
         self.set(key, value)
 
-    def __delitem__(self, key, value):
+    def __delitem__(self, key: KT, value: VT) -> None:
         result = self.pop(key, self.sentinel)
         if result is self.sentinel:
             raise KeyError()
 
-    def __len__(self):
+    def __len__(self) -> int:
         return self.len()
 
-    def __contains__(self, key):
+    def __contains__(self, key: KT) -> bool:
         return self.contains(key)
 
     def set_cache_factor(self, factor: float) -> bool:
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index ed7204336f..88ccf44337 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -104,8 +104,8 @@ class ResponseCache(Generic[KV]):
             return None
 
     def _set(
-        self, context: ResponseCacheContext[KV], deferred: defer.Deferred
-    ) -> defer.Deferred:
+        self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]"
+    ) -> "defer.Deferred[RV]":
         """Set the entry for the given key to the given deferred.
 
         *deferred* should run its callbacks in the sentinel logcontext (ie,
@@ -126,7 +126,7 @@ class ResponseCache(Generic[KV]):
         key = context.cache_key
         self.pending_result_cache[key] = result
 
-        def on_complete(r):
+        def on_complete(r: RV) -> RV:
             # if this cache has a non-zero timeout, and the callback has not cleared
             # the should_cache bit, we leave it in the cache for now and schedule
             # its removal later.
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 27b1da235e..330709b8b7 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -40,10 +40,10 @@ class StreamChangeCache:
         self,
         name: str,
         current_stream_pos: int,
-        max_size=10000,
+        max_size: int = 10000,
         prefilled_cache: Optional[Mapping[EntityType, int]] = None,
-    ):
-        self._original_max_size = max_size
+    ) -> None:
+        self._original_max_size: int = max_size
         self._max_size = math.floor(max_size)
         self._entity_to_key: Dict[EntityType, int] = {}
 
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 46afe3f934..0b9ac26b69 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -159,12 +159,12 @@ class TTLCache(Generic[KT, VT]):
             del self._expiry_list[0]
 
 
-@attr.s(frozen=True, slots=True)
-class _CacheEntry:
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _CacheEntry:  # Should be Generic[KT, VT]. See python-attrs/attrs#313
     """TTLCache entry"""
 
     # expiry_time is the first attribute, so that entries are sorted by expiry.
-    expiry_time = attr.ib(type=float)
-    ttl = attr.ib(type=float)
-    key = attr.ib()
-    value = attr.ib()
+    expiry_time: float
+    ttl: float
+    key: Any  # should be KT
+    value: Any  # should be VT
diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py
index f1a351cfd4..de04f34e4e 100644
--- a/synapse/util/daemonize.py
+++ b/synapse/util/daemonize.py
@@ -19,6 +19,8 @@ import logging
 import os
 import signal
 import sys
+from types import FrameType, TracebackType
+from typing import NoReturn, Type
 
 
 def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
@@ -97,7 +99,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
     # (we don't normally expect reactor.run to raise any exceptions, but this will
     # also catch any other uncaught exceptions before we get that far.)
 
-    def excepthook(type_, value, traceback):
+    def excepthook(
+        type_: Type[BaseException], value: BaseException, traceback: TracebackType
+    ) -> None:
         logger.critical("Unhanded exception", exc_info=(type_, value, traceback))
 
     sys.excepthook = excepthook
@@ -119,7 +123,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
         sys.exit(1)
 
     # write a log line on SIGTERM.
-    def sigterm(signum, frame):
+    def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn:
         logger.warning("Caught signal %s. Stopping daemon." % signum)
         sys.exit(0)
 
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 1b82dca81b..1e784b3f1f 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -14,9 +14,11 @@
 
 import logging
 from functools import wraps
-from typing import Any, Callable, Optional, TypeVar, cast
+from types import TracebackType
+from typing import Any, Callable, Optional, Type, TypeVar, cast
 
 from prometheus_client import Counter
+from typing_extensions import Protocol
 
 from synapse.logging.context import (
     ContextResourceUsage,
@@ -24,6 +26,7 @@ from synapse.logging.context import (
     current_context,
 )
 from synapse.metrics import InFlightGauge
+from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
@@ -64,6 +67,10 @@ in_flight = InFlightGauge(
 T = TypeVar("T", bound=Callable[..., Any])
 
 
+class HasClock(Protocol):
+    clock: Clock
+
+
 def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
     """
     Used to decorate an async function with a `Measure` context manager.
@@ -86,7 +93,7 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
         block_name = func.__name__ if name is None else name
 
         @wraps(func)
-        async def measured_func(self, *args, **kwargs):
+        async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any:
             with Measure(self.clock, block_name):
                 r = await func(self, *args, **kwargs)
             return r
@@ -104,10 +111,10 @@ class Measure:
         "start",
     ]
 
-    def __init__(self, clock, name: str):
+    def __init__(self, clock: Clock, name: str) -> None:
         """
         Args:
-            clock: A n object with a "time()" method, which returns the current
+            clock: An object with a "time()" method, which returns the current
                 time in seconds.
             name: The name of the metric to report.
         """
@@ -124,7 +131,7 @@ class Measure:
             assert isinstance(curr_context, LoggingContext)
             parent_context = curr_context
         self._logging_context = LoggingContext(str(curr_context), parent_context)
-        self.start: Optional[int] = None
+        self.start: Optional[float] = None
 
     def __enter__(self) -> "Measure":
         if self.start is not None:
@@ -138,7 +145,12 @@ class Measure:
 
         return self
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
         if self.start is None:
             raise RuntimeError("Measure() block exited without being entered")
 
@@ -168,8 +180,9 @@ class Measure:
         """
         return self._logging_context.get_resource_usage()
 
-    def _update_in_flight(self, metrics):
+    def _update_in_flight(self, metrics) -> None:
         """Gets called when processing in flight metrics"""
+        assert self.start is not None
         duration = self.clock.time() - self.start
 
         metrics.real_time_max = max(metrics.real_time_max, duration)
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 9dd010af3b..1f18654d47 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -14,7 +14,7 @@
 
 import functools
 import sys
-from typing import Any, Callable, List
+from typing import Any, Callable, Generator, List, TypeVar
 
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
@@ -24,6 +24,9 @@ from twisted.python.failure import Failure
 _already_patched = False
 
 
+T = TypeVar("T")
+
+
 def do_patch() -> None:
     """
     Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
@@ -37,15 +40,19 @@ def do_patch() -> None:
     if _already_patched:
         return
 
-    def new_inline_callbacks(f):
+    def new_inline_callbacks(
+        f: Callable[..., Generator["Deferred[object]", object, T]]
+    ) -> Callable[..., "Deferred[T]"]:
         @functools.wraps(f)
-        def wrapped(*args, **kwargs):
+        def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]":
             start_context = current_context()
             changes: List[str] = []
-            orig = orig_inline_callbacks(_check_yield_points(f, changes))
+            orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks(
+                _check_yield_points(f, changes)
+            )
 
             try:
-                res = orig(*args, **kwargs)
+                res: "Deferred[T]" = orig(*args, **kwargs)
             except Exception:
                 if current_context() != start_context:
                     for err in changes:
@@ -84,7 +91,7 @@ def do_patch() -> None:
                 print(err, file=sys.stderr)
                 raise Exception(err)
 
-            def check_ctx(r):
+            def check_ctx(r: T) -> T:
                 if current_context() != start_context:
                     for err in changes:
                         print(err, file=sys.stderr)
@@ -107,7 +114,10 @@ def do_patch() -> None:
     _already_patched = True
 
 
-def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
+def _check_yield_points(
+    f: Callable[..., Generator["Deferred[object]", object, T]],
+    changes: List[str],
+) -> Callable:
     """Wraps a generator that is about to be passed to defer.inlineCallbacks
     checking that after every yield the log contexts are correct.
 
@@ -127,7 +137,9 @@ def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
     from synapse.logging.context import current_context
 
     @functools.wraps(f)
-    def check_yield_points_inner(*args, **kwargs):
+    def check_yield_points_inner(
+        *args: Any, **kwargs: Any
+    ) -> Generator["Deferred[object]", object, T]:
         gen = f(*args, **kwargs)
 
         last_yield_line_no = gen.gi_frame.f_lineno
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 1c20b24bbe..899ee0adc8 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -15,14 +15,18 @@
 import logging
 import os
 import subprocess
+from types import ModuleType
+from typing import Dict
 
 logger = logging.getLogger(__name__)
 
+version_cache: Dict[ModuleType, str] = {}
 
-def get_version_string(module) -> str:
+
+def get_version_string(module: ModuleType) -> str:
     """Given a module calculate a git-aware version string for it.
 
-    If called on a module not in a git checkout will return `__verison__`.
+    If called on a module not in a git checkout will return `__version__`.
 
     Args:
         module (module)
@@ -31,11 +35,13 @@ def get_version_string(module) -> str:
         str
     """
 
-    cached_version = getattr(module, "_synapse_version_string_cache", None)
-    if cached_version:
+    cached_version = version_cache.get(module)
+    if cached_version is not None:
         return cached_version
 
-    version_string = module.__version__
+    # We want this to fail loudly with an AttributeError. Type-ignore this so
+    # mypy only considers the happy path.
+    version_string = module.__version__  # type: ignore[attr-defined]
 
     try:
         null = open(os.devnull, "w")
@@ -97,10 +103,15 @@ def get_version_string(module) -> str:
                 s for s in (git_branch, git_tag, git_commit, git_dirty) if s
             )
 
-            version_string = "%s (%s)" % (module.__version__, git_version)
+            version_string = "%s (%s)" % (
+                # If the __version__ attribute doesn't exist, we'll have failed
+                # loudly above.
+                module.__version__,  # type: ignore[attr-defined]
+                git_version,
+            )
     except Exception as e:
         logger.info("Failed to check for git repository: %s", e)
 
-    module._synapse_version_string_cache = version_string
+    version_cache[module] = version_string
 
     return version_string