summary refs log tree commit diff
path: root/synapse/streams/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/streams/events.py')
-rw-r--r--synapse/streams/events.py64
1 files changed, 63 insertions, 1 deletions
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index dd7401ac8e..93d5ae1a55 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -30,7 +30,12 @@ from synapse.handlers.room import RoomEventSource
 from synapse.handlers.typing import TypingNotificationEventSource
 from synapse.logging.opentracing import trace
 from synapse.streams import EventSource
-from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
+from synapse.types import (
+    AbstractMultiWriterStreamToken,
+    MultiWriterStreamToken,
+    StreamKeyType,
+    StreamToken,
+)
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -91,6 +96,63 @@ class EventSources:
         )
         return token
 
+    async def bound_future_token(self, token: StreamToken) -> StreamToken:
+        """Bound a token that is ahead of the current token to the maximum
+        persisted values.
+
+        This ensures that if we wait for the given token we know the stream will
+        eventually advance to that point.
+
+        This works around a bug where older Synapse versions will give out
+        tokens for streams, and then after a restart will give back tokens where
+        the stream has "gone backwards".
+        """
+
+        current_token = self.get_current_token()
+
+        stream_key_to_id_gen = {
+            StreamKeyType.ROOM: self.store.get_events_stream_id_generator(),
+            StreamKeyType.PRESENCE: self.store.get_presence_stream_id_gen(),
+            StreamKeyType.RECEIPT: self.store.get_receipts_stream_id_gen(),
+            StreamKeyType.ACCOUNT_DATA: self.store.get_account_data_id_generator(),
+            StreamKeyType.PUSH_RULES: self.store.get_push_rules_stream_id_gen(),
+            StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
+            StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
+            StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
+        }
+
+        for _, key in StreamKeyType.__members__.items():
+            if key == StreamKeyType.TYPING:
+                # Typing stream is allowed to "reset", and so comparisons don't
+                # really make sense as is.
+                # TODO: Figure out a better way of tracking resets.
+                continue
+
+            token_value = token.get_field(key)
+            current_value = current_token.get_field(key)
+
+            if isinstance(token_value, AbstractMultiWriterStreamToken):
+                assert type(current_value) is type(token_value)
+
+                if not token_value.is_before_or_eq(current_value):  # type: ignore[arg-type]
+                    max_token = await stream_key_to_id_gen[
+                        key
+                    ].get_max_allocated_token()
+
+                    token = token.copy_and_replace(
+                        key, token.room_key.bound_stream_token(max_token)
+                    )
+            else:
+                assert isinstance(current_value, int)
+                if current_value < token_value:
+                    max_token = await stream_key_to_id_gen[
+                        key
+                    ].get_max_allocated_token()
+
+                    token = token.copy_and_replace(key, min(token_value, max_token))
+
+        return token
+
     @trace
     async def get_start_token_for_pagination(self, room_id: str) -> StreamToken:
         """Get the start token for a given room to be used to paginate