summary refs log tree commit diff
path: root/tests/storage/test_event_federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_event_federation.py')
-rw-r--r--tests/storage/test_event_federation.py127
1 files changed, 95 insertions, 32 deletions
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 59b8910907..7fd3e01364 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import datetime
-from typing import Dict, List, Tuple, Union
+from typing import Dict, List, Tuple, Union, cast
 
 import attr
 from parameterized import parameterized
@@ -26,9 +26,12 @@ from synapse.api.room_versions import (
     EventFormatVersions,
     RoomVersion,
 )
-from synapse.events import _EventInternalMetadata
+from synapse.events import EventBase, _EventInternalMetadata
+from synapse.rest import admin
+from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.storage.database import LoggingTransaction
+from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 from synapse.util import Clock, json_encoder
 
@@ -43,14 +46,20 @@ class _BackfillSetupInfo:
 
 
 class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
-    def test_get_prev_events_for_room(self):
+    def test_get_prev_events_for_room(self) -> None:
         room_id = "@ROOM:local"
 
         # add a bunch of events and hashes to act as forward extremities
-        def insert_event(txn, i):
+        def insert_event(txn: Cursor, i: int) -> None:
             event_id = "$event_%i:local" % i
 
             txn.execute(
@@ -82,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         for i in range(0, 10):
             self.assertEqual("$event_%i:local" % (19 - i), r[i])
 
-    def test_get_rooms_with_many_extremities(self):
+    def test_get_rooms_with_many_extremities(self) -> None:
         room1 = "#room1"
         room2 = "#room2"
         room3 = "#room3"
 
-        def insert_event(txn, i, room_id):
+        def insert_event(txn: Cursor, i: int, room_id: str) -> None:
             event_id = "$event_%i:local" % i
             txn.execute(
                 (
@@ -147,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         #     |   |
         #     K   J
 
-        auth_graph = {
+        auth_graph: Dict[str, List[str]] = {
             "a": ["e"],
             "b": ["e"],
             "c": ["g", "i"],
@@ -177,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
         # Mark the room as maybe having a cover index.
 
-        def store_room(txn):
+        def store_room(txn: LoggingTransaction) -> None:
             self.store.db_pool.simple_insert_txn(
                 txn,
                 "rooms",
@@ -195,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         # We rudely fiddle with the appropriate tables directly, as that's much
         # easier than constructing events properly.
 
-        def insert_event(txn):
+        def insert_event(txn: LoggingTransaction) -> None:
             stream_ordering = 0
 
             for event_id in auth_graph:
@@ -220,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
-                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
                     for event_id in auth_graph
                 ],
             )
@@ -235,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         return room_id
 
     @parameterized.expand([(True,), (False,)])
-    def test_auth_chain_ids(self, use_chain_cover_index: bool):
+    def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None:
         room_id = self._setup_auth_chain(use_chain_cover_index)
 
         # a and b have the same auth chain.
@@ -300,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         self.assertCountEqual(auth_chain_ids, ["i", "j"])
 
     @parameterized.expand([(True,), (False,)])
-    def test_auth_difference(self, use_chain_cover_index: bool):
+    def test_auth_difference(self, use_chain_cover_index: bool) -> None:
         room_id = self._setup_auth_chain(use_chain_cover_index)
 
         # Now actually test that various combinations give the right result:
@@ -345,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         )
         self.assertSetEqual(difference, set())
 
-    def test_auth_difference_partial_cover(self):
+    def test_auth_difference_partial_cover(self) -> None:
         """Test that we correctly handle rooms where not all events have a chain
         cover calculated. This can happen in some obscure edge cases, including
         during the background update that calculates the chain cover for old
@@ -369,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         #     |   |
         #     K   J
 
-        auth_graph = {
+        auth_graph: Dict[str, List[str]] = {
             "a": ["e"],
             "b": ["e"],
             "c": ["g", "i"],
@@ -400,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         # We rudely fiddle with the appropriate tables directly, as that's much
         # easier than constructing events properly.
 
-        def insert_event(txn):
+        def insert_event(txn: LoggingTransaction) -> None:
             # First insert the room and mark it as having a chain cover.
             self.store.db_pool.simple_insert_txn(
                 txn,
@@ -439,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
-                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
                     for event_id in auth_graph
                     if event_id != "b"
                 ],
@@ -457,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
             self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
-                [FakeEvent("b", room_id, auth_graph["b"])],
+                [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
             )
 
             self.store.db_pool.simple_update_txn(
@@ -519,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
     @parameterized.expand(
         [(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
     )
-    def test_prune_inbound_federation_queue(self, room_version: RoomVersion):
+    def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None:
         """Test that pruning of inbound federation queues work"""
 
         room_id = "some_room_id"
@@ -678,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
             stream_ordering += 1
 
-        def populate_db(txn: LoggingTransaction):
+        def populate_db(txn: LoggingTransaction) -> None:
             # Insert the room to satisfy the foreign key constraint of
             # `event_failed_pull_attempts`
             self.store.db_pool.simple_insert_txn(
@@ -752,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
         return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
 
-    def test_get_backfill_points_in_room(self):
+    def test_get_backfill_points_in_room(self) -> None:
         """
         Test to make sure only backfill points that are older and come before
         the `current_depth` are returned.
@@ -779,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
         self,
-    ):
+    ) -> None:
         """
         Test to make sure that events we have attempted to backfill (and within
         backoff timeout duration) do not show up as an event to backfill again.
@@ -816,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
         self,
-    ):
+    ) -> None:
         """
         Test to make sure after we fake attempt to backfill event "b3" many times,
         we can see retry and see the "b3" again after the backoff timeout duration
@@ -933,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             "5": 7,
         }
 
