diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 02371ce724..5319928c28 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, Mock, patch
from parameterized import parameterized
+from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules
@@ -35,7 +36,7 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, StreamKeyType, UserID, create_requester
from synapse.util import Clock
import tests.unittest
@@ -959,6 +960,76 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.fail("No push rules found")
+ def test_wait_for_future_sync_token(self) -> None:
+ """Test that if we receive a token that is ahead of our current token,
+ we'll wait until the stream position advances.
+
+ This can happen if replication streams start lagging, and the client's
+ previous sync request was serviced by a worker ahead of ours.
+ """
+ user = self.register_user("alice", "password")
+
+ # We simulate a lagging stream by getting a stream ID from the ID gen
+ # and then waiting to mark it as "persisted".
+ presence_id_gen = self.store.get_presence_stream_id_gen()
+ ctx_mgr = presence_id_gen.get_next()
+ stream_id = self.get_success(ctx_mgr.__aenter__())
+
+ # Create the new token based on the stream ID above.
+ current_token = self.hs.get_event_sources().get_current_token()
+ since_token = current_token.copy_and_advance(StreamKeyType.PRESENCE, stream_id)
+
+ sync_d = defer.ensureDeferred(
+ self.sync_handler.wait_for_sync_for_user(
+ create_requester(user),
+ generate_sync_config(user),
+ sync_version=SyncVersion.SYNC_V2,
+ request_key=generate_request_key(),
+ since_token=since_token,
+ timeout=0,
+ )
+ )
+
+ # This should block waiting for the presence stream to update
+ self.pump()
+ self.assertFalse(sync_d.called)
+
+ # Marking the stream ID as persisted should unblock the request.
+ self.get_success(ctx_mgr.__aexit__(None, None, None))
+
+ self.get_success(sync_d, by=1.0)
+
+ def test_wait_for_invalid_future_sync_token(self) -> None:
+ """Like the previous test, except we give a token that has a stream
+ position ahead of what is in the DB, i.e. its invalid and we shouldn't
+ wait for the stream to advance (as it may never do so).
+
+ This can happen due to older versions of Synapse giving out stream
+ positions without persisting them in the DB, and so on restart the
+ stream would get reset back to an older position.
+ """
+ user = self.register_user("alice", "password")
+
+ # Create a token and arbitrarily advance one of the streams.
+ current_token = self.hs.get_event_sources().get_current_token()
+ since_token = current_token.copy_and_advance(
+ StreamKeyType.PRESENCE, current_token.presence_key + 1
+ )
+
+ sync_d = defer.ensureDeferred(
+ self.sync_handler.wait_for_sync_for_user(
+ create_requester(user),
+ generate_sync_config(user),
+ sync_version=SyncVersion.SYNC_V2,
+ request_key=generate_request_key(),
+ since_token=since_token,
+ timeout=0,
+ )
+ )
+
+ # We should return without waiting for the presence stream to advance.
+ self.get_success(sync_d)
+
def generate_sync_config(
user_id: str,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index bfb26139d3..12c11f342c 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -1386,10 +1386,12 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
# Create a future token that will cause us to wait. Since we never send a new
# event to reach that future stream_ordering, the worker will wait until the
# full timeout.
+ stream_id_gen = self.store.get_events_stream_id_generator()
+ stream_id = self.get_success(stream_id_gen.get_next().__aenter__())
current_token = self.event_sources.get_current_token()
future_position_token = current_token.copy_and_replace(
StreamKeyType.ROOM,
- RoomStreamToken(stream=current_token.room_key.stream + 1),
+ RoomStreamToken(stream=stream_id),
)
future_position_token_serialized = self.get_success(
|