diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index dd7401ac8e..93d5ae1a55 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -30,7 +30,12 @@ from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace
from synapse.streams import EventSource
-from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
+from synapse.types import (
+ AbstractMultiWriterStreamToken,
+ MultiWriterStreamToken,
+ StreamKeyType,
+ StreamToken,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -91,6 +96,63 @@ class EventSources:
)
return token
+ async def bound_future_token(self, token: StreamToken) -> StreamToken:
+ """Bound a token that is ahead of the current token to the maximum
+ persisted values.
+
+ This ensures that if we wait for the given token we know the stream will
+ eventually advance to that point.
+
+ This works around a bug where older Synapse versions will give out
+ tokens for streams, and then after a restart will give back tokens where
+ the stream has "gone backwards".
+ """
+
+ current_token = self.get_current_token()
+
+ stream_key_to_id_gen = {
+ StreamKeyType.ROOM: self.store.get_events_stream_id_generator(),
+ StreamKeyType.PRESENCE: self.store.get_presence_stream_id_gen(),
+ StreamKeyType.RECEIPT: self.store.get_receipts_stream_id_gen(),
+ StreamKeyType.ACCOUNT_DATA: self.store.get_account_data_id_generator(),
+ StreamKeyType.PUSH_RULES: self.store.get_push_rules_stream_id_gen(),
+ StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
+ StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
+ StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
+ }
+
+ for _, key in StreamKeyType.__members__.items():
+ if key == StreamKeyType.TYPING:
+ # Typing stream is allowed to "reset", and so comparisons don't
+ # really make sense as is.
+ # TODO: Figure out a better way of tracking resets.
+ continue
+
+ token_value = token.get_field(key)
+ current_value = current_token.get_field(key)
+
+ if isinstance(token_value, AbstractMultiWriterStreamToken):
+ assert type(current_value) is type(token_value)
+
+ if not token_value.is_before_or_eq(current_value): # type: ignore[arg-type]
+ max_token = await stream_key_to_id_gen[
+ key
+ ].get_max_allocated_token()
+
+ token = token.copy_and_replace(
+ key, token.room_key.bound_stream_token(max_token)
+ )
+ else:
+ assert isinstance(current_value, int)
+ if current_value < token_value:
+ max_token = await stream_key_to_id_gen[
+ key
+ ].get_max_allocated_token()
+
+ token = token.copy_and_replace(key, min(token_value, max_token))
+
+ return token
+
@trace
async def get_start_token_for_pagination(self, room_id: str) -> StreamToken:
"""Get the start token for a given room to be used to paginate
|