summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16326.misc1
-rw-r--r--synapse/events/builder.py2
-rw-r--r--synapse/handlers/federation_event.py8
-rw-r--r--synapse/storage/controllers/persist_events.py9
-rw-r--r--synapse/storage/databases/main/event_federation.py8
-rw-r--r--synapse/storage/databases/main/events.py2
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/replication/storage/test_events.py4
-rw-r--r--tests/replication/tcp/streams/test_events.py10
-rw-r--r--tests/replication/test_federation_sender_shard.py2
-rw-r--r--tests/storage/test_cleanup_extrems.py14
-rw-r--r--tests/test_federation.py26
12 files changed, 48 insertions, 40 deletions
diff --git a/changelog.d/16326.misc b/changelog.d/16326.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16326.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 1165c017ba..43469b170f 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -103,7 +103,7 @@ class EventBuilder:
 
     async def build(
         self,
-        prev_event_ids: StrCollection,
+        prev_event_ids: List[str],
         auth_event_ids: Optional[List[str]],
         depth: Optional[int] = None,
     ) -> EventBase:
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index d32d224d56..eedde97ab0 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -723,12 +723,11 @@ class FederationEventHandler:
         if not prevs - seen:
             return
 
-        latest_list = await self._store.get_latest_event_ids_in_room(room_id)
+        latest_frozen = await self._store.get_latest_event_ids_in_room(room_id)
 
         # We add the prev events that we have seen to the latest
         # list to ensure the remote server doesn't give them to us
-        latest = set(latest_list)
-        latest |= seen
+        latest = seen | latest_frozen
 
         logger.info(
             "Requesting missing events between %s and %s",
@@ -1976,8 +1975,7 @@ class FederationEventHandler:
             # partial and full state and may not be accurate.
             return
 
-        extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
-        extrem_ids = set(extrem_ids_list)
+        extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id)
         prev_event_ids = set(event.prev_event_ids())
 
         if extrem_ids == prev_event_ids:
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 6864f93090..f39ae2d635 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -19,6 +19,7 @@ import logging
 from collections import deque
 from typing import (
     TYPE_CHECKING,
+    AbstractSet,
     Any,
     Awaitable,
     Callable,
@@ -618,7 +619,7 @@ class EventsPersistenceStorageController:
                         )
 
                     for room_id, ev_ctx_rm in events_by_room.items():
