summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17386.bugfix1
-rw-r--r--synapse/notifier.py7
-rw-r--r--synapse/storage/databases/main/account_data.py10
-rw-r--r--synapse/storage/databases/main/deviceinbox.py10
-rw-r--r--synapse/storage/databases/main/devices.py3
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/presence.py10
-rw-r--r--synapse/storage/databases/main/push_rule.py3
-rw-r--r--synapse/storage/databases/main/receipts.py10
-rw-r--r--synapse/storage/databases/main/room.py11
-rw-r--r--synapse/storage/databases/main/stream.py3
-rw-r--r--synapse/storage/util/id_generators.py5
-rw-r--r--synapse/storage/util/sequence.py24
-rw-r--r--synapse/streams/events.py64
-rw-r--r--synapse/types/__init__.py18
-rw-r--r--tests/handlers/test_sync.py73
-rw-r--r--tests/rest/client/test_sync.py4
17 files changed, 229 insertions, 31 deletions
diff --git a/changelog.d/17386.bugfix b/changelog.d/17386.bugfix
new file mode 100644
index 0000000000..9686b5c276
--- /dev/null
+++ b/changelog.d/17386.bugfix
@@ -0,0 +1 @@
+Fix bug where `/sync` requests could get blocked indefinitely after an upgrade from Synapse versions before v1.109.0.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c87eb748c0..c3ecf86ec4 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -764,6 +764,13 @@ class Notifier:
 
     async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
         """Wait for this worker to catch up with the given stream token."""
+        current_token = self.event_sources.get_current_token()
+        if stream_token.is_before_or_eq(current_token):
+            return True
+
+        # Work around a bug where older Synapse versions gave out tokens "from
+        # the future", i.e. that are ahead of the tokens persisted in the DB.
+        stream_token = await self.event_sources.bound_future_token(stream_token)
 
         start = self.clock.time_msec()
         while True:
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9611a84932..966393869b 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -43,10 +43,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import JsonDict, JsonMapping
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -71,7 +68,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             self._instance_name in hs.config.worker.writers.account_data
         )
 
