summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16356.misc1
-rw-r--r--synapse/api/filtering.py8
-rw-r--r--synapse/federation/federation_client.py4
-rw-r--r--synapse/handlers/federation_event.py2
-rw-r--r--synapse/handlers/relations.py14
-rw-r--r--synapse/rest/client/filter.py4
-rw-r--r--synapse/storage/controllers/state.py2
-rw-r--r--synapse/storage/databases/main/filtering.py4
-rw-r--r--synapse/storage/databases/main/relations.py4
-rw-r--r--synapse/storage/databases/main/roommember.py10
-rw-r--r--tests/util/caches/test_descriptors.py35
11 files changed, 52 insertions, 36 deletions
diff --git a/changelog.d/16356.misc b/changelog.d/16356.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16356.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 0995ecbe83..74ee8e9f3f 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -37,7 +37,7 @@ from synapse.api.constants import EduTypes, EventContentFields
 from synapse.api.errors import SynapseError
 from synapse.api.presence import UserPresenceState
 from synapse.events import EventBase, relation_from_event
-from synapse.types import JsonDict, RoomID, UserID
+from synapse.types import JsonDict, JsonMapping, RoomID, UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -191,7 +191,7 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
 
 
 class FilterCollection:
-    def __init__(self, hs: "HomeServer", filter_json: JsonDict):
+    def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
         self._filter_json = filter_json
 
         room_filter_json = self._filter_json.get("room", {})
@@ -219,7 +219,7 @@ class FilterCollection:
     def __repr__(self) -> str:
         return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
 
-    def get_filter_json(self) -> JsonDict:
+    def get_filter_json(self) -> JsonMapping:
         return self._filter_json
 
     def timeline_limit(self) -> int:
@@ -313,7 +313,7 @@ class FilterCollection:
 
 
 class Filter:
-    def __init__(self, hs: "HomeServer", filter_json: JsonDict):
+    def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
         self._hs = hs
         self._store = hs.get_datastores().main
         self.filter_json = filter_json
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 607013f121..c8bc46415d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -64,7 +64,7 @@ from synapse.federation.transport.client import SendJoinResponse
 from synapse.http.client import is_unknown_endpoint
 from synapse.http.types import QueryParams
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
@@ -1704,7 +1704,7 @@ class FederationClient(FederationBase):
     async def timestamp_to_event(
         self,
         *,
-        destinations: List[str],
+        destinations: StrCollection,
         room_id: str,
         timestamp: int,
         direction: Direction,
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index eedde97ab0..7c62cdfaef 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1538,7 +1538,7 @@ class FederationEventHandler:
             logger.exception("Failed to resync device for %s", sender)
 
     async def backfill_event_id(
-        self, destinations: List[str], room_id: str, event_id: str
+        self, destinations: StrCollection, room_id: str, event_id: str
     ) -> PulledPduInfo:
         """Backfill a single event and persist it as a non-outlier which means
         we also pull in all of the state and auth events necessary for it.
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index db97f7aede..9b13448cdd 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -13,7 +13,17 @@
 # limitations under the License.
 import enum
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    FrozenSet,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+)
 
 import attr
 
@@ -245,7 +255,7 @@ class RelationsHandler:
 
     async def get_references_for_events(
         self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
-    ) -> Dict[str, List[_RelatedEvent]]:
+    ) -> Mapping[str, Sequence[_RelatedEvent]]:
         """Get a list of references to the given events.
 
         Args:
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index 5da1e511a2..b5879496db 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseErro
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, JsonMapping, UserID
 
 from ._base import client_patterns, set_timeline_upper_limit
 
@@ -41,7 +41,7 @@ class GetFilterRestServlet(RestServlet):
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str, filter_id: str
-    ) -> Tuple[int, JsonDict]:
+    ) -> Tuple[int, JsonMapping]:
         target_user = UserID.from_string(user_id)
         requester = await self.auth.get_user_by_req(request)
 
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 278c7832ba..10d219c045 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -582,7 +582,7 @@ class StateStorageController:
 
     @trace
     @tag_args
-    async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
+    async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
         """Get current hosts in room based on current state.
 
         Blocks until we have full state for the given room. This only happens for rooms
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 047de6283a..7d94685caf 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -25,7 +25,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, JsonMapping, UserID
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore):
     @cached(num_args=2)
     async def get_user_filter(
         self, user_id: UserID, filter_id: Union[int, str]
-    ) -> JsonDict:
+    ) -> JsonMapping:
         # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
         # with a coherent error message rather than 500 M_UNKNOWN.
         try:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 6ba9c9651f..b67f780c10 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -465,7 +465,7 @@ class RelationsWorkerStore(SQLBaseStore):
     @cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
     async def get_references_for_events(
         self, event_ids: Collection[str]
-    ) -> Mapping[str, Optional[List[_RelatedEvent]]]:
+    ) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]:
         """Get a list of references to the given events.
 
         Args:
@@ -931,7 +931,7 @@ class RelationsWorkerStore(SQLBaseStore):
         room_id: str,
         limit: int = 5,
         from_token: Optional[ThreadsNextBatch] = None,
-    ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+    ) -> Tuple[Sequence[str], Optional[ThreadsNextBatch]]:
         """Get a list of thread IDs, ordered by topological ordering of their
         latest reply.
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 7b503dd697..3755773faa 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -984,7 +984,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
         )
 
     @cached(iterable=True, max_entries=10000)
-    async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
+    async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
         """
         Get current hosts in room based on current state.
 
