diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5564161750..a433e70870 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -16,10 +16,15 @@ import logging
from frozendict import frozendict
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.storage.state import StateFilter
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomID, StateMap, UserID
+from synapse.types.state import StateFilter
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase, TestCase
@@ -27,7 +32,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
@@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase):
)
)
- def inject_state_event(self, room, sender, typ, state_key, content):
+ def inject_state_event(
+ self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
+ ) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
+ assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
return event
- def assertStateMapEqual(self, s1, s2):
+ def assertStateMapEqual(
+ self, s1: StateMap[EventBase], s2: StateMap[EventBase]
+ ) -> None:
for t in s1:
# just compare event IDs for simplicity
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
- def test_get_state_groups_ids(self):
+ def test_get_state_groups_ids(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = self.get_success(
- self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
+ self.storage.state.get_state_groups_ids(
+ self.room.to_string(), [e2.event_id]
+ )
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
- def test_get_state_groups(self):
+ def test_get_state_groups(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = self.get_success(
- self.storage.state.get_state_groups(self.room, [e2.event_id])
+ self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
- def test_get_state_for_event(self):
+ def test_get_state_for_event(self) -> None:
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase):
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
- ):
+ ) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)
- def test_state_filter_difference_no_include_other_minus_no_include_other(self):
+ def test_state_filter_difference_no_include_other_minus_no_include_other(
+ self,
+ ) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
@@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_include_other_minus_no_include_other(self):
+ def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
@@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_include_other_minus_include_other(self):
+ def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
@@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_no_include_other_minus_include_other(self):
+ def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
@@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_simple_cases(self):
+ def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
@@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase):
class StateFilterTestCase(TestCase):
- def test_return_expanded(self):
+ def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).
|