summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17438.bugfix1
-rw-r--r--synapse/handlers/sliding_sync.py11
-rw-r--r--synapse/types/__init__.py65
-rw-r--r--tests/test_types.py71
4 files changed, 138 insertions, 10 deletions
diff --git a/changelog.d/17438.bugfix b/changelog.d/17438.bugfix
new file mode 100644
index 0000000000..cff6eecd48
--- /dev/null
+++ b/changelog.d/17438.bugfix
@@ -0,0 +1 @@
+Fix rare bug where `/sync` would break for a user when using workers with multiple stream writers.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index be98b379eb..1b5262d667 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -699,10 +699,17 @@ class SlidingSyncHandler:
                 instance_to_max_stream_ordering_map[instance_name] = stream_ordering
 
         # Then assemble the `RoomStreamToken`
+        min_stream_pos = min(instance_to_max_stream_ordering_map.values())
         membership_snapshot_token = RoomStreamToken(
             # Minimum position in the `instance_map`
-            stream=min(instance_to_max_stream_ordering_map.values()),
-            instance_map=immutabledict(instance_to_max_stream_ordering_map),
+            stream=min_stream_pos,
+            instance_map=immutabledict(
+                {
+                    instance_name: stream_pos
+                    for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
+                    if stream_pos > min_stream_pos
+                }
+            ),
         )
 
         # Since we fetched the users room list at some point in time after the from/to
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index b22a13ef01..3962ecc996 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -20,6 +20,7 @@
 #
 #
 import abc
+import logging
 import re
 import string
 from enum import Enum
@@ -74,6 +75,9 @@ if TYPE_CHECKING:
     from synapse.storage.databases.main import DataStore, PurgeEventsStore
     from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 
+
+logger = logging.getLogger(__name__)
+
 # Define a state map type from type/state_key to T (usually an event ID or
 # event)
 T = TypeVar("T")
@@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
     represented by a default `stream` attribute and a map of instance name to
     stream position of any writers that are ahead of the default stream
     position.
