summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16276.misc1
-rw-r--r--synapse/config/oembed.py2
-rw-r--r--synapse/storage/controllers/persist_events.py8
-rw-r--r--synapse/util/async_helpers.py25
-rw-r--r--synapse/util/caches/dictionary_cache.py10
-rw-r--r--synapse/util/caches/expiringcache.py20
-rw-r--r--synapse/util/caches/ttlcache.py10
7 files changed, 37 insertions, 39 deletions
diff --git a/changelog.d/16276.misc b/changelog.d/16276.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16276.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py
index d7959639ee..59bc0b55f4 100644
--- a/synapse/config/oembed.py
+++ b/synapse/config/oembed.py
@@ -30,7 +30,7 @@ class OEmbedEndpointConfig:
     # The API endpoint to fetch.
     api_endpoint: str
     # The patterns to match.
-    url_patterns: List[Pattern]
+    url_patterns: List[Pattern[str]]
     # The supported formats.
     formats: Optional[List[str]]
 
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index abd1d149db..6864f93090 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -154,12 +154,13 @@ class _UpdateCurrentStateTask:
 
 
 _EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
+_PersistResult = TypeVar("_PersistResult")
 
 
 @attr.s(auto_attribs=True, slots=True)
-class _EventPersistQueueItem:
+class _EventPersistQueueItem(Generic[_PersistResult]):
     task: _EventPersistQueueTask
-    deferred: ObservableDeferred
+    deferred: ObservableDeferred[_PersistResult]
 
     parent_opentracing_span_contexts: List = attr.ib(factory=list)
     """A list of opentracing spans waiting for this batch"""
@@ -168,9 +169,6 @@ class _EventPersistQueueItem:
     """The opentracing span under which the persistence actually happened"""
 
 
-_PersistResult = TypeVar("_PersistResult")
-
-
 class _EventPeristenceQueue(Generic[_PersistResult]):
     """Queues up tasks so that they can be processed with only one concurrent
     transaction per room.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 943ad54456..0cbeb0c365 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -19,6 +19,7 @@ import collections
 import inspect
 import itertools
 import logging
+import typing
 from contextlib import asynccontextmanager
 from typing import (
     Any,
@@ -29,6 +30,7 @@ from typing import (
     Collection,
     Coroutine,
     Dict,
+    Generator,
     Generic,
     Hashable,
     Iterable,
@@ -398,7 +400,7 @@ class _LinearizerEntry:
     # The number of things executing.
     count: int
     # Deferreds for the things blocked from executing.
-    deferreds: collections.OrderedDict
+    deferreds: typing.OrderedDict["defer.Deferred[None]", Literal[1]]
 
 
 class Linearizer:
@@ -717,30 +719,25 @@ def timeout_deferred(
     return new_d
 
 
-# This class can't be generic because it uses slots with attrs.
-# See: https://github.com/python-attrs/attrs/issues/313
 @attr.s(slots=True, frozen=True, auto_attribs=True)
-class DoneAwaitable:  # should be: Generic[R]
+class DoneAwaitable(Awaitable[R]):
     """Simple awaitable that returns the provided value."""
 
-    value: Any  # should be: R
+    value: R
 
-    def __await__(self) -> Any:
-        return self
-
-    def __iter__(self) -> "DoneAwaitable":
-        return self
-
-    def __next__(self) -> None:
-        raise StopIteration(self.value)
+    def __await__(self) -> Generator[Any, None, R]:
+        yield None
+        return self.value
 
 
 def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
     """Convert a value to an awaitable if not already an awaitable."""
     if inspect.isawaitable(value):
-        assert isinstance(value, Awaitable)
         return value
 
+    # For some reason mypy doesn't deduce that value is not Awaitable here, even though
+    # inspect.isawaitable returns a TypeGuard.
+    assert not isinstance(value, Awaitable)
     return DoneAwaitable(value)
 
 
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 5eaf70c7ab..2fbc7b1e6c 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -14,7 +14,7 @@
 import enum
 import logging
 import threading
-from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
+from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
 
 import attr
 from typing_extensions import Literal
@@ -33,10 +33,8 @@ DKT = TypeVar("DKT")
 DV = TypeVar("DV")
 
 
-# This class can't be generic because it uses slots with attrs.
-# See: https://github.com/python-attrs/attrs/issues/313
 @attr.s(slots=True, frozen=True, auto_attribs=True)
-class DictionaryEntry:  # should be: Generic[DKT, DV].
+class DictionaryEntry(Generic[DKT, DV]):
     """Returned when getting an entry from the cache
 
     If `full` is true then `known_absent` will be the empty set.
