diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 1b36324b8f..897c52c785 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -17,14 +17,16 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from typing import Collection, List, Optional
+from typing import Collection, ContextManager, List, Optional
from unittest.mock import AsyncMock, Mock, patch
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
-from synapse.api.filtering import Filtering
+from synapse.api.filtering import FilterCollection, Filtering
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@@ -33,7 +35,7 @@ from synapse.handlers.sync import SyncConfig, SyncResult
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
-from synapse.types import UserID, create_requester
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
import tests.unittest
@@ -258,13 +260,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Eve tries to join the room. We monkey patch the internal logic which selects
# the prev_events used when creating the join event, such that the ban does not
# precede the join.
- mocked_get_prev_events = patch.object(
- self.hs.get_datastores().main,
- "get_prev_events_for_room",
- new_callable=AsyncMock,
- return_value=[last_room_creation_event_id],
- )
- with mocked_get_prev_events:
+ with self._patch_get_latest_events([last_room_creation_event_id]):
self.helper.join(room_id, eve, tok=eve_token)
# Eve makes a second, incremental sync.
@@ -288,6 +284,180 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEqual(eve_initial_sync_after_join.joined, [])
+ def test_state_includes_changes_on_forks(self) -> None:
+ """State changes that happen on a fork of the DAG must be included in `state`
+
+ Given the following DAG:
+
+ E1
+ ↗ ↖
+ | S2
+ | ↑
+ --|------|----
+ | |
+ E3 |
+ ↖ /
+ E4
+
+ ... and a filter that means we only return 2 events, represented by the dashed
+ horizontal line: `S2` must be included in the `state` section.
+ """
+ alice = self.register_user("alice", "password")
+ alice_tok = self.login(alice, "password")
+ alice_requester = create_requester(alice)
+ room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
+
+ # Do an initial sync as Alice to get a known starting point.
+ initial_sync_result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ alice_requester, generate_sync_config(alice)
+ )
+ )
+ last_room_creation_event_id = (
+ initial_sync_result.joined[0].timeline.events[-1].event_id
+ )
+
+ # Send a state event, and a regular event, both using the same prev ID
+ with self._patch_get_latest_events([last_room_creation_event_id]):
+ s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
+ "event_id"
+ ]
+ e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
+
+ # Send a final event, joining the two branches of the dag
+ e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"]
+
+ # do an incremental sync, with a filter that will ensure we only get two of
+ # the three new events.
+ incremental_sync = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ alice_requester,
+ generate_sync_config(
+ alice,
+ filter_collection=FilterCollection(
+ self.hs, {"room": {"timeline": {"limit": 2}}}
+ ),
+ ),
+ since_token=initial_sync_result.next_batch,
+ )
+ )
+
+ # The state event should appear in the 'state' section of the response.
+ room_sync = incremental_sync.joined[0]
+ self.assertEqual(room_sync.room_id, room_id)
+ self.assertTrue(room_sync.timeline.limited)
+ self.assertEqual(
+ [e.event_id for e in room_sync.timeline.events],
+ [e3_event, e4_event],
+ )
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [s2_event],
+ )
+
+ @parameterized.expand(
+ [
+ (False, False),
+ (True, False),
+ (False, True),
+ (True, True),
+ ]
+ )
+ def test_archived_rooms_do_not_include_state_after_leave(
+ self, initial_sync: bool, empty_timeline: bool
+ ) -> None:
+ """If the user leaves the room, state changes that happen after they leave are not returned.
+
+ We try with both a zero and a normal timeline limit,
+ and we try both an initial sync and an incremental sync for both.
+ """
+ if empty_timeline and not initial_sync:
+ # FIXME synapse doesn't return the room at all in this situation!
+ self.skipTest("Synapse does not correctly handle this case")
+
+ # Alice creates the room, and bob joins.
+ alice = self.register_user("alice", "password")
+ alice_tok = self.login(alice, "password")
+
+ bob = self.register_user("bob", "password")
+ bob_tok = self.login(bob, "password")
+ bob_requester = create_requester(bob)
+
+ room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
+ self.helper.join(room_id, bob, tok=bob_tok)
+
+ initial_sync_result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ bob_requester, generate_sync_config(bob)
+ )
+ )
+
+ # Alice sends a message and a state
+ before_message_event = self.helper.send(room_id, "before", tok=alice_tok)[
+ "event_id"
+ ]
+ before_state_event = self.helper.send_state(
+ room_id, "test_state", {"body": "before"}, tok=alice_tok
+ )["event_id"]
+
+ # Bob leaves
+ leave_event = self.helper.leave(room_id, bob, tok=bob_tok)["event_id"]
+
+ # Alice sends some more stuff
+ self.helper.send(room_id, "after", tok=alice_tok)["event_id"]
+ self.helper.send_state(room_id, "test_state", {"body": "after"}, tok=alice_tok)[
+ "event_id"
+ ]
+
+ # And now, Bob resyncs.
+ filter_dict: JsonDict = {"room": {"include_leave": True}}
+ if empty_timeline:
+ filter_dict["room"]["timeline"] = {"limit": 0}
+ sync_room_result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ bob_requester,
+ generate_sync_config(
+ bob, filter_collection=FilterCollection(self.hs, filter_dict)
+ ),
+ since_token=None if initial_sync else initial_sync_result.next_batch,
+ )
+ ).archived[0]
+
+ if empty_timeline:
+ # The timeline should be empty
+ self.assertEqual(sync_room_result.timeline.events, [])
+
+ # And the state should include the leave event...
+ self.assertEqual(
+ sync_room_result.state[("m.room.member", bob)].event_id, leave_event
+ )
+ # ... and the state change before he left.
+ self.assertEqual(
+ sync_room_result.state[("test_state", "")].event_id, before_state_event
+ )
+ else:
+ # The last three events in the timeline should be those leading up to the
+ # leave
+ self.assertEqual(
+ [e.event_id for e in sync_room_result.timeline.events[-3:]],
+ [before_message_event, before_state_event, leave_event],
+ )
+ # ... And the state should be empty
+ self.assertEqual(sync_room_result.state, {})
+
+ def _patch_get_latest_events(self, latest_events: List[str]) -> ContextManager:
+ """Monkey-patch `get_prev_events_for_room`
+
+ Returns a context manager which will replace the implementation of
+ `get_prev_events_for_room` with one which returns `latest_events`.
+ """
+ return patch.object(
+ self.hs.get_datastores().main,
+ "get_prev_events_for_room",
+ new_callable=AsyncMock,
+ return_value=latest_events,
+ )
+
def test_call_invite_in_public_room_not_returned(self) -> None:
user = self.register_user("alice", "password")
tok = self.login(user, "password")
@@ -401,14 +571,26 @@ _request_key = 0
def generate_sync_config(
- user_id: str, device_id: Optional[str] = "device_id"
+ user_id: str,
+ device_id: Optional[str] = "device_id",
+ filter_collection: Optional[FilterCollection] = None,
) -> SyncConfig:
- """Generate a sync config (with a unique request key)."""
+ """Generate a sync config (with a unique request key).
+
+ Args:
+ user_id: user who is syncing.
+ device_id: device that is syncing. Defaults to "device_id".
+ filter_collection: filter to apply. Defaults to the default filter (ie,
+ return everything, with a default limit)
+ """
+ if filter_collection is None:
+ filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION
+
global _request_key
_request_key += 1
return SyncConfig(
user=UserID.from_string(user_id),
- filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION,
+ filter_collection=filter_collection,
is_guest=False,
request_key=("request_key", _request_key),
device_id=device_id,
|