diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 5319928c28..674dd4fb54 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -36,7 +36,14 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict, StreamKeyType, UserID, create_requester
+from synapse.types import (
+ JsonDict,
+ MultiWriterStreamToken,
+ RoomStreamToken,
+ StreamKeyType,
+ UserID,
+ create_requester,
+)
from synapse.util import Clock
import tests.unittest
@@ -999,7 +1006,13 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.get_success(sync_d, by=1.0)
- def test_wait_for_invalid_future_sync_token(self) -> None:
+ @parameterized.expand(
+ [(key,) for key in StreamKeyType.__members__.values()],
+ name_func=lambda func, _, param: f"{func.__name__}_{param.args[0].name}",
+ )
+ def test_wait_for_invalid_future_sync_token(
+ self, stream_key: StreamKeyType
+ ) -> None:
"""Like the previous test, except we give a token that has a stream
position ahead of what is in the DB, i.e. its invalid and we shouldn't
wait for the stream to advance (as it may never do so).
@@ -1010,11 +1023,23 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
"""
user = self.register_user("alice", "password")
- # Create a token and arbitrarily advance one of the streams.
+ # Create a token and advance one of the streams.
current_token = self.hs.get_event_sources().get_current_token()
- since_token = current_token.copy_and_advance(
- StreamKeyType.PRESENCE, current_token.presence_key + 1
- )
+ token_value = current_token.get_field(stream_key)
+
+ # How we advance the streams depends on the type.
+ if isinstance(token_value, int):
+ since_token = current_token.copy_and_advance(stream_key, token_value + 1)
+ elif isinstance(token_value, MultiWriterStreamToken):
+ since_token = current_token.copy_and_advance(
+ stream_key, MultiWriterStreamToken(stream=token_value.stream + 1)
+ )
+ elif isinstance(token_value, RoomStreamToken):
+ since_token = current_token.copy_and_advance(
+ stream_key, RoomStreamToken(stream=token_value.stream + 1)
+ )
+ else:
+ raise Exception("Unreachable")
sync_d = defer.ensureDeferred(
self.sync_handler.wait_for_sync_for_user(
|