-        def populate_db(txn: LoggingTransaction):
+        def populate_db(txn: LoggingTransaction) -> None:
             # Insert the room to satisfy the foreign key constraint of
             # `event_failed_pull_attempts`
             self.store.db_pool.simple_insert_txn(
@@ -988,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
         return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
 
-    def test_get_insertion_event_backward_extremities_in_room(self):
+    def test_get_insertion_event_backward_extremities_in_room(self) -> None:
         """
         Test to make sure only insertion event backward extremities that are
         older and come before the `current_depth` are returned.
@@ -1019,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
         self,
-    ):
+    ) -> None:
         """
         Test to make sure that insertion events we have attempted to backfill
         (and within backoff timeout duration) do not show up as an event to
@@ -1052,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
         self,
-    ):
+    ) -> None:
         """
         Test to make sure after we fake attempt to backfill event
         "insertion_eventA" many times, we can see retry and see the
@@ -1122,20 +1131,74 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
         self.assertEqual(backfill_event_ids, ["insertion_eventA"])
 
+    def test_get_event_ids_to_not_pull_from_backoff(self) -> None:
+        """
+        Test to make sure only event IDs we should backoff from are returned.
+        """
+        # Create the room
+        user_id = self.register_user("alice", "test")
+        tok = self.login("alice", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
 
-@attr.s
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(
+                room_id, "$failed_event_id", "fake cause"
+            )
+        )
+
+        event_ids_to_backoff = self.get_success(
+            self.store.get_event_ids_to_not_pull_from_backoff(
+                room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
+            )
+        )
+
+        self.assertEqual(event_ids_to_backoff, ["$failed_event_id"])
+
+    def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
+        self,
+    ) -> None:
+        """
+        Test to make sure no event IDs are returned after the backoff duration has
+        elapsed.
+        """
+        # Create the room
+        user_id = self.register_user("alice", "test")
+        tok = self.login("alice", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+        self.get_success(
+            self.store.record_event_failed_pull_attempt(
+                room_id, "$failed_event_id", "fake cause"
+            )
+        )
+
+        # Now advance time by 2 hours so we wait long enough for the single failed
+        # attempt (2^1 hours).
+        self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
+
+        event_ids_to_backoff = self.get_success(
+            self.store.get_event_ids_to_not_pull_from_backoff(
+                room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
+            )
+        )
+        # Since this function only returns events we should backoff from, time has
+        # elapsed past the backoff range so there is no events to backoff from.
+        self.assertEqual(event_ids_to_backoff, [])
+
+
+@attr.s(auto_attribs=True)
 class FakeEvent:
-    event_id = attr.ib()
-    room_id = attr.ib()
-    auth_events = attr.ib()
+    event_id: str
+    room_id: str
+    auth_events: List[str]
 
     type = "foo"
     state_key = "foo"
 
     internal_metadata = _EventInternalMetadata({})
 
-    def auth_event_ids(self):
+    def auth_event_ids(self) -> List[str]:
         return self.auth_events
 
-    def is_state(self):
+    def is_state(self) -> bool:
         return True