diff --git a/changelog.d/9730.misc b/changelog.d/9730.misc
new file mode 100644
index 0000000000..8063059b0b
--- /dev/null
+++ b/changelog.d/9730.misc
@@ -0,0 +1 @@
+Add type hints to expiring cache.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index afdb5bf2fa..55533d7501 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -102,7 +102,7 @@ class FederationClient(FederationBase):
max_len=1000,
expiry_ms=120 * 1000,
reset_expiry_on_get=False,
- )
+ ) # type: ExpiringCache[str, EventBase]
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 54293d0b9c..7e76db3e2a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -631,7 +631,7 @@ class DeviceListUpdater:
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
- )
+ ) # type: ExpiringCache[str, Set[str]]
# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
@@ -760,7 +760,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
- seen_updates = self._seen_updates.get(user_id, set())
+ seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 739653a3fa..92b18378fc 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -38,7 +38,6 @@ from synapse.types import (
)
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater:
# user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
- # Recently seen stream ids. We don't bother keeping these in the DB,
- # but they're useful to have them about to reduce the number of spurious
- # resyncs.
- self._seen_updates = ExpiringCache(
- cache_name="signing_key_update_edu",
- clock=self.clock,
- max_len=10000,
- expiry_ms=30 * 60 * 1000,
- iterable=True,
- )
-
async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
) -> None:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7b356ba7e5..ff11266c67 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -252,13 +252,13 @@ class SyncHandler:
self.storage = hs.get_storage()
self.state_store = self.storage.state
- # ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
+ # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
"lazy_loaded_members_cache",
self.clock,
max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
- )
+ ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
async def wait_for_sync_for_user(
self,
@@ -733,8 +733,10 @@ class SyncHandler:
def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]]
- ) -> LruCache:
- cache = self.lazy_loaded_members_cache.get(cache_key)
+ ) -> LruCache[str, str]:
+ cache = self.lazy_loaded_members_cache.get(
+ cache_key
+ ) # type: Optional[LruCache[str, str]]
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index c4ed9dfdb4..814145a04a 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -175,7 +175,7 @@ class PreviewUrlResource(DirectServeJsonResource):
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR,
- )
+ ) # type: ExpiringCache[str, ObservableDeferred]
if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call(
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c3d6e80c49..c0f79ffdc8 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -22,6 +22,7 @@ from typing import (
Callable,
DefaultDict,
Dict,
+ FrozenSet,
Iterable,
List,
Optional,
@@ -515,7 +516,7 @@ class StateResolutionHandler:
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
- )
+ ) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]
#
# stuff for tracking time spent on state-res by room
@@ -536,7 +537,7 @@ class StateResolutionHandler:
state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
- ):
+ ) -> _StateCacheEntry:
"""Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index e15f7ee698..4dc3477e89 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -15,40 +15,50 @@
import logging
from collections import OrderedDict
+from typing import Any, Generic, Optional, TypeVar, Union, overload
+
+import attr
+from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
-SENTINEL = object()
+SENTINEL = object() # type: Any
+
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
-class ExpiringCache:
+
+class ExpiringCache(Generic[KT, VT]):
def __init__(
self,
- cache_name,
- clock,
- max_len=0,
- expiry_ms=0,
- reset_expiry_on_get=False,
- iterable=False,
+ cache_name: str,
+ clock: Clock,
+ max_len: int = 0,
+ expiry_ms: int = 0,
+ reset_expiry_on_get: bool = False,
+ iterable: bool = False,
):
"""
Args:
- cache_name (str): Name of this cache, used for logging.
- clock (Clock)
- max_len (int): Max size of dict. If the dict grows larger than this
+ cache_name: Name of this cache, used for logging.
+ clock
+ max_len: Max size of dict. If the dict grows larger than this
then the oldest items get automatically evicted. Default is 0,
which indicates there is no max limit.
- expiry_ms (int): How long before an item is evicted from the cache
+ expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
- reset_expiry_on_get (bool): If true, will reset the expiry time for
+ reset_expiry_on_get: If true, will reset the expiry time for
an item on access. Defaults to False.
- iterable (bool): If true, the size is calculated by summing the
+ iterable: If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
"""
self._cache_name = cache_name
@@ -62,7 +72,7 @@ class ExpiringCache:
self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get
- self._cache = OrderedDict()
+ self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry]
self.iterable = iterable
@@ -79,12 +89,12 @@ class ExpiringCache:
self._clock.looping_call(f, self._expiry_ms / 2)
- def __setitem__(self, key, value):
+ def __setitem__(self, key: KT, value: VT) -> None:
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
self.evict()
- def evict(self):
+ def evict(self) -> None:
# Evict if there are now too many items
while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False)
@@ -93,7 +103,7 @@ class ExpiringCache:
else:
self.metrics.inc_evictions()
- def __getitem__(self, key):
+ def __getitem__(self, key: KT) -> VT:
try:
entry = self._cache[key]
self.metrics.inc_hits()
@@ -106,7 +116,7 @@ class ExpiringCache:
return entry.value
- def pop(self, key, default=SENTINEL):
+ def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Removes and returns the value with the given key from the cache.
If the key isn't in the cache then `default` will be returned if
@@ -115,29 +125,40 @@ class ExpiringCache:
Identical functionality to `dict.pop(..)`.
"""
- value = self._cache.pop(key, default)
+ value = self._cache.pop(key, SENTINEL)
+ # The key was not found.
if value is SENTINEL:
- raise KeyError(key)
+ if default is SENTINEL:
+ raise KeyError(key)
+ return default
- return value
+ return value.value
- def __contains__(self, key):
+ def __contains__(self, key: KT) -> bool:
return key in self._cache
- def get(self, key, default=None):
+ @overload
+ def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
+ ...
+
+ @overload
+ def get(self, key: KT, default: T) -> Union[VT, T]:
+ ...
+
+ def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
try:
return self[key]
except KeyError:
return default
- def setdefault(self, key, value):
+ def setdefault(self, key: KT, value: VT) -> VT:
try:
return self[key]
except KeyError:
self[key] = value
return value
- def _prune_cache(self):
+ def _prune_cache(self) -> None:
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
@@ -166,7 +187,7 @@ class ExpiringCache:
len(self),
)
- def __len__(self):
+ def __len__(self) -> int:
if self.iterable:
return sum(len(entry.value) for entry in self._cache.values())
else:
@@ -190,9 +211,7 @@ class ExpiringCache:
return False
+@attr.s(slots=True)
class _CacheEntry:
- __slots__ = ["time", "value"]
-
- def __init__(self, time, value):
- self.time = time
- self.value = value
+ time = attr.ib(type=int)
+ value = attr.ib()
|