diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py
new file mode 100644
index 0000000000..f7c6417a09
--- /dev/null
+++ b/tests/replication/storage/test_events.py
@@ -0,0 +1,420 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import Any, Callable, Iterable, List, Optional, Tuple
+
+from canonicaljson import encode_canonical_json
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import ReceiptTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.events.snapshot import EventContext
+from synapse.handlers.room import RoomEventSource
+from synapse.server import HomeServer
+from synapse.storage.databases.main.event_push_actions import (
+ NotifCounts,
+ RoomNotifCounts,
+)
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
+from synapse.types import PersistedEventPosition
+from synapse.util import Clock
+
+from tests.server import FakeTransport
+
+from ._base import BaseWorkerStoreTestCase
+
+USER_ID = "@feeling:test"
+USER_ID_2 = "@bright:test"
+OUTLIER = {"outlier": True}
+ROOM_ID = "!room:test"
+
+logger = logging.getLogger(__name__)
+
+
+def dict_equals(self: EventBase, other: EventBase) -> bool:
+ me = encode_canonical_json(self.get_pdu_json())
+ them = encode_canonical_json(other.get_pdu_json())
+ return me == them
+
+
+def patch__eq__(cls: object) -> Callable[[], None]:
+ eq = getattr(cls, "__eq__", None)
+ cls.__eq__ = dict_equals # type: ignore[assignment]
+
+ def unpatch() -> None:
+ if eq is not None:
+ cls.__eq__ = eq # type: ignore[assignment]
+
+ return unpatch
+
+
+class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
+ STORE_TYPE = EventsWorkerStore
+
+ def setUp(self) -> None:
+ # Patch up the equality operator for events so that we can check
+ # whether lists of events match using assertEqual
+ self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
+ super().setUp()
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ super().prepare(reactor, clock, hs)
+
+ self.get_success(
+ self.master_store.store_room(
+ ROOM_ID,
+ USER_ID,
+ is_public=False,
+ room_version=RoomVersions.V1,
+ )
+ )
+
+ def tearDown(self) -> None:
+ [unpatch() for unpatch in self.unpatches]
+
+ def test_get_latest_event_ids_in_room(self) -> None:
+ create = self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.replicate()
+ self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
+
+ join = self.persist(
+ type="m.room.member",
+ key=USER_ID,
+ membership="join",
+ prev_events=[(create.event_id, {})],
+ )
+ self.replicate()
+ self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
+
+ def test_redactions(self) -> None:
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+
+ msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
+ self.replicate()
+ self.check("get_event", [msg.event_id], msg)
+
+ redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
+ self.replicate()
+
+ msg_dict = msg.get_dict()
+ msg_dict["content"] = {}
+ msg_dict["unsigned"]["redacted_by"] = redaction.event_id
+ msg_dict["unsigned"]["redacted_because"] = redaction
+ redacted = make_event_from_dict(
+ msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
+ )
+ self.check("get_event", [msg.event_id], redacted)
+
+ def test_backfilled_redactions(self) -> None:
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+
+ msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
+ self.replicate()
+ self.check("get_event", [msg.event_id], msg)
+
+ redaction = self.persist(
+ type="m.room.redaction", redacts=msg.event_id, backfill=True
+ )
+ self.replicate()
+
+ msg_dict = msg.get_dict()
+ msg_dict["content"] = {}
+ msg_dict["unsigned"]["redacted_by"] = redaction.event_id
+ msg_dict["unsigned"]["redacted_because"] = redaction
+ redacted = make_event_from_dict(
+ msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
+ )
+ self.check("get_event", [msg.event_id], redacted)
+
+ def test_invites(self) -> None:
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
+ event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
+ assert event.internal_metadata.stream_ordering is not None
+
+ self.replicate()
+
+ self.check(
+ "get_invited_rooms_for_local_user",
+ [USER_ID_2],
+ [
+ RoomsForUser(
+ ROOM_ID,
+ USER_ID,
+ "invite",
+ event.event_id,
+ event.internal_metadata.stream_ordering,
+ RoomVersions.V1.identifier,
+ )
+ ],
+ )
+
+ @parameterized.expand([(True,), (False,)])
+ def test_push_actions_for_user(self, send_receipt: bool) -> None:
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.persist(
+ type="m.room.member", sender=USER_ID, key=USER_ID_2, membership="join"
+ )
+ event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
+ self.replicate()
+
+ if send_receipt:
+ self.get_success(
+ self.master_store.insert_receipt(
+ ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {}
+ )
+ )
+
+ self.check(
+ "get_unread_event_push_actions_by_room_for_user",
+ [ROOM_ID, USER_ID_2],
+ RoomNotifCounts(
+ NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {}
+ ),
+ )
+
+ self.persist(
+ type="m.room.message",
+ msgtype="m.text",
+ body="world",
+ push_actions=[(USER_ID_2, ["notify"])],
+ )
+ self.replicate()
+ self.check(
+ "get_unread_event_push_actions_by_room_for_user",
+ [ROOM_ID, USER_ID_2],
+ RoomNotifCounts(
+ NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {}
+ ),
+ )
+
+ self.persist(
+ type="m.room.message",
+ msgtype="m.text",
+ body="world",
+ push_actions=[
+ (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
+ ],
+ )
+ self.replicate()
+ self.check(
+ "get_unread_event_push_actions_by_room_for_user",
+ [ROOM_ID, USER_ID_2],
+ RoomNotifCounts(
+ NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {}
+ ),
+ )
+
+ def test_get_rooms_for_user_with_stream_ordering(self) -> None:
+ """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
+ by rows in the events stream
+ """
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.replicate()
+ self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+ j2 = self.persist(
+ type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+ )
+ assert j2.internal_metadata.stream_ordering is not None
+ self.replicate()
+
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
+ )
+ self.check(
+ "get_rooms_for_user_with_stream_ordering",
+ (USER_ID_2,),
+ {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
+ )
+
+ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
+ self,
+ ) -> None:
+ """Check that current_state invalidation happens correctly with multiple events
+ in the persistence batch.
+
+ This test attempts to reproduce a race condition between the event persistence
+ loop and a worker-based Sync handler.
+
+ The problem occurred when the master persisted several events in one batch. It
+ only updates the current_state at the end of each batch, so the obvious thing
+ to do is then to issue a current_state_delta stream update corresponding to the
+ last stream_id in the batch.
+
+ However, that raises the possibility that a worker will see the replication
+ notification for a join event before the current_state caches are invalidated.
+
+ The test involves:
+ * creating a join and a message event for a user, and persisting them in the
+ same batch
+
+ * controlling the replication stream so that updates are sent gradually
+
+ * between each bunch of replication updates, check that we see a consistent
+ snapshot of the state.
+ """
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.replicate()
+ self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+ # limit the replication rate
+ repl_transport = self._server_transport
+ assert isinstance(repl_transport, FakeTransport)
+ repl_transport.autoflush = False
+
+ # build the join and message events and persist them in the same batch.
+ logger.info("----- build test events ------")
+ j2, j2ctx = self.build_event(
+ type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+ )
+ msg, msgctx = self.build_event()
+ self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
+ self.replicate()
+ assert j2.internal_metadata.stream_ordering is not None
+
+ event_source = RoomEventSource(self.hs)
+ event_source.store = self.worker_store
+ current_token = event_source.get_current_key()
+
+ # gradually stream out the replication
+ while repl_transport.buffer:
+ logger.info("------ flush ------")
+ repl_transport.flush(30)
+ self.pump(0)
+
+ prev_token = current_token
+ current_token = event_source.get_current_key()
+
+ # attempt to replicate the behaviour of the sync handler.
+ #
+ # First, we get a list of the rooms we are joined to
+ joined_rooms = self.get_success(
+ self.worker_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
+ )
+
+ # Then, we get a list of the events since the last sync
+ membership_changes = self.get_success(
+ self.worker_store.get_membership_changes_for_user(
+ USER_ID_2, prev_token, current_token
+ )
+ )
+
+ logger.info(
+ "%s->%s: joined_rooms=%r membership_changes=%r",
+ prev_token,
+ current_token,
+ joined_rooms,
+ membership_changes,
+ )
+
+ # the membership change is only any use to us if the room is in the
+ # joined_rooms list.
+ if membership_changes:
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
+ )
+ self.assertEqual(
+ joined_rooms,
+ {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
+ )
+
+ event_id = 0
+
+ def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
+ """
+ Returns:
+ The event that was persisted.
+ """
+ event, context = self.build_event(**kwargs)
+
+ if backfill:
+ self.get_success(
+ self.persistance.persist_events([(event, context)], backfilled=True)
+ )
+ else:
+ self.get_success(self.persistance.persist_event(event, context))
+
+ return event
+
+ def build_event(
+ self,
+ sender: str = USER_ID,
+ room_id: str = ROOM_ID,
+ type: str = "m.room.message",
+ key: Optional[str] = None,
+ internal: Optional[dict] = None,
+ depth: Optional[int] = None,
+ prev_events: Optional[List[Tuple[str, dict]]] = None,
+ auth_events: Optional[List[str]] = None,
+ prev_state: Optional[List[str]] = None,
+ redacts: Optional[str] = None,
+ push_actions: Iterable = frozenset(),
+ **content: object,
+ ) -> Tuple[EventBase, EventContext]:
+ prev_events = prev_events or []
+ auth_events = auth_events or []
+ prev_state = prev_state or []
+
+ if depth is None:
+ depth = self.event_id
+
+ if not prev_events:
+ latest_event_ids = self.get_success(
+ self.master_store.get_latest_event_ids_in_room(room_id)
+ )
+ prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
+
+ event_dict = {
+ "sender": sender,
+ "type": type,
+ "content": content,
+ "event_id": "$%d:blue" % (self.event_id,),
+ "room_id": room_id,
+ "depth": depth,
+ "origin_server_ts": self.event_id,
+ "prev_events": prev_events,
+ "auth_events": auth_events,
+ }
+ if key is not None:
+ event_dict["state_key"] = key
+ event_dict["prev_state"] = prev_state
+
+ if redacts is not None:
+ event_dict["redacts"] = redacts
+
+ event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {})
+
+ self.event_id += 1
+ state_handler = self.hs.get_state_handler()
+ context = self.get_success(state_handler.compute_event_context(event))
+
+ self.get_success(
+ self.master_store.add_push_actions_to_staging(
+ event.event_id,
+ dict(push_actions),
+ False,
+ "main",
+ )
+ )
+ return event, context
|