diff options
Diffstat (limited to 'tests/handlers')
-rw-r--r-- | tests/handlers/test_sync.py | 208 |
1 files changed, 195 insertions, 13 deletions
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 1b36324b8f..897c52c785 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -17,14 +17,16 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Collection, List, Optional +from typing import Collection, ContextManager, List, Optional from unittest.mock import AsyncMock, Mock, patch +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError -from synapse.api.filtering import Filtering +from synapse.api.filtering import FilterCollection, Filtering from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.events.snapshot import EventContext @@ -33,7 +35,7 @@ from synapse.handlers.sync import SyncConfig, SyncResult from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer -from synapse.types import UserID, create_requester +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock import tests.unittest @@ -258,13 +260,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Eve tries to join the room. We monkey patch the internal logic which selects # the prev_events used when creating the join event, such that the ban does not # precede the join. - mocked_get_prev_events = patch.object( - self.hs.get_datastores().main, - "get_prev_events_for_room", - new_callable=AsyncMock, - return_value=[last_room_creation_event_id], - ) - with mocked_get_prev_events: + with self._patch_get_latest_events([last_room_creation_event_id]): self.helper.join(room_id, eve, tok=eve_token) # Eve makes a second, incremental sync. @@ -288,6 +284,180 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEqual(eve_initial_sync_after_join.joined, []) + def test_state_includes_changes_on_forks(self) -> None: + """State changes that happen on a fork of the DAG must be included in `state` + + Given the following DAG: + + E1 + ↗ ↖ + | S2 + | ↑ + --|------|---- + | | + E3 | + ↖ / + E4 + + ... and a filter that means we only return 2 events, represented by the dashed + horizontal line: `S2` must be included in the `state` section. + """ + alice = self.register_user("alice", "password") + alice_tok = self.login(alice, "password") + alice_requester = create_requester(alice) + room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) + + # Do an initial sync as Alice to get a known starting point. + initial_sync_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + alice_requester, generate_sync_config(alice) + ) + ) + last_room_creation_event_id = ( + initial_sync_result.joined[0].timeline.events[-1].event_id + ) + + # Send a state event, and a regular event, both using the same prev ID + with self._patch_get_latest_events([last_room_creation_event_id]): + s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[ + "event_id" + ] + e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"] + + # Send a final event, joining the two branches of the dag + e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"] + + # do an incremental sync, with a filter that will ensure we only get two of + # the three new events. + incremental_sync = self.get_success( + self.sync_handler.wait_for_sync_for_user( + alice_requester, + generate_sync_config( + alice, + filter_collection=FilterCollection( + self.hs, {"room": {"timeline": {"limit": 2}}} + ), + ), + since_token=initial_sync_result.next_batch, + ) + ) + + # The state event should appear in the 'state' section of the response. + room_sync = incremental_sync.joined[0] + self.assertEqual(room_sync.room_id, room_id) + self.assertTrue(room_sync.timeline.limited) + self.assertEqual( + [e.event_id for e in room_sync.timeline.events], + [e3_event, e4_event], + ) + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [s2_event], + ) + + @parameterized.expand( + [ + (False, False), + (True, False), + (False, True), + (True, True), + ] + ) + def test_archived_rooms_do_not_include_state_after_leave( + self, initial_sync: bool, empty_timeline: bool + ) -> None: + """If the user leaves the room, state changes that happen after they leave are not returned. + + We try with both a zero and a normal timeline limit, + and we try both an initial sync and an incremental sync for both. + """ + if empty_timeline and not initial_sync: + # FIXME synapse doesn't return the room at all in this situation! + self.skipTest("Synapse does not correctly handle this case") + + # Alice creates the room, and bob joins. + alice = self.register_user("alice", "password") + alice_tok = self.login(alice, "password") + + bob = self.register_user("bob", "password") + bob_tok = self.login(bob, "password") + bob_requester = create_requester(bob) + + room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) + self.helper.join(room_id, bob, tok=bob_tok) + + initial_sync_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + bob_requester, generate_sync_config(bob) + ) + ) + + # Alice sends a message and a state + before_message_event = self.helper.send(room_id, "before", tok=alice_tok)[ + "event_id" + ] + before_state_event = self.helper.send_state( + room_id, "test_state", {"body": "before"}, tok=alice_tok + )["event_id"] + + # Bob leaves + leave_event = self.helper.leave(room_id, bob, tok=bob_tok)["event_id"] + + # Alice sends some more stuff + self.helper.send(room_id, "after", tok=alice_tok)["event_id"] + self.helper.send_state(room_id, "test_state", {"body": "after"}, tok=alice_tok)[ + "event_id" + ] + + # And now, Bob resyncs. + filter_dict: JsonDict = {"room": {"include_leave": True}} + if empty_timeline: + filter_dict["room"]["timeline"] = {"limit": 0} + sync_room_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + bob_requester, + generate_sync_config( + bob, filter_collection=FilterCollection(self.hs, filter_dict) + ), + since_token=None if initial_sync else initial_sync_result.next_batch, + ) + ).archived[0] + + if empty_timeline: + # The timeline should be empty + self.assertEqual(sync_room_result.timeline.events, []) + + # And the state should include the leave event... + self.assertEqual( + sync_room_result.state[("m.room.member", bob)].event_id, leave_event + ) + # ... and the state change before he left. + self.assertEqual( + sync_room_result.state[("test_state", "")].event_id, before_state_event + ) + else: + # The last three events in the timeline should be those leading up to the + # leave + self.assertEqual( + [e.event_id for e in sync_room_result.timeline.events[-3:]], + [before_message_event, before_state_event, leave_event], + ) + # ... And the state should be empty + self.assertEqual(sync_room_result.state, {}) + + def _patch_get_latest_events(self, latest_events: List[str]) -> ContextManager: + """Monkey-patch `get_prev_events_for_room` + + Returns a context manager which will replace the implementation of + `get_prev_events_for_room` with one which returns `latest_events`. + """ + return patch.object( + self.hs.get_datastores().main, + "get_prev_events_for_room", + new_callable=AsyncMock, + return_value=latest_events, + ) + def test_call_invite_in_public_room_not_returned(self) -> None: user = self.register_user("alice", "password") tok = self.login(user, "password") @@ -401,14 +571,26 @@ _request_key = 0 def generate_sync_config( - user_id: str, device_id: Optional[str] = "device_id" + user_id: str, + device_id: Optional[str] = "device_id", + filter_collection: Optional[FilterCollection] = None, ) -> SyncConfig: - """Generate a sync config (with a unique request key).""" + """Generate a sync config (with a unique request key). + + Args: + user_id: user who is syncing. + device_id: device that is syncing. Defaults to "device_id". + filter_collection: filter to apply. Defaults to the default filter (ie, + return everything, with a default limit) + """ + if filter_collection is None: + filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION + global _request_key _request_key += 1 return SyncConfig( user=UserID.from_string(user_id), - filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION, + filter_collection=filter_collection, is_guest=False, request_key=("request_key", _request_key), device_id=device_id, |