summary refs log tree commit diff
path: root/tests/storage/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_state.py')
-rw-r--r--tests/storage/test_state.py46
1 files changed, 30 insertions, 16 deletions
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5564161750..d4e6d4236c 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -16,10 +16,15 @@ import logging
 
 from frozendict import frozendict
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.server import HomeServer
 from synapse.storage.state import StateFilter
-from synapse.types import RoomID, UserID
+from synapse.types import JsonDict, RoomID, StateMap, UserID
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase, TestCase
 
@@ -27,7 +32,7 @@ logger = logging.getLogger(__name__)
 
 
 class StateStoreTestCase(HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.storage = hs.get_storage_controllers()
         self.state_datastore = self.storage.state.stores.state
@@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase):
             )
         )
 
-    def inject_state_event(self, room, sender, typ, state_key, content):
+    def inject_state_event(
+        self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
+    ) -> EventBase:
         builder = self.event_builder_factory.for_room_version(
             RoomVersions.V1,
             {
@@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        assert self.storage.persistence is not None
         self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
-    def assertStateMapEqual(self, s1, s2):
+    def assertStateMapEqual(
+        self, s1: StateMap[EventBase], s2: StateMap[EventBase]
+    ) -> None:
         for t in s1:
             # just compare event IDs for simplicity
             self.assertEqual(s1[t].event_id, s2[t].event_id)
         self.assertEqual(len(s1), len(s2))
 
-    def test_get_state_groups_ids(self):
+    def test_get_state_groups_ids(self) -> None:
         e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
         e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
         state_group_map = self.get_success(
-            self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
+            self.storage.state.get_state_groups_ids(
+                self.room.to_string(), [e2.event_id]
+            )
         )
         self.assertEqual(len(state_group_map), 1)
         state_map = list(state_group_map.values())[0]
@@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase):
             {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
         )
 
-    def test_get_state_groups(self):
+    def test_get_state_groups(self) -> None:
         e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
         e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
         state_group_map = self.get_success(
-            self.storage.state.get_state_groups(self.room, [e2.event_id])
+            self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_list = list(state_group_map.values())[0]
 
         self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
 
-    def test_get_state_for_event(self):
+    def test_get_state_for_event(self) -> None:
         # this defaults to a linear DAG as each new injection defaults to whatever
         # forward extremities are currently in the DB for this room.
         e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase):
 class StateFilterDifferenceTestCase(TestCase):
     def assert_difference(
         self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
-    ):
+    ) -> None:
         self.assertEqual(
             minuend.approx_difference(subtrahend),
             expected,
             f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
         )
 
-    def test_state_filter_difference_no_include_other_minus_no_include_other(self):
+    def test_state_filter_difference_no_include_other_minus_no_include_other(
+        self,
+    ) -> None:
         """
         Tests the StateFilter.approx_difference method
         where, in a.approx_difference(b), both a and b do not have the
@@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase):
             ),
         )
 
-    def test_state_filter_difference_include_other_minus_no_include_other(self):
+    def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
         """
         Tests the StateFilter.approx_difference method
         where, in a.approx_difference(b), only a has the include_others flag set.
@@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase):
             ),
         )
 
-    def test_state_filter_difference_include_other_minus_include_other(self):
+    def test_state_filter_difference_include_other_minus_include_other(self) -> None:
         """
         Tests the StateFilter.approx_difference method
         where, in a.approx_difference(b), both a and b have the include_others
@@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase):
             ),
         )
 
-    def test_state_filter_difference_no_include_other_minus_include_other(self):
+    def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
         """
         Tests the StateFilter.approx_difference method
         where, in a.approx_difference(b), only b has the include_others flag set.
@@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase):
             ),
         )
 
-    def test_state_filter_difference_simple_cases(self):
+    def test_state_filter_difference_simple_cases(self) -> None:
         """
         Tests some very simple cases of the StateFilter approx_difference,
         that are not explicitly tested by the more in-depth tests.
@@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase):
 
 
 class StateFilterTestCase(TestCase):
-    def test_return_expanded(self):
+    def test_return_expanded(self) -> None:
         """
         Tests the behaviour of the return_expanded() function that expands
         StateFilters to include more state types (for the sake of cache hit rate).