-        self._account_data_id_gen: AbstractStreamIdGenerator
+        self._account_data_id_gen: MultiWriterIdGenerator
 
         self._account_data_id_gen = MultiWriterIdGenerator(
             db_conn=db_conn,
@@ -113,6 +110,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         """
         return self._account_data_id_gen.get_current_token()
 
+    def get_account_data_id_generator(self) -> MultiWriterIdGenerator:
+        return self._account_data_id_gen
+
     @cached()
     async def get_global_account_data_for_user(
         self, user_id: str
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 07333efff8..304ac42411 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -50,10 +50,7 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_in_list_sql_clause,
 )
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -92,7 +89,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             self._instance_name in hs.config.worker.writers.to_device
         )
 
-        self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
+        self._to_device_msg_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator(
             db_conn=db_conn,
             db=database,
             notifier=hs.get_replication_notifier(),
@@ -169,6 +166,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     def get_to_device_stream_token(self) -> int:
         return self._to_device_msg_id_gen.get_current_token()
 
+    def get_to_device_id_generator(self) -> MultiWriterIdGenerator:
+        return self._to_device_msg_id_gen
+
     async def get_messages_for_user_devices(
         self,
         user_ids: Collection[str],
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 59a035dd62..53024bddc3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -243,6 +243,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
+    def get_device_stream_id_generator(self) -> MultiWriterIdGenerator:
+        return self._device_list_id_gen
+
     async def count_devices_by_users(
         self, user_ids: Optional[Collection[str]] = None
     ) -> int:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e264d36f02..198e65cfa5 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -192,8 +192,8 @@ class EventsWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self._stream_id_gen: AbstractStreamIdGenerator
-        self._backfill_id_gen: AbstractStreamIdGenerator
+        self._stream_id_gen: MultiWriterIdGenerator
+        self._backfill_id_gen: MultiWriterIdGenerator
 
         self._stream_id_gen = MultiWriterIdGenerator(
             db_conn=db_conn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 923e764491..065c885603 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -42,10 +42,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines._base import IsolationLevel
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.iterutils import batch_iter
@@ -83,7 +80,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
-        self._presence_id_gen: AbstractStreamIdGenerator
+        self._presence_id_gen: MultiWriterIdGenerator
 
         self._can_persist_presence = (
             self._instance_name in hs.config.worker.writers.presence
@@ -455,6 +452,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
     def get_current_presence_token(self) -> int:
         return self._presence_id_gen.get_current_token()
 
+    def get_presence_stream_id_gen(self) -> MultiWriterIdGenerator:
+        return self._presence_id_gen
+
     def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
         """Fetch non-offline presence from the database so that we can register
         the appropriate time outs.
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 2a39dc9f90..bbdde17711 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -178,6 +178,9 @@ class PushRulesWorkerStore(
         """
         return self._push_rules_stream_id_gen.get_current_token()
 
+    def get_push_rules_stream_id_gen(self) -> MultiWriterIdGenerator:
+        return self._push_rules_stream_id_gen
+
     def process_replication_rows(
         self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
     ) -> None:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 8432560a89..3bde0ae0d4 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -45,10 +45,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines._base import IsolationLevel
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import (
     JsonDict,
     JsonMapping,
@@ -76,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._receipts_id_gen: AbstractStreamIdGenerator
+        self._receipts_id_gen: MultiWriterIdGenerator
 
         self._can_write_to_receipts = (
             self._instance_name in hs.config.worker.writers.receipts
@@ -136,6 +133,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
     def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
         return self._receipts_id_gen.get_current_token_for_writer(instance_name)
 
+    def get_receipts_stream_id_gen(self) -> MultiWriterIdGenerator:
+        return self._receipts_id_gen
+
     def get_last_unthreaded_receipt_for_user_txn(
         self,
         txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d5627b1d6e..80a4bf95f2 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -59,11 +59,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    IdGenerator,
-    MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
 from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -151,7 +147,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         self.config: HomeServerConfig = hs.config
 
-        self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
+        self._un_partial_stated_rooms_stream_id_gen: MultiWriterIdGenerator
 
         self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
             db_conn=db_conn,
@@ -1409,6 +1405,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             instance_name
         )
 
+    def get_un_partial_stated_rooms_id_generator(self) -> MultiWriterIdGenerator:
+        return self._un_partial_stated_rooms_stream_id_gen
+
     async def get_un_partial_stated_rooms_between(
         self, last_id: int, current_id: int, room_ids: Collection[str]
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index ff0d723684..b7eb3116ae 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -577,6 +577,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
 
+    def get_events_stream_id_generator(self) -> MultiWriterIdGenerator:
+        return self._stream_id_gen
+
     async def get_room_events_stream_for_rooms(
         self,
         room_ids: Collection[str],
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 48f88a6f8a..e8588f33cf 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -812,6 +812,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         pos = self.get_current_token_for_writer(self._instance_name)
         txn.execute(sql, (self._stream_name, self._instance_name, pos))
 
+    async def get_max_allocated_token(self) -> int:
+        return await self._db.runInteraction(
+            "get_max_allocated_token", self._sequence_gen.get_max_allocated
+        )
+
 
 @attr.s(frozen=True, auto_attribs=True)
 class _AsyncCtxManagerWrapper(Generic[T]):
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c4c0602b28..cac3eba1a5 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -88,6 +88,10 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         """
         ...
 
+    @abc.abstractmethod
+    def get_max_allocated(self, txn: Cursor) -> int:
+        """Get the maximum ID that we have allocated"""
+
 
 class PostgresSequenceGenerator(SequenceGenerator):
     """An implementation of SequenceGenerator which uses a postgres sequence"""
@@ -190,6 +194,17 @@ class PostgresSequenceGenerator(SequenceGenerator):
                 % {"seq": self._sequence_name, "stream_name": stream_name}
             )
 
+    def get_max_allocated(self, txn: Cursor) -> int:
+        # We just read from the sequence what the last value we fetched was.
+        txn.execute(f"SELECT last_value, is_called FROM {self._sequence_name}")
+        row = txn.fetchone()
+        assert row is not None
+
+        last_value, is_called = row
+        if not is_called:
+            last_value -= 1
+        return last_value
+
 
 GetFirstCallbackType = Callable[[Cursor], int]
 
@@ -248,6 +263,15 @@ class LocalSequenceGenerator(SequenceGenerator):
         # There is nothing to do for in memory sequences
         pass
 
+    def get_max_allocated(self, txn: Cursor) -> int:
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            return self._current_max_id
+
 
 def build_sequence_generator(
     db_conn: "LoggingDatabaseConnection",
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index dd7401ac8e..93d5ae1a55 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -30,7 +30,12 @@ from synapse.handlers.room import RoomEventSource
 from synapse.handlers.typing import TypingNotificationEventSource
 from synapse.logging.opentracing import trace
 from synapse.streams import EventSource
-from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
+from synapse.types import (
+    AbstractMultiWriterStreamToken,
+    MultiWriterStreamToken,
+    StreamKeyType,
+    StreamToken,
+)
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -91,6 +96,63 @@ class EventSources:
         )
         return token
 
+    async def bound_future_token(self, token: StreamToken) -> StreamToken:
+        """Bound a token that is ahead of the current token to the maximum
+        persisted values.
+
+        This ensures that if we wait for the given token we know the stream will
+        eventually advance to that point.
+
+        This works around a bug where older Synapse versions will give out
+        tokens for streams, and then after a restart will give back tokens where
+        the stream has "gone backwards".
+        """
+
+        current_token = self.get_current_token()
+
+        stream_key_to_id_gen = {
+            StreamKeyType.ROOM: self.store.get_events_stream_id_generator(),
+            StreamKeyType.PRESENCE: self.store.get_presence_stream_id_gen(),
+            StreamKeyType.RECEIPT: self.store.get_receipts_stream_id_gen(),
+            StreamKeyType.ACCOUNT_DATA: self.store.get_account_data_id_generator(),
+            StreamKeyType.PUSH_RULES: self.store.get_push_rules_stream_id_gen(),
+            StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
+            StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
+            StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
+        }
+
+        for _, key in StreamKeyType.__members__.items():
+            if key == StreamKeyType.TYPING:
+                # Typing stream is allowed to "reset", and so comparisons don't
+                # really make sense as is.
+                # TODO: Figure out a better way of tracking resets.
+                continue
+
+            token_value = token.get_field(key)
+            current_value = current_token.get_field(key)
+
+            if isinstance(token_value, AbstractMultiWriterStreamToken):
+                assert type(current_value) is type(token_value)
+
+                if not token_value.is_before_or_eq(current_value):  # type: ignore[arg-type]
+                    max_token = await stream_key_to_id_gen[
+                        key
+                    ].get_max_allocated_token()
+
+                    token = token.copy_and_replace(
+                        key, token.room_key.bound_stream_token(max_token)
+                    )
+            else:
+                assert isinstance(current_value, int)
+                if current_value < token_value:
+                    max_token = await stream_key_to_id_gen[
+                        key
+                    ].get_max_allocated_token()
+
+                    token = token.copy_and_replace(key, min(token_value, max_token))
+
+        return token
+
     @trace
     async def get_start_token_for_pagination(self, room_id: str) -> StreamToken:
         """Get the start token for a given room to be used to paginate
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 151658df53..8ab9f90238 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -536,6 +536,16 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
 
         return True
 
+    def bound_stream_token(self, max_stream: int) -> "Self":
+        """Bound the stream positions to a maximum value"""
+
+        return type(self)(
+            stream=min(self.stream, max_stream),
+            instance_map=immutabledict(
+                {k: min(s, max_stream) for k, s in self.instance_map.items()}
+            ),
+        )
+
 
 @attr.s(frozen=True, slots=True, order=False)
 class RoomStreamToken(AbstractMultiWriterStreamToken):
@@ -722,6 +732,14 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
         else:
             return "s%d" % (self.stream,)
 
+    def bound_stream_token(self, max_stream: int) -> "RoomStreamToken":
+        """See super class"""
+
+        # This only makes sense for stream tokens.
+        assert self.topological is None
+
+        return super().bound_stream_token(max_stream)
+
 
 @attr.s(frozen=True, slots=True, order=False)
 class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 02371ce724..5319928c28 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, Mock, patch
 
 from parameterized import parameterized
 
+from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules
@@ -35,7 +36,7 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe
 from synapse.rest import admin
 from synapse.rest.client import knock, login, room
 from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, StreamKeyType, UserID, create_requester
 from synapse.util import Clock
 
 import tests.unittest
@@ -959,6 +960,76 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
 
         self.fail("No push rules found")
 
+    def test_wait_for_future_sync_token(self) -> None:
+        """Test that if we receive a token that is ahead of our current token,
+        we'll wait until the stream position advances.
+
+        This can happen if replication streams start lagging, and the client's
+        previous sync request was serviced by a worker ahead of ours.
+        """
+        user = self.register_user("alice", "password")
+
+        # We simulate a lagging stream by getting a stream ID from the ID gen
+        # and then waiting to mark it as "persisted".
+        presence_id_gen = self.store.get_presence_stream_id_gen()
+        ctx_mgr = presence_id_gen.get_next()
+        stream_id = self.get_success(ctx_mgr.__aenter__())
+
+        # Create the new token based on the stream ID above.
+        current_token = self.hs.get_event_sources().get_current_token()
+        since_token = current_token.copy_and_advance(StreamKeyType.PRESENCE, stream_id)
+
+        sync_d = defer.ensureDeferred(
+            self.sync_handler.wait_for_sync_for_user(
+                create_requester(user),
+                generate_sync_config(user),
+                sync_version=SyncVersion.SYNC_V2,
+                request_key=generate_request_key(),
+                since_token=since_token,
+                timeout=0,
+            )
+        )
+
+        # This should block waiting for the presence stream to update
+        self.pump()
+        self.assertFalse(sync_d.called)
+
+        # Marking the stream ID as persisted should unblock the request.
+        self.get_success(ctx_mgr.__aexit__(None, None, None))
+
+        self.get_success(sync_d, by=1.0)
+
+    def test_wait_for_invalid_future_sync_token(self) -> None:
+        """Like the previous test, except we give a token that has a stream
+        position ahead of what is in the DB, i.e. its invalid and we shouldn't
+        wait for the stream to advance (as it may never do so).
+
+        This can happen due to older versions of Synapse giving out stream
+        positions without persisting them in the DB, and so on restart the
+        stream would get reset back to an older position.
+        """
+        user = self.register_user("alice", "password")
+
+        # Create a token and arbitrarily advance one of the streams.
+        current_token = self.hs.get_event_sources().get_current_token()
+        since_token = current_token.copy_and_advance(
+            StreamKeyType.PRESENCE, current_token.presence_key + 1
+        )
+
+        sync_d = defer.ensureDeferred(
+            self.sync_handler.wait_for_sync_for_user(
+                create_requester(user),
+                generate_sync_config(user),
+                sync_version=SyncVersion.SYNC_V2,
+                request_key=generate_request_key(),
+                since_token=since_token,
+                timeout=0,
+            )
+        )
+
+        # We should return without waiting for the presence stream to advance.
+        self.get_success(sync_d)
+
 
 def generate_sync_config(
     user_id: str,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index bfb26139d3..12c11f342c 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -1386,10 +1386,12 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         # Create a future token that will cause us to wait. Since we never send a new
         # event to reach that future stream_ordering, the worker will wait until the
         # full timeout.
+        stream_id_gen = self.store.get_events_stream_id_generator()
+        stream_id = self.get_success(stream_id_gen.get_next().__aenter__())
         current_token = self.event_sources.get_current_token()
         future_position_token = current_token.copy_and_replace(
             StreamKeyType.ROOM,
-            RoomStreamToken(stream=current_token.room_key.stream + 1),
+            RoomStreamToken(stream=stream_id),
         )
 
         future_position_token_serialized = self.get_success(