diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index dce71f7334..ddca9d696c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Iterable, Optional
+from typing import Any, Callable, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
+from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
+from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import (
NotifCounts,
RoomNotifCounts,
@@ -28,6 +32,7 @@ from synapse.storage.databases.main.event_push_actions import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
+from synapse.util import Clock
from tests.server import FakeTransport
@@ -41,19 +46,19 @@ ROOM_ID = "!room:test"
logger = logging.getLogger(__name__)
-def dict_equals(self, other):
+def dict_equals(self: EventBase, other: EventBase) -> bool:
me = encode_canonical_json(self.get_pdu_json())
them = encode_canonical_json(other.get_pdu_json())
return me == them
-def patch__eq__(cls):
+def patch__eq__(cls: object) -> Callable[[], None]:
eq = getattr(cls, "__eq__", None)
- cls.__eq__ = dict_equals
+ cls.__eq__ = dict_equals # type: ignore[assignment]
- def unpatch():
+ def unpatch() -> None:
if eq is not None:
- cls.__eq__ = eq
+ cls.__eq__ = eq # type: ignore[assignment]
return unpatch
@@ -62,14 +67,14 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = EventsWorkerStore
- def setUp(self):
+ def setUp(self) -> None:
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEqual
- self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
- return super().setUp()
+ self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
+ super().setUp()
- def prepare(self, *args, **kwargs):
- super().prepare(*args, **kwargs)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ super().prepare(reactor, clock, hs)
self.get_success(
self.master_store.store_room(
@@ -80,10 +85,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
)
- def tearDown(self):
+ def tearDown(self) -> None:
[unpatch() for unpatch in self.unpatches]
- def test_get_latest_event_ids_in_room(self):
+ def test_get_latest_event_ids_in_room(self) -> None:
create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
@@ -97,7 +102,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
- def test_redactions(self):
+ def test_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
@@ -117,7 +122,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
self.check("get_event", [msg.event_id], redacted)
- def test_backfilled_redactions(self):
+ def test_backfilled_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
@@ -139,7 +144,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
self.check("get_event", [msg.event_id], redacted)
- def test_invites(self):
+ def test_invites(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
@@ -163,7 +168,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
@parameterized.expand([(True,), (False,)])
- def test_push_actions_for_user(self, send_receipt: bool):
+ def test_push_actions_for_user(self, send_receipt: bool) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(
@@ -219,7 +224,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
),
)
- def test_get_rooms_for_user_with_stream_ordering(self):
+ def test_get_rooms_for_user_with_stream_ordering(self) -> None:
"""Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
by rows in the events stream
"""
@@ -243,7 +248,9 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
{GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
)
- def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
+ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
+ self,
+ ) -> None:
"""Check that current_state invalidation happens correctly with multiple events
in the persistence batch.
@@ -283,11 +290,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
- self.get_success(
- self._storage_controllers.persistence.persist_events(
- [(j2, j2ctx), (msg, msgctx)]
- )
- )
+ self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
self.replicate()
assert j2.internal_metadata.stream_ordering is not None
@@ -339,7 +342,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0
- def persist(self, backfill=False, **kwargs) -> FrozenEvent:
+ def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
"""
Returns:
The event that was persisted.
@@ -348,32 +351,28 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self._storage_controllers.persistence.persist_events(
- [(event, context)], backfilled=True
- )
+ self.persistance.persist_events([(event, context)], backfilled=True)
)
else:
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ self.get_success(self.persistance.persist_event(event, context))
return event
def build_event(
self,
- sender=USER_ID,
- room_id=ROOM_ID,
- type="m.room.message",
- key=None,
+ sender: str = USER_ID,
+ room_id: str = ROOM_ID,
+ type: str = "m.room.message",
+ key: Optional[str] = None,
internal: Optional[dict] = None,
- depth=None,
- prev_events: Optional[list] = None,
- auth_events: Optional[list] = None,
- prev_state: Optional[list] = None,
- redacts=None,
+ depth: Optional[int] = None,
+ prev_events: Optional[List[Tuple[str, dict]]] = None,
+ auth_events: Optional[List[str]] = None,
+ prev_state: Optional[List[str]] = None,
+ redacts: Optional[str] = None,
push_actions: Iterable = frozenset(),
- **content,
- ):
+ **content: object,
+ ) -> Tuple[EventBase, EventContext]:
prev_events = prev_events or []
auth_events = auth_events or []
prev_state = prev_state or []
|