summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-07-15 16:13:04 +0100
committerGitHub <noreply@github.com>2024-07-15 16:13:04 +0100
commitdf11af14dbd2faad916924cab96e75bd7c95a66a (patch)
tree2de4fa075d8876aafda9cce06862c577f1219a5e /synapse
parentBump types-jsonschema from 4.22.0.20240610 to 4.23.0.20240712 (#17446) (diff)
downloadsynapse-df11af14dbd2faad916924cab96e75bd7c95a66a.tar.xz
Fix bug where sync could get stuck when using workers (#17438)
This is because we serialized the token wrong if the instance map
contained entries from before the minimum token.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sliding_sync.py11
-rw-r--r--synapse/types/__init__.py65
2 files changed, 66 insertions, 10 deletions
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index be98b379eb..1b5262d667 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -699,10 +699,17 @@ class SlidingSyncHandler:
                 instance_to_max_stream_ordering_map[instance_name] = stream_ordering
 
         # Then assemble the `RoomStreamToken`
+        min_stream_pos = min(instance_to_max_stream_ordering_map.values())
         membership_snapshot_token = RoomStreamToken(
             # Minimum position in the `instance_map`
-            stream=min(instance_to_max_stream_ordering_map.values()),
-            instance_map=immutabledict(instance_to_max_stream_ordering_map),
+            stream=min_stream_pos,
+            instance_map=immutabledict(
+                {
+                    instance_name: stream_pos
+                    for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
+                    if stream_pos > min_stream_pos
+                }
+            ),
         )
 
         # Since we fetched the users room list at some point in time after the from/to
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index b22a13ef01..3962ecc996 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -20,6 +20,7 @@
 #
 #
 import abc
+import logging
 import re
 import string
 from enum import Enum
@@ -74,6 +75,9 @@ if TYPE_CHECKING:
     from synapse.storage.databases.main import DataStore, PurgeEventsStore
     from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 
+
+logger = logging.getLogger(__name__)
+
 # Define a state map type from type/state_key to T (usually an event ID or
 # event)
 T = TypeVar("T")
@@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
     represented by a default `stream` attribute and a map of instance name to
     stream position of any writers that are ahead of the default stream
     position.
+
+    The values in `instance_map` must be greater than the `stream` attribute.
     """
 
     stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
@@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
         kw_only=True,
     )
 
+    def __attrs_post_init__(self) -> None:
+        # Enforce that all instances have a value greater than the min stream
+        # position.
+        for i, v in self.instance_map.items():
+            if v <= self.stream:
+                raise ValueError(
+                    f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
+                )
+
     @classmethod
     @abc.abstractmethod
     async def parse(cls, store: "DataStore", string: str) -> "Self":
@@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
             for instance in set(self.instance_map).union(other.instance_map)
         }
 
+        # Filter out any redundant entries.
+        instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
+
         return attr.evolve(
             self, stream=max_stream, instance_map=immutabledict(instance_map)
         )
@@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
     def bound_stream_token(self, max_stream: int) -> "Self":
         """Bound the stream positions to a maximum value"""
 
+        min_pos = min(self.stream, max_stream)
         return type(self)(
-            stream=min(self.stream, max_stream),
+            stream=min_pos,
             instance_map=immutabledict(
-                {k: min(s, max_stream) for k, s in self.instance_map.items()}
+                {
+                    k: min(s, max_stream)
+                    for k, s in self.instance_map.items()
+                    if min(s, max_stream) > min_pos
+                }
             ),
         )
 
@@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
                 "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
             )
 
+        super().__attrs_post_init__()
+
     @classmethod
     async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
         try:
@@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
 
                 instance_map = {}
                 for part in parts[1:]:
+                    if not part:
+                        # Handle tokens of the form `m5~`, which were created by
+                        # a bug
+                        continue
+
                     key, value = part.split(".")
                     instance_id = int(key)
                     pos = int(value)
@@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
         except CancelledError:
             raise
         except Exception:
-            pass
+            # We log an exception here as even though this *might* be a client
+            # handing a bad token, its more likely that Synapse returned a bad
+            # token (and we really want to catch those!).
+            logger.exception("Failed to parse stream token: %r", string)
         raise SynapseError(400, "Invalid room stream token %r" % (string,))
 
     @classmethod
@@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
         return self.instance_map.get(instance_name, self.stream)
 
     async def to_string(self, store: "DataStore") -> str:
+        """See class level docstring for information about the format."""
+
         if self.topological is not None:
             return "t%d-%d" % (self.topological, self.stream)
         elif self.instance_map:
@@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
                 instance_id = await store.get_id_for_instance(name)
                 entries.append(f"{instance_id}.{pos}")
 
-            encoded_map = "~".join(entries)
-            return f"m{self.stream}~{encoded_map}"
+            if entries:
+                encoded_map = "~".join(entries)
+                return f"m{self.stream}~{encoded_map}"
+            return f"s{self.stream}"
         else:
             return "s%d" % (self.stream,)
 
@@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
 
                 instance_map = {}
                 for part in parts[1:]:
+                    if not part:
+                        # Handle tokens of the form `m5~`, which were created by
+                        # a bug
+                        continue
+
                     key, value = part.split(".")
                     instance_id = int(key)
                     pos = int(value)
@@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
         except CancelledError:
             raise
         except Exception:
-            pass
+            # We log an exception here as even though this *might* be a client
+            # handing a bad token, its more likely that Synapse returned a bad
+            # token (and we really want to catch those!).
+            logger.exception("Failed to parse stream token: %r", string)
         raise SynapseError(400, "Invalid stream token %r" % (string,))
 
     async def to_string(self, store: "DataStore") -> str:
+        """See class level docstring for information about the format."""
+
         if self.instance_map:
             entries = []
             for name, pos in self.instance_map.items():
@@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
                 instance_id = await store.get_id_for_instance(name)
                 entries.append(f"{instance_id}.{pos}")
 
-            encoded_map = "~".join(entries)
-            return f"m{self.stream}~{encoded_map}"
+            if entries:
+                encoded_map = "~".join(entries)
+                return f"m{self.stream}~{encoded_map}"
+            return str(self.stream)
         else:
             return str(self.stream)