summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-04-06 08:58:18 -0400
committerGitHub <noreply@github.com>2021-04-06 08:58:18 -0400
commit44bb881096d7ad4d730e2dc31e0094c0324e0970 (patch)
treee5fdb161168dba08ec4f800ed5f0d44eb8a552cd
parentFix reported bugbear: too broad exception assertion (#9753) (diff)
downloadsynapse-44bb881096d7ad4d730e2dc31e0094c0324e0970.tar.xz
Add type hints to expiring cache. (#9730)
-rw-r--r--changelog.d/9730.misc1
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/handlers/device.py4
-rw-r--r--synapse/handlers/e2e_keys.py12
-rw-r--r--synapse/handlers/sync.py10
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py2
-rw-r--r--synapse/state/__init__.py5
-rw-r--r--synapse/util/caches/expiringcache.py83
8 files changed, 65 insertions, 54 deletions
diff --git a/changelog.d/9730.misc b/changelog.d/9730.misc
new file mode 100644
index 0000000000..8063059b0b
--- /dev/null
+++ b/changelog.d/9730.misc
@@ -0,0 +1 @@
+Add type hints to expiring cache.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index afdb5bf2fa..55533d7501 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -102,7 +102,7 @@ class FederationClient(FederationBase):
             max_len=1000,
             expiry_ms=120 * 1000,
             reset_expiry_on_get=False,
-        )
+        )  # type: ExpiringCache[str, EventBase]
 
     def _clear_tried_cache(self):
         """Clear pdu_destination_tried cache"""
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 54293d0b9c..7e76db3e2a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -631,7 +631,7 @@ class DeviceListUpdater:
             max_len=10000,
             expiry_ms=30 * 60 * 1000,
             iterable=True,
-        )
+        )  # type: ExpiringCache[str, Set[str]]
 
         # Attempt to resync out of sync device lists every 30s.
         self._resync_retry_in_progress = False
@@ -760,7 +760,7 @@ class DeviceListUpdater:
         """Given a list of updates for a user figure out if we need to do a full
         resync, or whether we have enough data that we can just apply the delta.
         """
-        seen_updates = self._seen_updates.get(user_id, set())
+        seen_updates = self._seen_updates.get(user_id, set())  # type: Set[str]
 
         extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 739653a3fa..92b18378fc 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -38,7 +38,6 @@ from synapse.types import (
 )
 from synapse.util import json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater:
         # user_id -> list of updates waiting to be handled.
         self._pending_updates = {}  # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
 
-        # Recently seen stream ids. We don't bother keeping these in the DB,
-        # but they're useful to have them about to reduce the number of spurious
-        # resyncs.
-        self._seen_updates = ExpiringCache(
-            cache_name="signing_key_update_edu",
-            clock=self.clock,
-            max_len=10000,
-            expiry_ms=30 * 60 * 1000,
-            iterable=True,
-        )
-
     async def incoming_signing_key_update(
         self, origin: str, edu_content: JsonDict
     ) -> None:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7b356ba7e5..ff11266c67 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -252,13 +252,13 @@ class SyncHandler:
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-        # ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
+        # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
         self.lazy_loaded_members_cache = ExpiringCache(
             "lazy_loaded_members_cache",
             self.clock,
             max_len=0,
             expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
-        )
+        )  # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
 
     async def wait_for_sync_for_user(
         self,
@@ -733,8 +733,10 @@ class SyncHandler:
 
     def get_lazy_loaded_members_cache(
         self, cache_key: Tuple[str, Optional[str]]
-    ) -> LruCache:
-        cache = self.lazy_loaded_members_cache.get(cache_key)
+    ) -> LruCache[str, str]:
+        cache = self.lazy_loaded_members_cache.get(
+            cache_key
+        )  # type: Optional[LruCache[str, str]]
         if cache is None:
             logger.debug("creating LruCache for %r", cache_key)
             cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index c4ed9dfdb4..814145a04a 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -175,7 +175,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             clock=self.clock,
             # don't spider URLs more often than once an hour
             expiry_ms=ONE_HOUR,
-        )
+        )  # type: ExpiringCache[str, ObservableDeferred]
 
         if self._worker_run_media_background_jobs:
             self._cleaner_loop = self.clock.looping_call(
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c3d6e80c49..c0f79ffdc8 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -22,6 +22,7 @@ from typing import (
     Callable,
     DefaultDict,
     Dict,
+    FrozenSet,
     Iterable,
     List,
     Optional,
@@ -515,7 +516,7 @@ class StateResolutionHandler:
             expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
             iterable=True,
             reset_expiry_on_get=True,
-        )
+        )  # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]
 
         #
         # stuff for tracking time spent on state-res by room
@@ -536,7 +537,7 @@ class StateResolutionHandler:
         state_groups_ids: Dict[int, StateMap[str]],
         event_map: Optional[Dict[str, EventBase]],
         state_res_store: "StateResolutionStore",