@@ -50,8 +48,8 @@ class DictionaryEntry:  # should be: Generic[DKT, DV].
     """
 
     full: bool
-    known_absent: Set[Any]  # should be: Set[DKT]
-    value: Dict[Any, Any]  # should be: Dict[DKT, DV]
+    known_absent: Set[DKT]
+    value: Dict[DKT, DV]
 
     def __len__(self) -> int:
         return len(self.value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 01ad02af67..8e4c34039d 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -14,7 +14,7 @@
 
 import logging
 from collections import OrderedDict
-from typing import Any, Generic, Optional, TypeVar, Union, overload
+from typing import Any, Generic, Iterable, Optional, TypeVar, Union, overload
 
 import attr
 from typing_extensions import Literal
@@ -73,7 +73,7 @@ class ExpiringCache(Generic[KT, VT]):
         self._expiry_ms = expiry_ms
         self._reset_expiry_on_get = reset_expiry_on_get
 
-        self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
+        self._cache: OrderedDict[KT, _CacheEntry[VT]] = OrderedDict()
 
         self.iterable = iterable
 
@@ -100,7 +100,10 @@ class ExpiringCache(Generic[KT, VT]):
         while self._max_size and len(self) > self._max_size:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
-                self.metrics.inc_evictions(EvictionReason.size, len(value.value))
+                # type-ignore, here and below: if self.iterable is true, then the value
+                # type VT should be Sized (i.e. have a __len__ method). We don't enforce
+                # this via the type system at present.
+                self.metrics.inc_evictions(EvictionReason.size, len(value.value))  # type: ignore[arg-type]
             else:
                 self.metrics.inc_evictions(EvictionReason.size)
 
@@ -134,7 +137,7 @@ class ExpiringCache(Generic[KT, VT]):
             return default
 
         if self.iterable:
-            self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value))
+            self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value))  # type: ignore[arg-type]
         else:
             self.metrics.inc_evictions(EvictionReason.invalidation)
 
@@ -182,7 +185,7 @@ class ExpiringCache(Generic[KT, VT]):
         for k in keys_to_delete:
             value = self._cache.pop(k)
             if self.iterable:
-                self.metrics.inc_evictions(EvictionReason.time, len(value.value))
+                self.metrics.inc_evictions(EvictionReason.time, len(value.value))  # type: ignore[arg-type]
             else:
                 self.metrics.inc_evictions(EvictionReason.time)
 
@@ -195,7 +198,8 @@ class ExpiringCache(Generic[KT, VT]):
 
     def __len__(self) -> int:
         if self.iterable:
-            return sum(len(entry.value) for entry in self._cache.values())
+            g: Iterable[int] = (len(entry.value) for entry in self._cache.values())  # type: ignore[arg-type]
+            return sum(g)
         else:
             return len(self._cache)
 
@@ -218,6 +222,6 @@ class ExpiringCache(Generic[KT, VT]):
 
 
 @attr.s(slots=True, auto_attribs=True)
-class _CacheEntry:
+class _CacheEntry(Generic[VT]):
     time: int
-    value: Any
+    value: VT
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index f6b3ee31e4..48a6e4a906 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):
 
     def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
         # map from key to _CacheEntry
-        self._data: Dict[KT, _CacheEntry] = {}
+        self._data: Dict[KT, _CacheEntry[KT, VT]] = {}
 
         # the _CacheEntries, sorted by expiry time
-        self._expiry_list: SortedList[_CacheEntry] = SortedList()
+        self._expiry_list: SortedList[_CacheEntry[KT, VT]] = SortedList()
 
         self._timer = timer
 
@@ -160,11 +160,11 @@ class TTLCache(Generic[KT, VT]):
 
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
-class _CacheEntry:  # Should be Generic[KT, VT]. See python-attrs/attrs#313
+class _CacheEntry(Generic[KT, VT]):
     """TTLCache entry"""
 
     # expiry_time is the first attribute, so that entries are sorted by expiry.
     expiry_time: float
     ttl: float
-    key: Any  # should be KT
-    value: Any  # should be VT
+    key: KT
+    value: VT