@@ -1013,12 +1013,14 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             # `get_users_in_room` rather than funky SQL.
 
             domains = await self.get_current_hosts_in_room(room_id)
-            return list(domains)
+            return tuple(domains)
 
         # For PostgreSQL we can use a regex to pull out the domains from the
         # joined users in `current_state_events` via regex.
 
-        def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
+        def get_current_hosts_in_room_ordered_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[str, ...]:
             # Returns a list of servers currently joined in the room sorted by
             # longest in the room first (aka. with the lowest depth). The
             # heuristic of sorting by servers who have been in the room the
@@ -1043,7 +1045,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             """
             txn.execute(sql, (room_id,))
             # `server_domain` will be `NULL` for malformed MXIDs with no colons.
-            return [d for d, in txn if d is not None]
+            return tuple(d for d, in txn if d is not None)
 
         return await self.db_pool.runInteraction(
             "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 168419f440..7e8725e610 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -15,10 +15,10 @@
 import logging
 from typing import (
     Any,
-    Dict,
     Generator,
     Iterable,
     List,
+    Mapping,
     NoReturn,
     Optional,
     Set,
@@ -96,7 +96,7 @@ class DescriptorTestCase(unittest.TestCase):
                 self.mock = mock.Mock()
 
             @descriptors.cached(num_args=1)
-            def fn(self, arg1: int, arg2: int) -> mock.Mock:
+            def fn(self, arg1: int, arg2: int) -> str:
                 return self.mock(arg1, arg2)
 
         obj = Cls()
@@ -228,8 +228,9 @@ class DescriptorTestCase(unittest.TestCase):
             call_count = 0
 
             @cached()
-            def fn(self, arg1: int) -> Optional[Deferred]:
+            def fn(self, arg1: int) -> Deferred:
                 self.call_count += 1
+                assert self.result is not None
                 return self.result
 
         obj = Cls()
@@ -401,21 +402,21 @@ class DescriptorTestCase(unittest.TestCase):
                 self.mock = mock.Mock()
 
             @descriptors.cached(iterable=True)
-            def fn(self, arg1: int, arg2: int) -> List[str]:
+            def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]:
                 return self.mock(arg1, arg2)
 
         obj = Cls()
 
-        obj.mock.return_value = ["spam", "eggs"]
+        obj.mock.return_value = ("spam", "eggs")
         r = obj.fn(1, 2)
-        self.assertEqual(r.result, ["spam", "eggs"])
+        self.assertEqual(r.result, ("spam", "eggs"))
         obj.mock.assert_called_once_with(1, 2)
         obj.mock.reset_mock()
 
         # a call with different params should call the mock again
-        obj.mock.return_value = ["chips"]
+        obj.mock.return_value = ("chips",)
         r = obj.fn(1, 3)
-        self.assertEqual(r.result, ["chips"])
+        self.assertEqual(r.result, ("chips",))
         obj.mock.assert_called_once_with(1, 3)
         obj.mock.reset_mock()
 
@@ -423,9 +424,9 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(len(obj.fn.cache.cache), 3)
 
         r = obj.fn(1, 2)
-        self.assertEqual(r.result, ["spam", "eggs"])
+        self.assertEqual(r.result, ("spam", "eggs"))
         r = obj.fn(1, 3)
-        self.assertEqual(r.result, ["chips"])
+        self.assertEqual(r.result, ("chips",))
         obj.mock.assert_not_called()
 
     def test_cache_iterable_with_sync_exception(self) -> None:
@@ -784,7 +785,9 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 pass
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
-            async def list_fn(self, args1: Iterable[int], arg2: int) -> Dict[int, str]:
+            async def list_fn(
+                self, args1: Iterable[int], arg2: int
+            ) -> Mapping[int, str]:
                 context = current_context()
                 assert isinstance(context, LoggingContext)
                 assert context.name == "c1"
@@ -847,11 +850,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 pass
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
-            def list_fn(self, args1: List[int]) -> "Deferred[dict]":
+            def list_fn(self, args1: List[int]) -> "Deferred[Mapping[int, str]]":
                 return self.mock(args1)
 
         obj = Cls()
-        deferred_result: "Deferred[dict]" = Deferred()
+        deferred_result: "Deferred[Mapping[int, str]]" = Deferred()
         obj.mock.return_value = deferred_result
 
         # start off several concurrent lookups of the same key
@@ -890,7 +893,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 pass
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
-            async def list_fn(self, args1: List[int], arg2: int) -> Dict[int, str]:
+            async def list_fn(self, args1: List[int], arg2: int) -> Mapping[int, str]:
                 # we want this to behave like an asynchronous function
                 await run_on_reactor()
                 return self.mock(args1, arg2)
@@ -929,7 +932,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 pass
 
             @cachedList(cached_method_name="fn", list_name="args")
-            async def list_fn(self, args: List[int]) -> Dict[int, str]:
+            async def list_fn(self, args: List[int]) -> Mapping[int, str]:
                 await complete_lookup
                 return {arg: str(arg) for arg in args}
 
@@ -964,7 +967,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 pass
 
             @cachedList(cached_method_name="fn", list_name="args")
-            async def list_fn(self, args: List[int]) -> Dict[int, str]:
+            async def list_fn(self, args: List[int]) -> Mapping[int, str]:
                 await make_deferred_yieldable(complete_lookup)
                 self.inner_context_was_finished = current_context().finished
                 return {arg: str(arg) for arg in args}