diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index fa55f76916..6b202dfbd5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -17,10 +17,11 @@
# [This file includes modifications made by New Vector Limited]
#
#
+from http import HTTPStatus
from typing import Collection, ContextManager, List, Optional
from unittest.mock import AsyncMock, Mock, patch
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
@@ -32,7 +33,13 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json
-from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVersion
+from synapse.handlers.sync import (
+ SyncConfig,
+ SyncRequestKey,
+ SyncResult,
+ SyncVersion,
+ TimelineBatch,
+)
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
@@ -58,9 +65,21 @@ def generate_request_key() -> SyncRequestKey:
return ("request_key", _request_key)
+@parameterized_class(
+ ("use_state_after",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'state_after' if params_dict['use_state_after'] else 'state'}",
+)
class SyncTestCase(tests.unittest.HomeserverTestCase):
"""Tests Sync Handler."""
+ use_state_after: bool
+
servlets = [
admin.register_servlets,
knock.register_servlets,
@@ -79,7 +98,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self) -> None:
user_id1 = "@user1:test"
user_id2 = "@user2:test"
- sync_config = generate_sync_config(user_id1)
+ sync_config = generate_sync_config(
+ user_id1, use_state_after=self.use_state_after
+ )
requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
@@ -112,7 +133,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = False
- sync_config = generate_sync_config(user_id2)
+ sync_config = generate_sync_config(
+ user_id2, use_state_after=self.use_state_after
+ )
requester = create_requester(user_id2)
e = self.get_failure(
@@ -141,7 +164,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user, device_id="dev"),
+ sync_config=generate_sync_config(
+ user, device_id="dev", use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -175,7 +200,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user),
+ sync_config=generate_sync_config(
+ user, use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -188,7 +215,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user, device_id="dev"),
+ sync_config=generate_sync_config(
+ user, device_id="dev", use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_result.next_batch,
@@ -220,7 +249,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user),
+ sync_config=generate_sync_config(
+ user, use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -233,7 +264,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user, device_id="dev"),
+ sync_config=generate_sync_config(
+ user, device_id="dev", use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_result.next_batch,
@@ -276,7 +309,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
alice_sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(owner),
- generate_sync_config(owner),
+ generate_sync_config(owner, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -296,7 +329,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Eve syncs.
eve_requester = create_requester(eve)
- eve_sync_config = generate_sync_config(eve)
+ eve_sync_config = generate_sync_config(
+ eve, use_state_after=self.use_state_after
+ )
eve_sync_after_ban: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
eve_requester,
@@ -313,7 +348,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# the prev_events used when creating the join event, such that the ban does not
# precede the join.
with self._patch_get_latest_events([last_room_creation_event_id]):
- self.helper.join(room_id, eve, tok=eve_token)
+ self.helper.join(
+ room_id,
+ eve,
+ tok=eve_token,
+ # Previously, this join would succeed but now we expect it to fail at
+ # this point. The rest of the test is for the case when this used to
+ # succeed.
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
# Eve makes a second, incremental sync.
eve_incremental_sync_after_join: SyncResult = self.get_success(
@@ -367,7 +410,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -396,6 +439,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 2}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -442,7 +486,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -481,6 +525,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
}
},
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -518,6 +563,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
... and a filter that means we only return 1 event, represented by the dashed
horizontal lines: `S2` must be included in the `state` section on the second sync.
+
+ When `use_state_after` is enabled, then we expect to see `s2` in the first sync.
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
@@ -528,7 +575,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -554,6 +601,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -567,10 +615,18 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e3_event],
)
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [],
- )
+
+ if self.use_state_after:
+ # When using `state_after` we get told about s2 immediately
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [s2_event],
+ )
+ else:
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [],
+ )
# Now send another event that points to S2, but not E3.
with self._patch_get_latest_events([s2_event]):
@@ -585,6 +641,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -598,10 +655,19 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e4_event],
)
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [s2_event],
- )
+
+ if self.use_state_after:
+ # When using `state_after` we got told about s2 previously, so we
+ # don't again.
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [],
+ )
+ else:
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [s2_event],
+ )
def test_state_includes_changes_on_ungappy_syncs(self) -> None:
"""Test `state` where the sync is not gappy.
@@ -638,6 +704,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
This is the last chance for us to tell the client about S2, so it *must* be
included in the response.
+
+ When `use_state_after` is enabled, then we expect to see `s2` in the first sync.
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
@@ -648,7 +716,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -673,6 +741,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -684,7 +753,11 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e3_event],
)
- self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()])
+ if self.use_state_after:
+ # When using `state_after` we get told about s2 immediately
+ self.assertIn(s2_event, [e.event_id for e in room_sync.state.values()])
+ else:
+ self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()])
# More events, E4 and E5
with self._patch_get_latest_events([e3_event]):
@@ -695,7 +768,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
incremental_sync = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_sync_result.next_batch,
@@ -710,10 +783,19 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e4_event, e5_event],
)
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [s2_event],
- )
+
+ if self.use_state_after:
+ # When using `state_after` we got told about s2 previously, so we
+ # don't again.
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [],
+ )
+ else:
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [s2_event],
+ )
@parameterized.expand(
[
@@ -721,7 +803,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
(True, False),
(False, True),
(True, True),
- ]
+ ],
+ name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}_{p.args[1]}",
)
def test_archived_rooms_do_not_include_state_after_leave(
self, initial_sync: bool, empty_timeline: bool
@@ -749,7 +832,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
bob_requester,
- generate_sync_config(bob),
+ generate_sync_config(bob, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -780,7 +863,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
bob_requester,
generate_sync_config(
- bob, filter_collection=FilterCollection(self.hs, filter_dict)
+ bob,
+ filter_collection=FilterCollection(self.hs, filter_dict),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -791,7 +876,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
if empty_timeline:
# The timeline should be empty
self.assertEqual(sync_room_result.timeline.events, [])
+ else:
+ # The last three events in the timeline should be those leading up to the
+ # leave
+ self.assertEqual(
+ [e.event_id for e in sync_room_result.timeline.events[-3:]],
+ [before_message_event, before_state_event, leave_event],
+ )
+ if empty_timeline or self.use_state_after:
# And the state should include the leave event...
self.assertEqual(
sync_room_result.state[("m.room.member", bob)].event_id, leave_event
@@ -801,12 +894,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_room_result.state[("test_state", "")].event_id, before_state_event
)
else:
- # The last three events in the timeline should be those leading up to the
- # leave
- self.assertEqual(
- [e.event_id for e in sync_room_result.timeline.events[-3:]],
- [before_message_event, before_state_event, leave_event],
- )
# ... And the state should be empty
self.assertEqual(sync_room_result.state, {})
@@ -843,7 +930,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
) -> List[EventBase]:
return list(pdus)
- self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
+ self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( # type: ignore[method-assign]
+ _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
+ )
prev_events = self.get_success(self.store.get_prev_events_for_room(room_id))
@@ -877,7 +966,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -926,7 +1015,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
private_sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(user2),
- generate_sync_config(user2),
+ generate_sync_config(user2, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -952,7 +1041,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -989,7 +1078,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_d = defer.ensureDeferred(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=since_token,
@@ -1044,7 +1133,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_d = defer.ensureDeferred(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=since_token,
@@ -1060,6 +1149,7 @@ def generate_sync_config(
user_id: str,
device_id: Optional[str] = "device_id",
filter_collection: Optional[FilterCollection] = None,
+ use_state_after: bool = False,
) -> SyncConfig:
"""Generate a sync config (with a unique request key).
@@ -1067,7 +1157,8 @@ def generate_sync_config(
user_id: user who is syncing.
device_id: device that is syncing. Defaults to "device_id".
filter_collection: filter to apply. Defaults to the default filter (ie,
- return everything, with a default limit)
+ return everything, with a default limit)
+ use_state_after: whether the `use_state_after` flag was set.
"""
if filter_collection is None:
filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION
@@ -1077,4 +1168,138 @@ def generate_sync_config(
filter_collection=filter_collection,
is_guest=False,
device_id=device_id,
+ use_state_after=use_state_after,
)
+
+
+class SyncStateAfterTestCase(tests.unittest.HomeserverTestCase):
+ """Tests Sync Handler state behavior when using `use_state_after."""
+
+ servlets = [
+ admin.register_servlets,
+ knock.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.sync_handler = self.hs.get_sync_handler()
+ self.store = self.hs.get_datastores().main
+
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.hs.get_auth_blocking()
+
+ def test_initial_sync_multiple_deltas(self) -> None:
+ """Test that if multiple state deltas have happened during processing of
+ a full state sync we return the correct state"""
+
+ user = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ # Create a room as the user and set some custom state.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ first_state = self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 1}, tok=tok
+ )
+
+ # Take a snapshot of the stream token, to simulate doing an initial sync
+ # at this point.
+ end_stream_token = self.hs.get_event_sources().get_current_token()
+
+ # Send some state *after* the stream token
+ self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 2}, tok=tok
+ )
+
+ # Calculating the full state will return the first state, and not the
+ # second.
+ state = self.get_success(
+ self.sync_handler._compute_state_delta_for_full_sync(
+ room_id=joined_room,
+ sync_config=generate_sync_config(user, use_state_after=True),
+ batch=TimelineBatch(
+ prev_batch=end_stream_token, events=[], limited=True
+ ),
+ end_token=end_stream_token,
+ members_to_fetch=None,
+ timeline_state={},
+ joined=True,
+ )
+ )
+ self.assertEqual(state[("m.test_event", "")], first_state["event_id"])
+
+ def test_incremental_sync_multiple_deltas(self) -> None:
+ """Test that if multiple state deltas have happened since an incremental
+ state sync we return the correct state"""
+
+ user = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ # Create a room as the user and set some custom state.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ # Take a snapshot of the stream token, to simulate doing an incremental sync
+ # from this point.
+ since_token = self.hs.get_event_sources().get_current_token()
+
+ self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 1}, tok=tok
+ )
+
+ # Send some state *after* the stream token
+ second_state = self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 2}, tok=tok
+ )
+
+ end_stream_token = self.hs.get_event_sources().get_current_token()
+
+ # Calculating the incrementals state will return the second state, and not the
+ # first.
+ state = self.get_success(
+ self.sync_handler._compute_state_delta_for_incremental_sync(
+ room_id=joined_room,
+ sync_config=generate_sync_config(user, use_state_after=True),
+ batch=TimelineBatch(
+ prev_batch=end_stream_token, events=[], limited=True
+ ),
+ since_token=since_token,
+ end_token=end_stream_token,
+ members_to_fetch=None,
+ timeline_state={},
+ )
+ )
+ self.assertEqual(state[("m.test_event", "")], second_state["event_id"])
+
+ def test_incremental_sync_lazy_loaded_no_timeline(self) -> None:
+ """Test that lazy-loading with an empty timeline doesn't return the full
+ state.
+
+ There was a bug where an empty state filter would cause the DB to return
+ the full state, rather than an empty set.
+ """
+ user = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ # Create a room as the user and set some custom state.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ since_token = self.hs.get_event_sources().get_current_token()
+ end_stream_token = self.hs.get_event_sources().get_current_token()
+
+ state = self.get_success(
+ self.sync_handler._compute_state_delta_for_incremental_sync(
+ room_id=joined_room,
+ sync_config=generate_sync_config(user, use_state_after=True),
+ batch=TimelineBatch(
+ prev_batch=end_stream_token, events=[], limited=True
+ ),
+ since_token=since_token,
+ end_token=end_stream_token,
+ members_to_fetch=set(),
+ timeline_state={},
+ )
+ )
+
+ self.assertEqual(state, {})
|