+
+    The values in `instance_map` must be greater than the `stream` attribute.
     """
 
     stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
@@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
         kw_only=True,
     )
 
+    def __attrs_post_init__(self) -> None:
+        # Enforce that all instances have a value greater than the min stream
+        # position.
+        for i, v in self.instance_map.items():
+            if v <= self.stream:
+                raise ValueError(
+                    f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
+                )
+
     @classmethod
     @abc.abstractmethod
     async def parse(cls, store: "DataStore", string: str) -> "Self":
@@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
             for instance in set(self.instance_map).union(other.instance_map)
         }
 
+        # Filter out any redundant entries.
+        instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
+
         return attr.evolve(
             self, stream=max_stream, instance_map=immutabledict(instance_map)
         )
@@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
     def bound_stream_token(self, max_stream: int) -> "Self":
         """Bound the stream positions to a maximum value"""
 
+        min_pos = min(self.stream, max_stream)
         return type(self)(
-            stream=min(self.stream, max_stream),
+            stream=min_pos,
             instance_map=immutabledict(
-                {k: min(s, max_stream) for k, s in self.instance_map.items()}
+                {
+                    k: min(s, max_stream)
+                    for k, s in self.instance_map.items()
+                    if min(s, max_stream) > min_pos
+                }
             ),
         )
 
@@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
                 "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
             )
 
+        super().__attrs_post_init__()
+
     @classmethod
     async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
         try:
@@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
 
                 instance_map = {}
                 for part in parts[1:]:
+                    if not part:
+                        # Handle tokens of the form `m5~`, which were created by
+                        # a bug
+                        continue
+
                     key, value = part.split(".")
                     instance_id = int(key)
                     pos = int(value)
@@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
         except CancelledError:
             raise
         except Exception:
-            pass
+            # We log an exception here as even though this *might* be a client
+            # handing a bad token, its more likely that Synapse returned a bad
+            # token (and we really want to catch those!).
+            logger.exception("Failed to parse stream token: %r", string)
         raise SynapseError(400, "Invalid room stream token %r" % (string,))
 
     @classmethod
@@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
         return self.instance_map.get(instance_name, self.stream)
 
     async def to_string(self, store: "DataStore") -> str:
+        """See class level docstring for information about the format."""
+
         if self.topological is not None:
             return "t%d-%d" % (self.topological, self.stream)
         elif self.instance_map:
@@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
                 instance_id = await store.get_id_for_instance(name)
                 entries.append(f"{instance_id}.{pos}")
 
-            encoded_map = "~".join(entries)
-            return f"m{self.stream}~{encoded_map}"
+            if entries:
+                encoded_map = "~".join(entries)
+                return f"m{self.stream}~{encoded_map}"
+            return f"s{self.stream}"
         else:
             return "s%d" % (self.stream,)
 
@@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
 
                 instance_map = {}
                 for part in parts[1:]:
+                    if not part:
+                        # Handle tokens of the form `m5~`, which were created by
+                        # a bug
+                        continue
+
                     key, value = part.split(".")
                     instance_id = int(key)
                     pos = int(value)
@@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
         except CancelledError:
             raise
         except Exception:
-            pass
+            # We log an exception here as even though this *might* be a client
+            # handing a bad token, its more likely that Synapse returned a bad
+            # token (and we really want to catch those!).
+            logger.exception("Failed to parse stream token: %r", string)
         raise SynapseError(400, "Invalid stream token %r" % (string,))
 
     async def to_string(self, store: "DataStore") -> str:
+        """See class level docstring for information about the format."""
+
         if self.instance_map:
             entries = []
             for name, pos in self.instance_map.items():
@@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
                 instance_id = await store.get_id_for_instance(name)
                 entries.append(f"{instance_id}.{pos}")
 
-            encoded_map = "~".join(entries)
-            return f"m{self.stream}~{encoded_map}"
+            if entries:
+                encoded_map = "~".join(entries)
+                return f"m{self.stream}~{encoded_map}"
+            return str(self.stream)
         else:
             return str(self.stream)
 
diff --git a/tests/test_types.py b/tests/test_types.py
index 944aa784fc..00adc65a5a 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -19,9 +19,18 @@
 #
 #
 
+from typing import Type
+from unittest import skipUnless
+
+from immutabledict import immutabledict
+from parameterized import parameterized_class
+
 from synapse.api.errors import SynapseError
 from synapse.types import (
+    AbstractMultiWriterStreamToken,
+    MultiWriterStreamToken,
     RoomAlias,
+    RoomStreamToken,
     UserID,
     get_domain_from_id,
     get_localpart_from_id,
@@ -29,6 +38,7 @@ from synapse.types import (
 )
 
 from tests import unittest
+from tests.utils import USE_POSTGRES_FOR_TESTS
 
 
 class IsMineIDTests(unittest.HomeserverTestCase):
@@ -127,3 +137,64 @@ class MapUsernameTestCase(unittest.TestCase):
         # this should work with either a unicode or a bytes
         self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
         self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
+
+
+@parameterized_class(
+    ("token_type",),
+    [
+        (MultiWriterStreamToken,),
+        (RoomStreamToken,),
+    ],
+    class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
+)
+class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
+    """Tests for the different types of multi writer tokens."""
+
+    token_type: Type[AbstractMultiWriterStreamToken]
+
+    def test_basic_token(self) -> None:
+        """Test that a simple stream token can be serialized and unserialized"""
+        store = self.hs.get_datastores().main
+
+        token = self.token_type(stream=5)
+
+        string_token = self.get_success(token.to_string(store))
+
+        if isinstance(token, RoomStreamToken):
+            self.assertEqual(string_token, "s5")
+        else:
+            self.assertEqual(string_token, "5")
+
+        parsed_token = self.get_success(self.token_type.parse(store, string_token))
+        self.assertEqual(parsed_token, token)
+
+    @skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
+    def test_instance_map(self) -> None:
+        """Test for stream token with instance map"""
+        store = self.hs.get_datastores().main
+
+        token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))
+
+        string_token = self.get_success(token.to_string(store))
+        self.assertEqual(string_token, "m5~1.6")
+
+        parsed_token = self.get_success(self.token_type.parse(store, string_token))
+        self.assertEqual(parsed_token, token)
+
+    def test_instance_map_assertion(self) -> None:
+        """Test that we assert values in the instance map are greater than the
+        min stream position"""
+
+        with self.assertRaises(ValueError):
+            self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))
+
+        with self.assertRaises(ValueError):
+            self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))
+
+    def test_parse_bad_token(self) -> None:
+        """Test that we can parse tokens produced by a bug in Synapse of the
+        form `m5~`"""
+        store = self.hs.get_datastores().main
+
+        parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
+        self.assertEqual(parsed_token, self.token_type(stream=5))