-    ):
+    ) -> _StateCacheEntry:
         """Resolves conflicts between a set of state groups
 
         Always generates a new state group (unless we hit the cache), so should
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index e15f7ee698..4dc3477e89 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -15,40 +15,50 @@
 
 import logging
 from collections import OrderedDict
+from typing import Any, Generic, Optional, TypeVar, Union, overload
+
+import attr
+from typing_extensions import Literal
 
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
 from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
 
-SENTINEL = object()
+SENTINEL = object()  # type: Any
+
 
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
 
-class ExpiringCache:
+
+class ExpiringCache(Generic[KT, VT]):
     def __init__(
         self,
-        cache_name,
-        clock,
-        max_len=0,
-        expiry_ms=0,
-        reset_expiry_on_get=False,
-        iterable=False,
+        cache_name: str,
+        clock: Clock,
+        max_len: int = 0,
+        expiry_ms: int = 0,
+        reset_expiry_on_get: bool = False,
+        iterable: bool = False,
     ):
         """
         Args:
-            cache_name (str): Name of this cache, used for logging.
-            clock (Clock)
-            max_len (int): Max size of dict. If the dict grows larger than this
+            cache_name: Name of this cache, used for logging.
+            clock
+            max_len: Max size of dict. If the dict grows larger than this
                 then the oldest items get automatically evicted. Default is 0,
                 which indicates there is no max limit.
-            expiry_ms (int): How long before an item is evicted from the cache
+            expiry_ms: How long before an item is evicted from the cache
                 in milliseconds. Default is 0, indicating items never get
                 evicted based on time.
-            reset_expiry_on_get (bool): If true, will reset the expiry time for
+            reset_expiry_on_get: If true, will reset the expiry time for
                 an item on access. Defaults to False.
-            iterable (bool): If true, the size is calculated by summing the
+            iterable: If true, the size is calculated by summing the
                 sizes of all entries, rather than the number of entries.
         """
         self._cache_name = cache_name
@@ -62,7 +72,7 @@ class ExpiringCache:
         self._expiry_ms = expiry_ms
         self._reset_expiry_on_get = reset_expiry_on_get
 
-        self._cache = OrderedDict()
+        self._cache = OrderedDict()  # type: OrderedDict[KT, _CacheEntry]
 
         self.iterable = iterable
 
@@ -79,12 +89,12 @@ class ExpiringCache:
 
         self._clock.looping_call(f, self._expiry_ms / 2)
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: KT, value: VT) -> None:
         now = self._clock.time_msec()
         self._cache[key] = _CacheEntry(now, value)
         self.evict()
 
-    def evict(self):
+    def evict(self) -> None:
         # Evict if there are now too many items
         while self._max_size and len(self) > self._max_size:
             _key, value = self._cache.popitem(last=False)
@@ -93,7 +103,7 @@ class ExpiringCache:
             else:
                 self.metrics.inc_evictions()
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: KT) -> VT:
         try:
             entry = self._cache[key]
             self.metrics.inc_hits()
@@ -106,7 +116,7 @@ class ExpiringCache:
 
         return entry.value
 
-    def pop(self, key, default=SENTINEL):
+    def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
         """Removes and returns the value with the given key from the cache.
 
         If the key isn't in the cache then `default` will be returned if
@@ -115,29 +125,40 @@ class ExpiringCache:
         Identical functionality to `dict.pop(..)`.
         """
 
-        value = self._cache.pop(key, default)
+        value = self._cache.pop(key, SENTINEL)
+        # The key was not found.
         if value is SENTINEL:
-            raise KeyError(key)
+            if default is SENTINEL:
+                raise KeyError(key)
+            return default
 
-        return value
+        return value.value
 
-    def __contains__(self, key):
+    def __contains__(self, key: KT) -> bool:
         return key in self._cache
 
-    def get(self, key, default=None):
+    @overload
+    def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
+        ...
+
+    @overload
+    def get(self, key: KT, default: T) -> Union[VT, T]:
+        ...
+
+    def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
         try:
             return self[key]
         except KeyError:
             return default
 
-    def setdefault(self, key, value):
+    def setdefault(self, key: KT, value: VT) -> VT:
         try:
             return self[key]
         except KeyError:
             self[key] = value
             return value
 
-    def _prune_cache(self):
+    def _prune_cache(self) -> None:
         if not self._expiry_ms:
             # zero expiry time means don't expire. This should never get called
             # since we have this check in start too.
@@ -166,7 +187,7 @@ class ExpiringCache:
             len(self),
         )
 
-    def __len__(self):
+    def __len__(self) -> int:
         if self.iterable:
             return sum(len(entry.value) for entry in self._cache.values())
         else:
@@ -190,9 +211,7 @@ class ExpiringCache:
         return False
 
 
+@attr.s(slots=True)
 class _CacheEntry:
-    __slots__ = ["time", "value"]
-
-    def __init__(self, time, value):
-        self.time = time
-        self.value = value
+    time = attr.ib(type=int)
+    value = attr.ib()