summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_sync.py208
1 files changed, 195 insertions, 13 deletions
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 1b36324b8f..897c52c785 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -17,14 +17,16 @@
 # [This file includes modifications made by New Vector Limited]
 #
 #
-from typing import Collection, List, Optional
+from typing import Collection, ContextManager, List, Optional
 from unittest.mock import AsyncMock, Mock, patch
 
+from parameterized import parameterized
+
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.api.errors import Codes, ResourceLimitError
-from synapse.api.filtering import Filtering
+from synapse.api.filtering import FilterCollection, Filtering
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
@@ -33,7 +35,7 @@ from synapse.handlers.sync import SyncConfig, SyncResult
 from synapse.rest import admin
 from synapse.rest.client import knock, login, room
 from synapse.server import HomeServer
-from synapse.types import UserID, create_requester
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 
 import tests.unittest
@@ -258,13 +260,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # Eve tries to join the room. We monkey patch the internal logic which selects
         # the prev_events used when creating the join event, such that the ban does not
         # precede the join.
-        mocked_get_prev_events = patch.object(
-            self.hs.get_datastores().main,
-            "get_prev_events_for_room",
-            new_callable=AsyncMock,
-            return_value=[last_room_creation_event_id],
-        )
-        with mocked_get_prev_events:
+        with self._patch_get_latest_events([last_room_creation_event_id]):
             self.helper.join(room_id, eve, tok=eve_token)
 
         # Eve makes a second, incremental sync.
@@ -288,6 +284,180 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         )
         self.assertEqual(eve_initial_sync_after_join.joined, [])
 
