From 375b0a8a119bb925ca280f050a25a931662fcbb5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 16 May 2023 15:56:38 -0400 Subject: Update code to refer to "workers". (#15606) A bunch of comments and variables are out of date and use obsolete terms. --- tests/app/test_openid_listener.py | 2 +- tests/replication/slave/__init__.py | 13 - tests/replication/slave/storage/__init__.py | 13 - tests/replication/slave/storage/_base.py | 72 ----- tests/replication/slave/storage/test_events.py | 420 ------------------------- tests/replication/storage/__init__.py | 13 + tests/replication/storage/_base.py | 72 +++++ tests/replication/storage/test_events.py | 420 +++++++++++++++++++++++++ 8 files changed, 506 insertions(+), 519 deletions(-) delete mode 100644 tests/replication/slave/__init__.py delete mode 100644 tests/replication/slave/storage/__init__.py delete mode 100644 tests/replication/slave/storage/_base.py delete mode 100644 tests/replication/slave/storage/test_events.py create mode 100644 tests/replication/storage/__init__.py create mode 100644 tests/replication/storage/_base.py create mode 100644 tests/replication/storage/test_events.py (limited to 'tests') diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 2ee343d8a4..6e0413400e 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -38,7 +38,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): def default_config(self) -> JsonDict: conf = super().default_config() - # we're using FederationReaderServer, which uses a SlavedStore, so we + # we're using GenericWorkerServer, which uses a GenericWorkerStore, so we # have to tell the FederationHandler not to try to access stuff that is only # in the primary store. conf["worker_app"] = "yes" diff --git a/tests/replication/slave/__init__.py b/tests/replication/slave/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/tests/replication/slave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/replication/slave/storage/__init__.py b/tests/replication/slave/storage/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/tests/replication/slave/storage/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py deleted file mode 100644 index 4c9b494344..0000000000 --- a/tests/replication/slave/storage/_base.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Iterable, Optional -from unittest.mock import Mock - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.server import HomeServer -from synapse.util import Clock - -from tests.replication._base import BaseStreamTestCase - - -class BaseSlavedStoreTestCase(BaseStreamTestCase): - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - return self.setup_test_homeserver(federation_client=Mock()) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - super().prepare(reactor, clock, hs) - - self.reconnect() - - self.master_store = hs.get_datastores().main - self.slaved_store = self.worker_hs.get_datastores().main - persistence = hs.get_storage_controllers().persistence - assert persistence is not None - self.persistance = persistence - - def replicate(self) -> None: - """Tell the master side of replication that something has happened, and then - wait for the replication to occur. - """ - self.streamer.on_notifier_poke() - self.pump(0.1) - - def check( - self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None - ) -> None: - master_result = self.get_success(getattr(self.master_store, method)(*args)) - slaved_result = self.get_success(getattr(self.slaved_store, method)(*args)) - if expected_result is not None: - self.assertEqual( - master_result, - expected_result, - "Expected master result to be %r but was %r" - % (expected_result, master_result), - ) - self.assertEqual( - slaved_result, - expected_result, - "Expected slave result to be %r but was %r" - % (expected_result, slaved_result), - ) - self.assertEqual( - master_result, - slaved_result, - "Slave result %r does not match master result %r" - % (slaved_result, master_result), - ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py deleted file mode 100644 index b2125b1fea..0000000000 --- a/tests/replication/slave/storage/test_events.py +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import Any, Callable, Iterable, List, Optional, Tuple - -from canonicaljson import encode_canonical_json -from parameterized import parameterized - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.constants import ReceiptTypes -from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict -from synapse.events.snapshot import EventContext -from synapse.handlers.room import RoomEventSource -from synapse.server import HomeServer -from synapse.storage.databases.main.event_push_actions import ( - NotifCounts, - RoomNotifCounts, -) -from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser -from synapse.types import PersistedEventPosition -from synapse.util import Clock - -from tests.server import FakeTransport - -from ._base import BaseSlavedStoreTestCase - -USER_ID = "@feeling:test" -USER_ID_2 = "@bright:test" -OUTLIER = {"outlier": True} -ROOM_ID = "!room:test" - -logger = logging.getLogger(__name__) - - -def dict_equals(self: EventBase, other: EventBase) -> bool: - me = encode_canonical_json(self.get_pdu_json()) - them = encode_canonical_json(other.get_pdu_json()) - return me == them - - -def patch__eq__(cls: object) -> Callable[[], None]: - eq = getattr(cls, "__eq__", None) - cls.__eq__ = dict_equals # type: ignore[assignment] - - def unpatch() -> None: - if eq is not None: - cls.__eq__ = eq # type: ignore[assignment] - - return unpatch - - -class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): - STORE_TYPE = EventsWorkerStore - - def setUp(self) -> None: - # Patch up the equality operator for events so that we can check - # whether lists of events match using assertEqual - self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)] - super().setUp() - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - super().prepare(reactor, clock, hs) - - self.get_success( - self.master_store.store_room( - ROOM_ID, - USER_ID, - is_public=False, - room_version=RoomVersions.V1, - ) - ) - - def tearDown(self) -> None: - [unpatch() for unpatch in self.unpatches] - - def test_get_latest_event_ids_in_room(self) -> None: - create = self.persist(type="m.room.create", key="", creator=USER_ID) - self.replicate() - self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) - - join = self.persist( - type="m.room.member", - key=USER_ID, - membership="join", - prev_events=[(create.event_id, {})], - ) - self.replicate() - self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) - - def test_redactions(self) -> None: - self.persist(type="m.room.create", key="", creator=USER_ID) - self.persist(type="m.room.member", key=USER_ID, membership="join") - - msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") - self.replicate() - self.check("get_event", [msg.event_id], msg) - - redaction = self.persist(type="m.room.redaction", redacts=msg.event_id) - self.replicate() - - msg_dict = msg.get_dict() - msg_dict["content"] = {} - msg_dict["unsigned"]["redacted_by"] = redaction.event_id - msg_dict["unsigned"]["redacted_because"] = redaction - redacted = make_event_from_dict( - msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() - ) - self.check("get_event", [msg.event_id], redacted) - - def test_backfilled_redactions(self) -> None: - self.persist(type="m.room.create", key="", creator=USER_ID) - self.persist(type="m.room.member", key=USER_ID, membership="join") - - msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") - self.replicate() - self.check("get_event", [msg.event_id], msg) - - redaction = self.persist( - type="m.room.redaction", redacts=msg.event_id, backfill=True - ) - self.replicate() - - msg_dict = msg.get_dict() - msg_dict["content"] = {} - msg_dict["unsigned"]["redacted_by"] = redaction.event_id - msg_dict["unsigned"]["redacted_because"] = redaction - redacted = make_event_from_dict( - msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() - ) - self.check("get_event", [msg.event_id], redacted) - - def test_invites(self) -> None: - self.persist(type="m.room.create", key="", creator=USER_ID) - self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) - event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") - assert event.internal_metadata.stream_ordering is not None - - self.replicate() - - self.check( - "get_invited_rooms_for_local_user", - [USER_ID_2], - [ - RoomsForUser( - ROOM_ID, - USER_ID, - "invite", - event.event_id, - event.internal_metadata.stream_ordering, - RoomVersions.V1.identifier, - ) - ], - ) - - @parameterized.expand([(True,), (False,)]) - def test_push_actions_for_user(self, send_receipt: bool) -> None: - self.persist(type="m.room.create", key="", creator=USER_ID) - self.persist(type="m.room.member", key=USER_ID, membership="join") - self.persist( - type="m.room.member", sender=USER_ID, key=USER_ID_2, membership="join" - ) - event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello") - self.replicate() - - if send_receipt: - self.get_success( - self.master_store.insert_receipt( - ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {} - ) - ) - - self.check( - "get_unread_event_push_actions_by_room_for_user", - [ROOM_ID, USER_ID_2], - RoomNotifCounts( - NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {} - ), - ) - - self.persist( - type="m.room.message", - msgtype="m.text", - body="world", - push_actions=[(USER_ID_2, ["notify"])], - ) - self.replicate() - self.check( - "get_unread_event_push_actions_by_room_for_user", - [ROOM_ID, USER_ID_2], - RoomNotifCounts( - NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {} - ), - ) - - self.persist( - type="m.room.message", - msgtype="m.text", - body="world", - push_actions=[ - (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}]) - ], - ) - self.replicate() - self.check( - "get_unread_event_push_actions_by_room_for_user", - [ROOM_ID, USER_ID_2], - RoomNotifCounts( - NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {} - ), - ) - - def test_get_rooms_for_user_with_stream_ordering(self) -> None: - """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated - by rows in the events stream - """ - self.persist(type="m.room.create", key="", creator=USER_ID) - self.persist(type="m.room.member", key=USER_ID, membership="join") - self.replicate() - self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) - - j2 = self.persist( - type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" - ) - assert j2.internal_metadata.stream_ordering is not None - self.replicate() - - expected_pos = PersistedEventPosition( - "master", j2.internal_metadata.stream_ordering - ) - self.check( - "get_rooms_for_user_with_stream_ordering", - (USER_ID_2,), - {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, - ) - - def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist( - self, - ) -> None: - """Check that current_state invalidation happens correctly with multiple events - in the persistence batch. - - This test attempts to reproduce a race condition between the event persistence - loop and a worker-based Sync handler. - - The problem occurred when the master persisted several events in one batch. It - only updates the current_state at the end of each batch, so the obvious thing - to do is then to issue a current_state_delta stream update corresponding to the - last stream_id in the batch. - - However, that raises the possibility that a worker will see the replication - notification for a join event before the current_state caches are invalidated. - - The test involves: - * creating a join and a message event for a user, and persisting them in the - same batch - - * controlling the replication stream so that updates are sent gradually - - * between each bunch of replication updates, check that we see a consistent - snapshot of the state. - """ - self.persist(type="m.room.create", key="", creator=USER_ID) - self.persist(type="m.room.member", key=USER_ID, membership="join") - self.replicate() - self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) - - # limit the replication rate - repl_transport = self._server_transport - assert isinstance(repl_transport, FakeTransport) - repl_transport.autoflush = False - - # build the join and message events and persist them in the same batch. - logger.info("----- build test events ------") - j2, j2ctx = self.build_event( - type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" - ) - msg, msgctx = self.build_event() - self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)])) - self.replicate() - assert j2.internal_metadata.stream_ordering is not None - - event_source = RoomEventSource(self.hs) - event_source.store = self.slaved_store - current_token = event_source.get_current_key() - - # gradually stream out the replication - while repl_transport.buffer: - logger.info("------ flush ------") - repl_transport.flush(30) - self.pump(0) - - prev_token = current_token - current_token = event_source.get_current_key() - - # attempt to replicate the behaviour of the sync handler. - # - # First, we get a list of the rooms we are joined to - joined_rooms = self.get_success( - self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2) - ) - - # Then, we get a list of the events since the last sync - membership_changes = self.get_success( - self.slaved_store.get_membership_changes_for_user( - USER_ID_2, prev_token, current_token - ) - ) - - logger.info( - "%s->%s: joined_rooms=%r membership_changes=%r", - prev_token, - current_token, - joined_rooms, - membership_changes, - ) - - # the membership change is only any use to us if the room is in the - # joined_rooms list. - if membership_changes: - expected_pos = PersistedEventPosition( - "master", j2.internal_metadata.stream_ordering - ) - self.assertEqual( - joined_rooms, - {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, - ) - - event_id = 0 - - def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase: - """ - Returns: - The event that was persisted. - """ - event, context = self.build_event(**kwargs) - - if backfill: - self.get_success( - self.persistance.persist_events([(event, context)], backfilled=True) - ) - else: - self.get_success(self.persistance.persist_event(event, context)) - - return event - - def build_event( - self, - sender: str = USER_ID, - room_id: str = ROOM_ID, - type: str = "m.room.message", - key: Optional[str] = None, - internal: Optional[dict] = None, - depth: Optional[int] = None, - prev_events: Optional[List[Tuple[str, dict]]] = None, - auth_events: Optional[List[str]] = None, - prev_state: Optional[List[str]] = None, - redacts: Optional[str] = None, - push_actions: Iterable = frozenset(), - **content: object, - ) -> Tuple[EventBase, EventContext]: - prev_events = prev_events or [] - auth_events = auth_events or [] - prev_state = prev_state or [] - - if depth is None: - depth = self.event_id - - if not prev_events: - latest_event_ids = self.get_success( - self.master_store.get_latest_event_ids_in_room(room_id) - ) - prev_events = [(ev_id, {}) for ev_id in latest_event_ids] - - event_dict = { - "sender": sender, - "type": type, - "content": content, - "event_id": "$%d:blue" % (self.event_id,), - "room_id": room_id, - "depth": depth, - "origin_server_ts": self.event_id, - "prev_events": prev_events, - "auth_events": auth_events, - } - if key is not None: - event_dict["state_key"] = key - event_dict["prev_state"] = prev_state - - if redacts is not None: - event_dict["redacts"] = redacts - - event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {}) - - self.event_id += 1 - state_handler = self.hs.get_state_handler() - context = self.get_success(state_handler.compute_event_context(event)) - - self.get_success( - self.master_store.add_push_actions_to_staging( - event.event_id, - dict(push_actions), - False, - "main", - ) - ) - return event, context diff --git a/tests/replication/storage/__init__.py b/tests/replication/storage/__init__.py new file mode 100644 index 0000000000..f43a360a80 --- /dev/null +++ b/tests/replication/storage/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/storage/_base.py b/tests/replication/storage/_base.py new file mode 100644 index 0000000000..de26a62ae1 --- /dev/null +++ b/tests/replication/storage/_base.py @@ -0,0 +1,72 @@ +# Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Iterable, Optional +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.util import Clock + +from tests.replication._base import BaseStreamTestCase + + +class BaseWorkerStoreTestCase(BaseStreamTestCase): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + return self.setup_test_homeserver(federation_client=Mock()) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + + self.reconnect() + + self.master_store = hs.get_datastores().main + self.worker_store = self.worker_hs.get_datastores().main + persistence = hs.get_storage_controllers().persistence + assert persistence is not None + self.persistance = persistence + + def replicate(self) -> None: + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump(0.1) + + def check( + self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None + ) -> None: + master_result = self.get_success(getattr(self.master_store, method)(*args)) + worker_result = self.get_success(getattr(self.worker_store, method)(*args)) + if expected_result is not None: + self.assertEqual( + master_result, + expected_result, + "Expected master result to be %r but was %r" + % (expected_result, master_result), + ) + self.assertEqual( + worker_result, + expected_result, + "Expected worker result to be %r but was %r" + % (expected_result, worker_result), + ) + self.assertEqual( + master_result, + worker_result, + "Worker result %r does not match master result %r" + % (worker_result, master_result), + ) diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py new file mode 100644 index 0000000000..f7c6417a09 --- /dev/null +++ b/tests/replication/storage/test_events.py @@ -0,0 +1,420 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Any, Callable, Iterable, List, Optional, Tuple + +from canonicaljson import encode_canonical_json +from parameterized import parameterized + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import ReceiptTypes +from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict +from synapse.events.snapshot import EventContext +from synapse.handlers.room import RoomEventSource +from synapse.server import HomeServer +from synapse.storage.databases.main.event_push_actions import ( + NotifCounts, + RoomNotifCounts, +) +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser +from synapse.types import PersistedEventPosition +from synapse.util import Clock + +from tests.server import FakeTransport + +from ._base import BaseWorkerStoreTestCase + +USER_ID = "@feeling:test" +USER_ID_2 = "@bright:test" +OUTLIER = {"outlier": True} +ROOM_ID = "!room:test" + +logger = logging.getLogger(__name__) + + +def dict_equals(self: EventBase, other: EventBase) -> bool: + me = encode_canonical_json(self.get_pdu_json()) + them = encode_canonical_json(other.get_pdu_json()) + return me == them + + +def patch__eq__(cls: object) -> Callable[[], None]: + eq = getattr(cls, "__eq__", None) + cls.__eq__ = dict_equals # type: ignore[assignment] + + def unpatch() -> None: + if eq is not None: + cls.__eq__ = eq # type: ignore[assignment] + + return unpatch + + +class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): + STORE_TYPE = EventsWorkerStore + + def setUp(self) -> None: + # Patch up the equality operator for events so that we can check + # whether lists of events match using assertEqual + self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)] + super().setUp() + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + + self.get_success( + self.master_store.store_room( + ROOM_ID, + USER_ID, + is_public=False, + room_version=RoomVersions.V1, + ) + ) + + def tearDown(self) -> None: + [unpatch() for unpatch in self.unpatches] + + def test_get_latest_event_ids_in_room(self) -> None: + create = self.persist(type="m.room.create", key="", creator=USER_ID) + self.replicate() + self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) + + join = self.persist( + type="m.room.member", + key=USER_ID, + membership="join", + prev_events=[(create.event_id, {})], + ) + self.replicate() + self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) + + def test_redactions(self) -> None: + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") + self.replicate() + self.check("get_event", [msg.event_id], msg) + + redaction = self.persist(type="m.room.redaction", redacts=msg.event_id) + self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = make_event_from_dict( + msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() + ) + self.check("get_event", [msg.event_id], redacted) + + def test_backfilled_redactions(self) -> None: + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") + self.replicate() + self.check("get_event", [msg.event_id], msg) + + redaction = self.persist( + type="m.room.redaction", redacts=msg.event_id, backfill=True + ) + self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = make_event_from_dict( + msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() + ) + self.check("get_event", [msg.event_id], redacted) + + def test_invites(self) -> None: + self.persist(type="m.room.create", key="", creator=USER_ID) + self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) + event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") + assert event.internal_metadata.stream_ordering is not None + + self.replicate() + + self.check( + "get_invited_rooms_for_local_user", + [USER_ID_2], + [ + RoomsForUser( + ROOM_ID, + USER_ID, + "invite", + event.event_id, + event.internal_metadata.stream_ordering, + RoomVersions.V1.identifier, + ) + ], + ) + + @parameterized.expand([(True,), (False,)]) + def test_push_actions_for_user(self, send_receipt: bool) -> None: + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") + self.persist( + type="m.room.member", sender=USER_ID, key=USER_ID_2, membership="join" + ) + event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello") + self.replicate() + + if send_receipt: + self.get_success( + self.master_store.insert_receipt( + ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {} + ) + ) + + self.check( + "get_unread_event_push_actions_by_room_for_user", + [ROOM_ID, USER_ID_2], + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {} + ), + ) + + self.persist( + type="m.room.message", + msgtype="m.text", + body="world", + push_actions=[(USER_ID_2, ["notify"])], + ) + self.replicate() + self.check( + "get_unread_event_push_actions_by_room_for_user", + [ROOM_ID, USER_ID_2], + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {} + ), + ) + + self.persist( + type="m.room.message", + msgtype="m.text", + body="world", + push_actions=[ + (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}]) + ], + ) + self.replicate() + self.check( + "get_unread_event_push_actions_by_room_for_user", + [ROOM_ID, USER_ID_2], + RoomNotifCounts( + NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {} + ), + ) + + def test_get_rooms_for_user_with_stream_ordering(self) -> None: + """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated + by rows in the events stream + """ + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") + self.replicate() + self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) + + j2 = self.persist( + type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" + ) + assert j2.internal_metadata.stream_ordering is not None + self.replicate() + + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering + ) + self.check( + "get_rooms_for_user_with_stream_ordering", + (USER_ID_2,), + {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, + ) + + def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist( + self, + ) -> None: + """Check that current_state invalidation happens correctly with multiple events + in the persistence batch. + + This test attempts to reproduce a race condition between the event persistence + loop and a worker-based Sync handler. + + The problem occurred when the master persisted several events in one batch. It + only updates the current_state at the end of each batch, so the obvious thing + to do is then to issue a current_state_delta stream update corresponding to the + last stream_id in the batch. + + However, that raises the possibility that a worker will see the replication + notification for a join event before the current_state caches are invalidated. + + The test involves: + * creating a join and a message event for a user, and persisting them in the + same batch + + * controlling the replication stream so that updates are sent gradually + + * between each bunch of replication updates, check that we see a consistent + snapshot of the state. + """ + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") + self.replicate() + self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) + + # limit the replication rate + repl_transport = self._server_transport + assert isinstance(repl_transport, FakeTransport) + repl_transport.autoflush = False + + # build the join and message events and persist them in the same batch. + logger.info("----- build test events ------") + j2, j2ctx = self.build_event( + type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" + ) + msg, msgctx = self.build_event() + self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)])) + self.replicate() + assert j2.internal_metadata.stream_ordering is not None + + event_source = RoomEventSource(self.hs) + event_source.store = self.worker_store + current_token = event_source.get_current_key() + + # gradually stream out the replication + while repl_transport.buffer: + logger.info("------ flush ------") + repl_transport.flush(30) + self.pump(0) + + prev_token = current_token + current_token = event_source.get_current_key() + + # attempt to replicate the behaviour of the sync handler. + # + # First, we get a list of the rooms we are joined to + joined_rooms = self.get_success( + self.worker_store.get_rooms_for_user_with_stream_ordering(USER_ID_2) + ) + + # Then, we get a list of the events since the last sync + membership_changes = self.get_success( + self.worker_store.get_membership_changes_for_user( + USER_ID_2, prev_token, current_token + ) + ) + + logger.info( + "%s->%s: joined_rooms=%r membership_changes=%r", + prev_token, + current_token, + joined_rooms, + membership_changes, + ) + + # the membership change is only any use to us if the room is in the + # joined_rooms list. + if membership_changes: + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering + ) + self.assertEqual( + joined_rooms, + {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, + ) + + event_id = 0 + + def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase: + """ + Returns: + The event that was persisted. + """ + event, context = self.build_event(**kwargs) + + if backfill: + self.get_success( + self.persistance.persist_events([(event, context)], backfilled=True) + ) + else: + self.get_success(self.persistance.persist_event(event, context)) + + return event + + def build_event( + self, + sender: str = USER_ID, + room_id: str = ROOM_ID, + type: str = "m.room.message", + key: Optional[str] = None, + internal: Optional[dict] = None, + depth: Optional[int] = None, + prev_events: Optional[List[Tuple[str, dict]]] = None, + auth_events: Optional[List[str]] = None, + prev_state: Optional[List[str]] = None, + redacts: Optional[str] = None, + push_actions: Iterable = frozenset(), + **content: object, + ) -> Tuple[EventBase, EventContext]: + prev_events = prev_events or [] + auth_events = auth_events or [] + prev_state = prev_state or [] + + if depth is None: + depth = self.event_id + + if not prev_events: + latest_event_ids = self.get_success( + self.master_store.get_latest_event_ids_in_room(room_id) + ) + prev_events = [(ev_id, {}) for ev_id in latest_event_ids] + + event_dict = { + "sender": sender, + "type": type, + "content": content, + "event_id": "$%d:blue" % (self.event_id,), + "room_id": room_id, + "depth": depth, + "origin_server_ts": self.event_id, + "prev_events": prev_events, + "auth_events": auth_events, + } + if key is not None: + event_dict["state_key"] = key + event_dict["prev_state"] = prev_state + + if redacts is not None: + event_dict["redacts"] = redacts + + event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {}) + + self.event_id += 1 + state_handler = self.hs.get_state_handler() + context = self.get_success(state_handler.compute_event_context(event)) + + self.get_success( + self.master_store.add_push_actions_to_staging( + event.event_id, + dict(push_actions), + False, + "main", + ) + ) + return event, context -- cgit 1.5.1