summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/util/caches/deferred_cache.py9
-rw-r--r--synapse/util/caches/lrucache.py42
-rw-r--r--synapse/util/linked_list.py4
-rw-r--r--synapse/util/versionstring.py82
4 files changed, 60 insertions, 77 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 3c4cc093af..377c9a282a 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     MutableMapping,
     Optional,
+    Sized,
     TypeVar,
     Union,
     cast,
@@ -104,7 +105,13 @@ class DeferredCache(Generic[KT, VT]):
             max_size=max_entries,
             cache_name=name,
             cache_type=cache_type,
-            size_callback=(lambda d: len(d) or 1) if iterable else None,
+            size_callback=(
+                (lambda d: len(cast(Sized, d)) or 1)
+                # Argument 1 to "len" has incompatible type "VT"; expected "Sized"
+                # We trust that `VT` is `Sized` when `iterable` is `True`
+                if iterable
+                else None
+            ),
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
             prune_unread_entries=prune_unread_entries,
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index a0a7a9de32..eb96f7e665 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,14 +15,15 @@
 import logging
 import threading
 import weakref
+from enum import Enum
 from functools import wraps
 from typing import (
     TYPE_CHECKING,
     Any,
     Callable,
     Collection,
+    Dict,
     Generic,
-    Iterable,
     List,
     Optional,
     Type,
@@ -190,7 +191,7 @@ class _Node(Generic[KT, VT]):
         root: "ListNode[_Node]",
         key: KT,
         value: VT,
-        cache: "weakref.ReferenceType[LruCache]",
+        cache: "weakref.ReferenceType[LruCache[KT, VT]]",
         clock: Clock,
         callbacks: Collection[Callable[[], None]] = (),
         prune_unread_entries: bool = True,
@@ -270,7 +271,10 @@ class _Node(Generic[KT, VT]):
         removed from all lists.
         """
         cache = self._cache()
-        if not cache or not cache.pop(self.key, None):
+        if (
+            cache is None
+            or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel
+        ):
             # `cache.pop` should call `drop_from_lists()`, unless this Node had
             # already been removed from the cache.
             self.drop_from_lists()
@@ -290,6 +294,12 @@ class _Node(Generic[KT, VT]):
             self._global_list_node.update_last_access(clock)
 
 
+class _Sentinel(Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup.
+    sentinel = object()
+
+
 class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@@ -302,7 +312,7 @@ class LruCache(Generic[KT, VT]):
         max_size: int,
         cache_name: Optional[str] = None,
         cache_type: Type[Union[dict, TreeCache]] = dict,
-        size_callback: Optional[Callable] = None,
+        size_callback: Optional[Callable[[VT], int]] = None,
         metrics_collection_callback: Optional[Callable[[], None]] = None,
         apply_cache_factor_from_config: bool = True,
         clock: Optional[Clock] = None,
@@ -339,7 +349,7 @@ class LruCache(Generic[KT, VT]):
         else:
             real_clock = clock
 
-        cache = cache_type()
+        cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
         self.cache = cache  # Used for introspection.
         self.apply_cache_factor_from_config = apply_cache_factor_from_config
 
@@ -374,7 +384,7 @@ class LruCache(Generic[KT, VT]):
         # creating more each time we create a `_Node`.
         weak_ref_to_self = weakref.ref(self)
 
-        list_root = ListNode[_Node].create_root_node()
+        list_root = ListNode[_Node[KT, VT]].create_root_node()
 
         lock = threading.Lock()
 
@@ -422,7 +432,7 @@ class LruCache(Generic[KT, VT]):
         def add_node(
             key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
         ) -> None:
-            node = _Node(
+            node: _Node[KT, VT] = _Node(
                 list_root,
                 key,
                 value,
@@ -439,10 +449,10 @@ class LruCache(Generic[KT, VT]):
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.inc_memory_usage(node.memory)
 
-        def move_node_to_front(node: _Node) -> None:
+        def move_node_to_front(node: _Node[KT, VT]) -> None:
             node.move_to_front(real_clock, list_root)
 
-        def delete_node(node: _Node) -> int:
+        def delete_node(node: _Node[KT, VT]) -> int:
             node.drop_from_lists()
 
             deleted_len = 1
@@ -496,7 +506,7 @@ class LruCache(Generic[KT, VT]):
 
         @synchronized
         def cache_set(
-            key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
+            key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
         ) -> None:
             node = cache.get(key, None)
             if node is not None:
@@ -590,8 +600,6 @@ class LruCache(Generic[KT, VT]):
         def cache_contains(key: KT) -> bool:
             return key in cache
 
-        self.sentinel = object()
-
         # make sure that we clear out any excess entries after we get resized.
         self._on_resize = evict
 
@@ -608,18 +616,18 @@ class LruCache(Generic[KT, VT]):
         self.clear = cache_clear
 
     def __getitem__(self, key: KT) -> VT:
-        result = self.get(key, self.sentinel)
-        if result is self.sentinel:
+        result = self.get(key, _Sentinel.sentinel)
+        if result is _Sentinel.sentinel:
             raise KeyError()
         else:
-            return cast(VT, result)
+            return result
 
     def __setitem__(self, key: KT, value: VT) -> None:
         self.set(key, value)
 
     def __delitem__(self, key: KT, value: VT) -> None:
-        result = self.pop(key, self.sentinel)
-        if result is self.sentinel:
+        result = self.pop(key, _Sentinel.sentinel)
+        if result is _Sentinel.sentinel:
             raise KeyError()
 
     def __len__(self) -> int:
diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py
index 9f4be757ba..8efbf061aa 100644
--- a/synapse/util/linked_list.py
+++ b/synapse/util/linked_list.py
@@ -84,7 +84,7 @@ class ListNode(Generic[P]):
         # immediately rather than at the next GC.
         self.cache_entry = None
 
-    def move_after(self, node: "ListNode") -> None:
+    def move_after(self, node: "ListNode[P]") -> None:
         """Move this node from its current location in the list to after the
         given node.
         """
@@ -122,7 +122,7 @@ class ListNode(Generic[P]):
         self.prev_node = None
         self.next_node = None
 
-    def _refs_insert_after(self, node: "ListNode") -> None:
+    def _refs_insert_after(self, node: "ListNode[P]") -> None:
         """Internal method to insert the node after the given node."""
 
         # This method should only be called when we're not already in the list.
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 899ee0adc8..c144ff62c1 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -1,4 +1,5 @@
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -29,10 +30,11 @@ def get_version_string(module: ModuleType) -> str:
     If called on a module not in a git checkout will return `__version__`.
 
     Args:
-        module (module)
+        module: The module to check the version of. Must declare a __version__
+            attribute.
 
     Returns:
-        str
+        The module version (as a string).
     """
 
     cached_version = version_cache.get(module)
@@ -44,71 +46,37 @@ def get_version_string(module: ModuleType) -> str:
     version_string = module.__version__  # type: ignore[attr-defined]
 
     try:
-        null = open(os.devnull, "w")
         cwd = os.path.dirname(os.path.abspath(module.__file__))
 
-        try:
-            git_branch = (
-                subprocess.check_output(
-                    ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd
+        def _run_git_command(prefix: str, *params: str) -> str:
+            try:
+                result = (
+                    subprocess.check_output(
+                        ["git", *params], stderr=subprocess.DEVNULL, cwd=cwd
+                    )
+                    .strip()
+                    .decode("ascii")
                 )
-                .strip()
-                .decode("ascii")
-            )
-            git_branch = "b=" + git_branch
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            # FileNotFoundError can arise when git is not installed
-            git_branch = ""
-
-        try:
-            git_tag = (
-                subprocess.check_output(
-                    ["git", "describe", "--exact-match"], stderr=null, cwd=cwd
-                )
-                .strip()
-                .decode("ascii")
-            )
-            git_tag = "t=" + git_tag
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            git_tag = ""
-
-        try:
-            git_commit = (
-                subprocess.check_output(
-                    ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd
-                )
-                .strip()
-                .decode("ascii")
-            )
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            git_commit = ""
-
-        try:
-            dirty_string = "-this_is_a_dirty_checkout"
-            is_dirty = (
-                subprocess.check_output(
-                    ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd
-                )
-                .strip()
-                .decode("ascii")
-                .endswith(dirty_string)
-            )
+                return prefix + result
+            except (subprocess.CalledProcessError, FileNotFoundError):
+                return ""
 
-            git_dirty = "dirty" if is_dirty else ""
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            git_dirty = ""
+        git_branch = _run_git_command("b=", "rev-parse", "--abbrev-ref", "HEAD")
+        git_tag = _run_git_command("t=", "describe", "--exact-match")
+        git_commit = _run_git_command("", "rev-parse", "--short", "HEAD")
+
+        dirty_string = "-this_is_a_dirty_checkout"
+        is_dirty = _run_git_command("", "describe", "--dirty=" + dirty_string).endswith(
+            dirty_string
+        )
+        git_dirty = "dirty" if is_dirty else ""
 
         if git_branch or git_tag or git_commit or git_dirty:
             git_version = ",".join(
                 s for s in (git_branch, git_tag, git_commit, git_dirty) if s
             )
 
-            version_string = "%s (%s)" % (
-                # If the __version__ attribute doesn't exist, we'll have failed
-                # loudly above.
-                module.__version__,  # type: ignore[attr-defined]
-                git_version,
-            )
+            version_string = f"{version_string} ({git_version})"
     except Exception as e:
         logger.info("Failed to check for git repository: %s", e)