-                        latest_event_ids = set(
+                        latest_event_ids = (
                             await self.main_store.get_latest_event_ids_in_room(room_id)
                         )
                         new_latest_event_ids = await self._calculate_new_extremities(
@@ -740,7 +741,7 @@ class EventsPersistenceStorageController:
         self,
         room_id: str,
         event_contexts: List[Tuple[EventBase, EventContext]],
-        latest_event_ids: Collection[str],
+        latest_event_ids: AbstractSet[str],
     ) -> Set[str]:
         """Calculates the new forward extremities for a room given events to
         persist.
@@ -758,8 +759,6 @@ class EventsPersistenceStorageController:
             and not event.internal_metadata.is_soft_failed()
         ]
 
-        latest_event_ids = set(latest_event_ids)
-
         # start with the existing forward extremities
         result = set(latest_event_ids)
 
@@ -798,7 +797,7 @@ class EventsPersistenceStorageController:
         self,
         room_id: str,
         events_context: List[Tuple[EventBase, EventContext]],
-        old_latest_event_ids: Set[str],
+        old_latest_event_ids: AbstractSet[str],
         new_latest_event_ids: Set[str],
     ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
         """Calculate the current state dict after adding some new events to
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 09de8f55e2..afffa54985 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -19,6 +19,7 @@ from typing import (
     TYPE_CHECKING,
     Collection,
     Dict,
+    FrozenSet,
     Iterable,
     List,
     Optional,
@@ -47,7 +48,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, StrCollection, StrSequence
+from synapse.types import JsonDict, StrCollection
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
@@ -1179,13 +1180,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     @cached(max_entries=5000, iterable=True)
-    async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
-        return await self.db_pool.simple_select_onecol(
+    async def get_latest_event_ids_in_room(self, room_id: str) -> FrozenSet[str]:
+        event_ids = await self.db_pool.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
             retcol="event_id",
             desc="get_latest_event_ids_in_room",
         )
+        return frozenset(event_ids)
 
     async def get_min_depth(self, room_id: str) -> Optional[int]:
         """For the given room, get the minimum depth we have seen for it."""
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 0c1ed75240..bc8474a589 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -222,7 +222,7 @@ class PersistEventsStore:
 
             for room_id, latest_event_ids in new_forward_extremities.items():
                 self.store.get_latest_event_ids_in_room.prefill(
-                    (room_id,), list(latest_event_ids)
+                    (room_id,), frozenset(latest_event_ids)
                 )
 
     async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 638787b029..41c8c44e02 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -1858,7 +1858,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         )
 
         event = self.get_success(
-            builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
+            builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
         )
 
         self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py
index af25815fa5..33c277a38a 100644
--- a/tests/replication/storage/test_events.py
+++ b/tests/replication/storage/test_events.py
@@ -90,7 +90,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
     def test_get_latest_event_ids_in_room(self) -> None:
         create = self.persist(type="m.room.create", key="", creator=USER_ID)
         self.replicate()
-        self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
+        self.check("get_latest_event_ids_in_room", (ROOM_ID,), {create.event_id})
 
         join = self.persist(
             type="m.room.member",
@@ -99,7 +99,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
             prev_events=[(create.event_id, {})],
         )
         self.replicate()
-        self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
+        self.check("get_latest_event_ids_in_room", (ROOM_ID,), {join.event_id})
 
     def test_redactions(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 65ef4bb160..128fc3e046 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, List, Optional, Sequence
+from typing import Any, List, Optional
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: Sequence[str] = self.get_success(
+        fork_point = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: Sequence[str] = self.get_success(
+        fork_point = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -316,14 +316,14 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.test_handler.received_rdata_rows.clear()
 
         # now roll back all that state by de-modding the users
-        prev_events = fork_point
+        prev_events = list(fork_point)
         pl_events = []
         for u in user_ids:
             pls["users"][u] = 0
             e = self.get_success(
                 inject_event(
                     self.hs,
-                    prev_event_ids=list(prev_events),
+                    prev_event_ids=prev_events,
                     type=EventTypes.PowerLevels,
                     state_key="",
                     sender=self.user_id,
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 9b28cd474f..59f4fdc70b 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -261,7 +261,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
 
         builder = factory.for_room_version(room_version, event_dict)
         join_event = self.get_success(
-            builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
+            builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
         )
 
         self.get_success(federation.on_send_membership_event(remote_server, join_event))
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 7de109966d..ceb9597dd3 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -120,7 +120,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
 
-        self.assertEqual(latest_event_ids, [event_id_4])
+        self.assertEqual(latest_event_ids, {event_id_4})
 
     def test_basic_cleanup(self) -> None:
         """Test that extremities are correctly calculated in the presence of
@@ -147,7 +147,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
+        self.assertEqual(latest_event_ids, {event_id_a, event_id_b})
 
         # Run the background update and check it did the right thing
         self.run_background_update()
@@ -155,7 +155,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(latest_event_ids, [event_id_b])
+        self.assertEqual(latest_event_ids, {event_id_b})
 
     def test_chain_of_fail_cleanup(self) -> None:
         """Test that extremities are correctly calculated in the presence of
@@ -185,7 +185,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
+        self.assertEqual(latest_event_ids, {event_id_a, event_id_b})
 
         # Run the background update and check it did the right thing
         self.run_background_update()
@@ -193,7 +193,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(latest_event_ids, [event_id_b])
+        self.assertEqual(latest_event_ids, {event_id_b})
 
     def test_forked_graph_cleanup(self) -> None:
         r"""Test that extremities are correctly calculated in the presence of
@@ -240,7 +240,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c})
+        self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c})
 
         # Run the background update and check it did the right thing
         self.run_background_update()
@@ -248,7 +248,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c})
+        self.assertEqual(latest_event_ids, {event_id_b, event_id_c})
 
 
 class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
diff --git a/tests/test_federation.py b/tests/test_federation.py
index f8ade6da38..1b0504709e 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -51,9 +51,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.store = self.hs.get_datastores().main
 
         # Figure out what the most recent event is
-        most_recent = self.get_success(
-            self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
-        )[0]
+        most_recent = next(
+            iter(
+                self.get_success(
+                    self.hs.get_datastores().main.get_latest_event_ids_in_room(
+                        self.room_id
+                    )
+                )
+            )
+        )
 
         join_event = make_event_from_dict(
             {
@@ -100,8 +106,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Make sure we actually joined the room
         self.assertEqual(
-            self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
-            "$join:test.serv",
+            self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
+            {"$join:test.serv"},
         )
 
     def test_cant_hide_direct_ancestors(self) -> None:
@@ -127,9 +133,11 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.http_client.post_json = post_json
 
         # Figure out what the most recent event is
-        most_recent = self.get_success(
-            self.store.get_latest_event_ids_in_room(self.room_id)
-        )[0]
+        most_recent = next(
+            iter(
+                self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
+            )
+        )
 
         # Now lie about an event
         lying_event = make_event_from_dict(
@@ -165,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Make sure the invalid event isn't there
         extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
-        self.assertEqual(extrem[0], "$join:test.serv")
+        self.assertEqual(extrem, {"$join:test.serv"})
 
     def test_retry_device_list_resync(self) -> None:
         """Tests that device lists are marked as stale if they couldn't be synced, and