summary refs log tree commit diff
path: root/tests/replication/slave
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-02-06 09:55:00 -0500
committerGitHub <noreply@github.com>2023-02-06 09:55:00 -0500
commit156cd88eefe7db100e5cdba48174c709975b93ca (patch)
treebf4059f81c6ba16439ef6dfa19a4e016057da20d /tests/replication/slave
parentExpect type stubs from canonicaljson (#14992) (diff)
downloadsynapse-156cd88eefe7db100e5cdba48174c709975b93ca.tar.xz
Add missing type hints to tests.replication. (#14987)
Diffstat (limited to 'tests/replication/slave')
-rw-r--r--tests/replication/slave/storage/_base.py25
-rw-r--r--tests/replication/slave/storage/test_events.py85
2 files changed, 58 insertions, 52 deletions
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index c5705256e6..4c9b494344 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -13,35 +13,42 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Any, Iterable, Optional
 from unittest.mock import Mock
 
-from tests.replication._base import BaseStreamTestCase
+from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.server import HomeServer
+from synapse.util import Clock
 
-class BaseSlavedStoreTestCase(BaseStreamTestCase):
-    def make_homeserver(self, reactor, clock):
+from tests.replication._base import BaseStreamTestCase
 
-        hs = self.setup_test_homeserver(federation_client=Mock())
 
-        return hs
+class BaseSlavedStoreTestCase(BaseStreamTestCase):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        return self.setup_test_homeserver(federation_client=Mock())
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
 
         self.reconnect()
 
         self.master_store = hs.get_datastores().main
         self.slaved_store = self.worker_hs.get_datastores().main
-        self._storage_controllers = hs.get_storage_controllers()
+        persistence = hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self.persistance = persistence
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
         self.streamer.on_notifier_poke()
         self.pump(0.1)
 
-    def check(self, method, args, expected_result=None):
+    def check(
+        self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
+    ) -> None:
         master_result = self.get_success(getattr(self.master_store, method)(*args))
         slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
         if expected_result is not None:
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index dce71f7334..ddca9d696c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,15 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Iterable, Optional
+from typing import Any, Callable, Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 from parameterized import parameterized
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import ReceiptTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
+from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.handlers.room import RoomEventSource
+from synapse.server import HomeServer
 from synapse.storage.databases.main.event_push_actions import (
     NotifCounts,
     RoomNotifCounts,
@@ -28,6 +32,7 @@ from synapse.storage.databases.main.event_push_actions import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
 from synapse.types import PersistedEventPosition
+from synapse.util import Clock
 
 from tests.server import FakeTransport
 
@@ -41,19 +46,19 @@ ROOM_ID = "!room:test"
 logger = logging.getLogger(__name__)
 
 
-def dict_equals(self, other):
+def dict_equals(self: EventBase, other: EventBase) -> bool:
     me = encode_canonical_json(self.get_pdu_json())
     them = encode_canonical_json(other.get_pdu_json())
     return me == them
 
 
-def patch__eq__(cls):
+def patch__eq__(cls: object) -> Callable[[], None]:
     eq = getattr(cls, "__eq__", None)
-    cls.__eq__ = dict_equals
+    cls.__eq__ = dict_equals  # type: ignore[assignment]
 
-    def unpatch():
+    def unpatch() -> None:
         if eq is not None:
-            cls.__eq__ = eq
+            cls.__eq__ = eq  # type: ignore[assignment]
 
     return unpatch
 
@@ -62,14 +67,14 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
 
     STORE_TYPE = EventsWorkerStore
 
-    def setUp(self):
+    def setUp(self) -> None:
         # Patch up the equality operator for events so that we can check
         # whether lists of events match using assertEqual
-        self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
-        return super().setUp()
+        self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
+        super().setUp()
 
-    def prepare(self, *args, **kwargs):
-        super().prepare(*args, **kwargs)
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        super().prepare(reactor, clock, hs)
 
         self.get_success(
             self.master_store.store_room(
@@ -80,10 +85,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             )
         )
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         [unpatch() for unpatch in self.unpatches]
 
-    def test_get_latest_event_ids_in_room(self):
+    def test_get_latest_event_ids_in_room(self) -> None:
         create = self.persist(type="m.room.create", key="", creator=USER_ID)
         self.replicate()
         self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
@@ -97,7 +102,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         self.replicate()
         self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
 
-    def test_redactions(self):
+    def test_redactions(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.persist(type="m.room.member", key=USER_ID, membership="join")
 
@@ -117,7 +122,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         )
         self.check("get_event", [msg.event_id], redacted)
 
-    def test_backfilled_redactions(self):
+    def test_backfilled_redactions(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.persist(type="m.room.member", key=USER_ID, membership="join")
 
@@ -139,7 +144,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         )
         self.check("get_event", [msg.event_id], redacted)
 
-    def test_invites(self):
+    def test_invites(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
         event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
@@ -163,7 +168,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         )
 
     @parameterized.expand([(True,), (False,)])
-    def test_push_actions_for_user(self, send_receipt: bool):
+    def test_push_actions_for_user(self, send_receipt: bool) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.persist(type="m.room.member", key=USER_ID, membership="join")
         self.persist(
@@ -219,7 +224,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             ),
         )
 
-    def test_get_rooms_for_user_with_stream_ordering(self):
+    def test_get_rooms_for_user_with_stream_ordering(self) -> None:
         """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
         by rows in the events stream
         """
@@ -243,7 +248,9 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
         )
 
-    def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
+    def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
+        self,
+    ) -> None:
         """Check that current_state invalidation happens correctly with multiple events
         in the persistence batch.
 
@@ -283,11 +290,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
         )
         msg, msgctx = self.build_event()
-        self.get_success(
-            self._storage_controllers.persistence.persist_events(
-                [(j2, j2ctx), (msg, msgctx)]
-            )
-        )
+        self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
         self.replicate()
         assert j2.internal_metadata.stream_ordering is not None
 
@@ -339,7 +342,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
 
     event_id = 0
 
-    def persist(self, backfill=False, **kwargs) -> FrozenEvent:
+    def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
         """
         Returns:
             The event that was persisted.
@@ -348,32 +351,28 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
 
         if backfill:
             self.get_success(
-                self._storage_controllers.persistence.persist_events(
-                    [(event, context)], backfilled=True
-                )
+                self.persistance.persist_events([(event, context)], backfilled=True)
             )
         else:
-            self.get_success(
-                self._storage_controllers.persistence.persist_event(event, context)
-            )
+            self.get_success(self.persistance.persist_event(event, context))
 
         return event
 
     def build_event(
         self,
-        sender=USER_ID,
-        room_id=ROOM_ID,
-        type="m.room.message",
-        key=None,
+        sender: str = USER_ID,
+        room_id: str = ROOM_ID,
+        type: str = "m.room.message",
+        key: Optional[str] = None,
         internal: Optional[dict] = None,
-        depth=None,
-        prev_events: Optional[list] = None,
-        auth_events: Optional[list] = None,
-        prev_state: Optional[list] = None,
-        redacts=None,
+        depth: Optional[int] = None,
+        prev_events: Optional[List[Tuple[str, dict]]] = None,
+        auth_events: Optional[List[str]] = None,
+        prev_state: Optional[List[str]] = None,
+        redacts: Optional[str] = None,
         push_actions: Iterable = frozenset(),
-        **content,
-    ):
+        **content: object,
+    ) -> Tuple[EventBase, EventContext]:
         prev_events = prev_events or []
         auth_events = auth_events or []
         prev_state = prev_state or []