+    def test_state_includes_changes_on_forks(self) -> None:
+        """State changes that happen on a fork of the DAG must be included in `state`
+
+        Given the following DAG:
+
+             E1
+           ↗    ↖
+          |      S2
+          |      ↑
+        --|------|----
+          |      |
+          E3     |
+           ↖    /
+             E4
+
+        ... and a filter that means we only return 2 events, represented by the dashed
+        horizontal line: `S2` must be included in the `state` section.
+        """
+        alice = self.register_user("alice", "password")
+        alice_tok = self.login(alice, "password")
+        alice_requester = create_requester(alice)
+        room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
+
+        # Do an initial sync as Alice to get a known starting point.
+        initial_sync_result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                alice_requester, generate_sync_config(alice)
+            )
+        )
+        last_room_creation_event_id = (
+            initial_sync_result.joined[0].timeline.events[-1].event_id
+        )
+
+        # Send a state event, and a regular event, both using the same prev ID
+        with self._patch_get_latest_events([last_room_creation_event_id]):
+            s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
+                "event_id"
+            ]
+            e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
+
+        # Send a final event, joining the two branches of the dag
+        e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"]
+
+        # do an incremental sync, with a filter that will ensure we only get two of
+        # the three new events.
+        incremental_sync = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                alice_requester,
+                generate_sync_config(
+                    alice,
+                    filter_collection=FilterCollection(
+                        self.hs, {"room": {"timeline": {"limit": 2}}}
+                    ),
+                ),
+                since_token=initial_sync_result.next_batch,
+            )
+        )
+
+        # The state event should appear in the 'state' section of the response.
+        room_sync = incremental_sync.joined[0]
+        self.assertEqual(room_sync.room_id, room_id)
+        self.assertTrue(room_sync.timeline.limited)
+        self.assertEqual(
+            [e.event_id for e in room_sync.timeline.events],
+            [e3_event, e4_event],
+        )
+        self.assertEqual(
+            [e.event_id for e in room_sync.state.values()],
+            [s2_event],
+        )
+
+    @parameterized.expand(
+        [
+            (False, False),
+            (True, False),
+            (False, True),
+            (True, True),
+        ]
+    )
+    def test_archived_rooms_do_not_include_state_after_leave(
+        self, initial_sync: bool, empty_timeline: bool
+    ) -> None:
+        """If the user leaves the room, state changes that happen after they leave are not returned.
+
+        We try with both a zero and a normal timeline limit,
+        and we try both an initial sync and an incremental sync for both.
+        """
+        if empty_timeline and not initial_sync:
+            # FIXME synapse doesn't return the room at all in this situation!
+            self.skipTest("Synapse does not correctly handle this case")
+
+        # Alice creates the room, and bob joins.
+        alice = self.register_user("alice", "password")
+        alice_tok = self.login(alice, "password")
+
+        bob = self.register_user("bob", "password")
+        bob_tok = self.login(bob, "password")
+        bob_requester = create_requester(bob)
+
+        room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
+        self.helper.join(room_id, bob, tok=bob_tok)
+
+        initial_sync_result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                bob_requester, generate_sync_config(bob)
+            )
+        )
+
+        # Alice sends a message and a state
+        before_message_event = self.helper.send(room_id, "before", tok=alice_tok)[
+            "event_id"
+        ]
+        before_state_event = self.helper.send_state(
+            room_id, "test_state", {"body": "before"}, tok=alice_tok
+        )["event_id"]
+
+        # Bob leaves
+        leave_event = self.helper.leave(room_id, bob, tok=bob_tok)["event_id"]
+
+        # Alice sends some more stuff
+        self.helper.send(room_id, "after", tok=alice_tok)["event_id"]
+        self.helper.send_state(room_id, "test_state", {"body": "after"}, tok=alice_tok)[
+            "event_id"
+        ]
+
+        # And now, Bob resyncs.
+        filter_dict: JsonDict = {"room": {"include_leave": True}}
+        if empty_timeline:
+            filter_dict["room"]["timeline"] = {"limit": 0}
+        sync_room_result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                bob_requester,
+                generate_sync_config(
+                    bob, filter_collection=FilterCollection(self.hs, filter_dict)
+                ),
+                since_token=None if initial_sync else initial_sync_result.next_batch,
+            )
+        ).archived[0]
+
+        if empty_timeline:
+            # The timeline should be empty
+            self.assertEqual(sync_room_result.timeline.events, [])
+
+            # And the state should include the leave event...
+            self.assertEqual(
+                sync_room_result.state[("m.room.member", bob)].event_id, leave_event
+            )
+            # ... and the state change before he left.
+            self.assertEqual(
+                sync_room_result.state[("test_state", "")].event_id, before_state_event
+            )
+        else:
+            # The last three events in the timeline should be those leading up to the
+            # leave
+            self.assertEqual(
+                [e.event_id for e in sync_room_result.timeline.events[-3:]],
+                [before_message_event, before_state_event, leave_event],
+            )
+            # ... And the state should be empty
+            self.assertEqual(sync_room_result.state, {})
+
+    def _patch_get_latest_events(self, latest_events: List[str]) -> ContextManager:
+        """Monkey-patch `get_prev_events_for_room`
+
+        Returns a context manager which will replace the implementation of
+        `get_prev_events_for_room` with one which returns `latest_events`.
+        """
+        return patch.object(
+            self.hs.get_datastores().main,
+            "get_prev_events_for_room",
+            new_callable=AsyncMock,
+            return_value=latest_events,
+        )
+
     def test_call_invite_in_public_room_not_returned(self) -> None:
         user = self.register_user("alice", "password")
         tok = self.login(user, "password")
@@ -401,14 +571,26 @@ _request_key = 0
 
 
 def generate_sync_config(
-    user_id: str, device_id: Optional[str] = "device_id"
+    user_id: str,
+    device_id: Optional[str] = "device_id",
+    filter_collection: Optional[FilterCollection] = None,
 ) -> SyncConfig:
-    """Generate a sync config (with a unique request key)."""
+    """Generate a sync config (with a unique request key).
+
+    Args:
+        user_id: user who is syncing.
+        device_id: device that is syncing. Defaults to "device_id".
+        filter_collection: filter to apply. Defaults to the default filter (ie,
+           return everything, with a default limit)
+    """
+    if filter_collection is None:
+        filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION
+
     global _request_key
     _request_key += 1
     return SyncConfig(
         user=UserID.from_string(user_id),
-        filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION,
+        filter_collection=filter_collection,
         is_guest=False,
         request_key=("request_key", _request_key),
         device_id=device_id,