diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 1f8dafe7ea..273d627fad 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -26,7 +26,7 @@ class SlavedClientIpStore(BaseSlavedStore):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000
- )
+ ) # type: Cache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 98b34f2223..14458bc20f 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,12 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
import functools
import inspect
import logging
import threading
-from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ MutableMapping,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+)
from weakref import WeakValueDictionary
from prometheus_client import Gauge
@@ -38,6 +49,8 @@ logger = logging.getLogger(__name__)
CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])
+KT = TypeVar("KT")
+VT = TypeVar("VT")
class _CachedFunction(Generic[F]):
@@ -61,13 +74,19 @@ cache_pending_metric = Gauge(
["name"],
)
-_CacheSentinel = object()
+
+class _Sentinel(enum.Enum):
+ # defining a sentinel in this way allows mypy to correctly handle the
+ # type of a dictionary lookup.
+ sentinel = object()
class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
- def __init__(self, deferred, callbacks):
+ def __init__(
+ self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
+ ):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False
@@ -80,7 +99,13 @@ class CacheEntry:
self.callbacks.clear()
-class Cache:
+class Cache(Generic[KT, VT]):
+ """Wraps an LruCache, adding support for Deferred results.
+
+ It expects that each entry added with set() will be a Deferred; likewise get()
+ may return an ObservableDeferred.
+ """
+
__slots__ = (
"cache",
"name",
@@ -103,19 +128,23 @@ class Cache:
Args:
name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold
- keylen: The length of the tuple used as the cache key
+ keylen: The length of the tuple used as the cache key. Ignored unless
+ `tree` is True.
tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry,
rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries`
-
- Returns:
- Cache
"""
cache_type = TreeCache if tree else dict
- self._pending_deferred_cache = cache_type()
+ # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
+ self._pending_deferred_cache = (
+ cache_type()
+ ) # type: MutableMapping[KT, CacheEntry]
+
+ # cache is used for completed results and maps to the result itself, rather than
+ # a Deferred.
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
@@ -155,7 +184,13 @@ class Cache:
"Cache objects can only be accessed from the main thread"
)
- def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
+ def get(
+ self,
+ key: KT,
+ default=_Sentinel.sentinel,
+ callback: Optional[Callable[[], None]] = None,
+ update_metrics: bool = True,
+ ):
"""Looks the key up in the caches.
Args:
@@ -166,30 +201,32 @@ class Cache:
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
- Either an ObservableDeferred or the raw result
+ Either an ObservableDeferred or the result itself
"""
callbacks = [callback] if callback else []
- val = self._pending_deferred_cache.get(key, _CacheSentinel)
- if val is not _CacheSentinel:
+ val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+ if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
if update_metrics:
self.metrics.inc_hits()
return val.deferred
- val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
- if val is not _CacheSentinel:
+ val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
+ if val is not _Sentinel.sentinel:
self.metrics.inc_hits()
return val
if update_metrics:
self.metrics.inc_misses()
- if default is _CacheSentinel:
+ if default is _Sentinel.sentinel:
raise KeyError()
else:
return default
- def set(self, key, value, callback=None):
+ def set(
+ self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None
+ ) -> ObservableDeferred:
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
@@ -248,7 +285,7 @@ class Cache:
observer.addCallbacks(cb, eb)
return observable
- def prefill(self, key, value, callback=None):
+ def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
@@ -267,7 +304,7 @@ class Cache:
if entry:
entry.invalidate()
- def invalidate_many(self, key):
+ def invalidate_many(self, key: KT):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
@@ -275,7 +312,7 @@ class Cache:
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
- entry_dict = self._pending_deferred_cache.pop(key, None)
+ entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
@@ -396,7 +433,7 @@ class CacheDescriptor(_CacheDescriptorBase):
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
- )
+ ) # type: Cache[Tuple, Any]
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4bc1a67b58..33eae2b7c4 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -64,7 +64,8 @@ class LruCache:
Args:
max_size: The maximum amount of entries the cache can hold
- keylen: The length of the tuple used as the cache key
+ keylen: The length of the tuple used as the cache key. Ignored unless
+ cache_type is `TreeCache`.
cache_type (type):
type of underlying cache to be used. Typically one of dict
|