summary refs log tree commit diff
path: root/tests/storage/test_event_federation.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-12-09 12:36:32 -0500
committerGitHub <noreply@github.com>2022-12-09 12:36:32 -0500
commit3ac412b4e2f8c5ba11dc962b8a9d871c1efdce9b (patch)
treea08170e3c286e790b0c4596dc6d9ec884996c532 /tests/storage/test_event_federation.py
parentLimit the number of devices we delete at once (#14649) (diff)
downloadsynapse-3ac412b4e2f8c5ba11dc962b8a9d871c1efdce9b.tar.xz
Require types in tests.storage. (#14646)
Adds missing type hints to `tests.storage` package
and does not allow untyped definitions.
Diffstat (limited to 'tests/storage/test_event_federation.py')
-rw-r--r--tests/storage/test_event_federation.py71
1 files changed, 35 insertions, 36 deletions
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 853db930d6..7fd3e01364 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import datetime
-from typing import Dict, List, Tuple, Union
+from typing import Dict, List, Tuple, Union, cast
 
 import attr
 from parameterized import parameterized
@@ -26,11 +26,12 @@ from synapse.api.room_versions import (
     EventFormatVersions,
     RoomVersion,
 )
-from synapse.events import _EventInternalMetadata
+from synapse.events import EventBase, _EventInternalMetadata
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.storage.database import LoggingTransaction
+from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 from synapse.util import Clock, json_encoder
 
@@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
-    def test_get_prev_events_for_room(self):
+    def test_get_prev_events_for_room(self) -> None:
         room_id = "@ROOM:local"
 
         # add a bunch of events and hashes to act as forward extremities
-        def insert_event(txn, i):
+        def insert_event(txn: Cursor, i: int) -> None:
             event_id = "$event_%i:local" % i
 
             txn.execute(
@@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         for i in range(0, 10):
             self.assertEqual("$event_%i:local" % (19 - i), r[i])
 
-    def test_get_rooms_with_many_extremities(self):
+    def test_get_rooms_with_many_extremities(self) -> None:
         room1 = "#room1"
         room2 = "#room2"
         room3 = "#room3"
 
-        def insert_event(txn, i, room_id):
+        def insert_event(txn: Cursor, i: int, room_id: str) -> None:
             event_id = "$event_%i:local" % i
             txn.execute(
                 (
@@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         #     |   |
         #     K   J
 
-        auth_graph = {
+        auth_graph: Dict[str, List[str]] = {
             "a": ["e"],
             "b": ["e"],
             "c": ["g", "i"],
@@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
         # Mark the room as maybe having a cover index.
 
-        def store_room(txn):
+        def store_room(txn: LoggingTransaction) -> None:
             self.store.db_pool.simple_insert_txn(
                 txn,
                 "rooms",
@@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         # We rudely fiddle with the appropriate tables directly, as that's much
         # easier than constructing events properly.
 
-        def insert_event(txn):
+        def insert_event(txn: LoggingTransaction) -> None:
             stream_ordering = 0
 
             for event_id in auth_graph:
@@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
-                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
                     for event_id in auth_graph
                 ],
             )
@@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         return room_id
 
     @parameterized.expand([(True,), (False,)])
-    def test_auth_chain_ids(self, use_chain_cover_index: bool):
+    def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None:
         room_id = self._setup_auth_chain(use_chain_cover_index)
 
         # a and b have the same auth chain.
@@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         self.assertCountEqual(auth_chain_ids, ["i", "j"])
 
     @parameterized.expand([(True,), (False,)])
-    def test_auth_difference(self, use_chain_cover_index: bool):
+    def test_auth_difference(self, use_chain_cover_index: bool) -> None:
         room_id = self._setup_auth_chain(use_chain_cover_index)
 
         # Now actually test that various combinations give the right result:
@@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         )
         self.assertSetEqual(difference, set())
 
-    def test_auth_difference_partial_cover(self):
+    def test_auth_difference_partial_cover(self) -> None:
         """Test that we correctly handle rooms where not all events have a chain
         cover calculated. This can happen in some obscure edge cases, including
         during the background update that calculates the chain cover for old
@@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         #     |   |
         #     K   J
 
-        auth_graph = {
+        auth_graph: Dict[str, List[str]] = {
             "a": ["e"],
             "b": ["e"],
             "c": ["g", "i"],
@@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         # We rudely fiddle with the appropriate tables directly, as that's much
         # easier than constructing events properly.
 
-        def insert_event(txn):
+        def insert_event(txn: LoggingTransaction) -> None:
             # First insert the room and mark it as having a chain cover.
             self.store.db_pool.simple_insert_txn(
                 txn,
@@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
-                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
                     for event_id in auth_graph
                     if event_id != "b"
                 ],
@@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
             self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
-                [FakeEvent("b", room_id, auth_graph["b"])],
+                [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
             )
 
             self.store.db_pool.simple_update_txn(
@@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
     @parameterized.expand(
         [(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
     )
-    def test_prune_inbound_federation_queue(self, room_version: RoomVersion):
+    def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None:
         """Test that pruning of inbound federation queues work"""
 
         room_id = "some_room_id"
@@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
             stream_ordering += 1
 
-        def populate_db(txn: LoggingTransaction):
+        def populate_db(txn: LoggingTransaction) -> None:
             # Insert the room to satisfy the foreign key constraint of
             # `event_failed_pull_attempts`
             self.store.db_pool.simple_insert_txn(
@@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
         return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
 
-    def test_get_backfill_points_in_room(self):
+    def test_get_backfill_points_in_room(self) -> None:
         """
         Test to make sure only backfill points that are older and come before
         the `current_depth` are returned.
@@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
         self,
-    ):
+    ) -> None:
         """
         Test to make sure that events we have attempted to backfill (and within
         backoff timeout duration) do not show up as an event to backfill again.
@@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
         self,
-    ):
+    ) -> None:
         """
         Test to make sure after we fake attempt to backfill event "b3" many times,
         we can see retry and see the "b3" again after the backoff timeout duration
@@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             "5": 7,
         }
 
-        def populate_db(txn: LoggingTransaction):
+        def populate_db(txn: LoggingTransaction) -> None:
             # Insert the room to satisfy the foreign key constraint of
             # `event_failed_pull_attempts`
             self.store.db_pool.simple_insert_txn(
@@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
         return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
 
-    def test_get_insertion_event_backward_extremities_in_room(self):
+    def test_get_insertion_event_backward_extremities_in_room(self) -> None:
         """
         Test to make sure only insertion event backward extremities that are
         older and come before the `current_depth` are returned.
@@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
         self,
-    ):
+    ) -> None:
         """
         Test to make sure that insertion events we have attempted to backfill
         (and within backoff timeout duration) do not show up as an event to
@@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
         self,
-    ):
+    ) -> None:
         """
         Test to make sure after we fake attempt to backfill event
         "insertion_eventA" many times, we can see retry and see the
@@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
         self.assertEqual(backfill_event_ids, ["insertion_eventA"])
 
-    def test_get_event_ids_to_not_pull_from_backoff(
-        self,
-    ):
+    def test_get_event_ids_to_not_pull_from_backoff(self) -> None:
         """
         Test to make sure only event IDs we should backoff from are returned.
         """
@@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
         self,
-    ):
+    ) -> None:
         """
         Test to make sure no event IDs are returned after the backoff duration has
         elapsed.
@@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         self.assertEqual(event_ids_to_backoff, [])
 
 
-@attr.s
+@attr.s(auto_attribs=True)
 class FakeEvent:
-    event_id = attr.ib()
-    room_id = attr.ib()
-    auth_events = attr.ib()
+    event_id: str
+    room_id: str
+    auth_events: List[str]
 
     type = "foo"
     state_key = "foo"
 
     internal_metadata = _EventInternalMetadata({})
 
-    def auth_event_ids(self):
+    def auth_event_ids(self) -> List[str]:
         return self.auth_events
 
-    def is_state(self):
+    def is_state(self) -> bool:
         return True