summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17215.bugfix1
-rw-r--r--pyproject.toml6
-rw-r--r--synapse/handlers/sync.py35
-rw-r--r--synapse/notifier.py23
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py11
-rw-r--r--synapse/types/__init__.py58
7 files changed, 134 insertions, 7 deletions
diff --git a/changelog.d/17215.bugfix b/changelog.d/17215.bugfix
new file mode 100644
index 0000000000..10981b798e
--- /dev/null
+++ b/changelog.d/17215.bugfix
@@ -0,0 +1 @@
+Fix bug where duplicate events could be sent down sync when using workers that are overloaded.
diff --git a/pyproject.toml b/pyproject.toml
index ea14b98199..9a3348be49 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -200,10 +200,8 @@ netaddr = ">=0.7.18"
 # add a lower bound to the Jinja2 dependency.
 Jinja2 = ">=3.0"
 bleach = ">=1.4.3"
-# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
-# Additionally we need https://github.com/python/typing/pull/817 to allow types to be
-# generic over ParamSpecs.
-typing-extensions = ">=3.10.0.1"
+# We use `Self`, which were added in `typing-extensions` 4.0.
+typing-extensions = ">=4.0"
 # We enforce that we have a `cryptography` version that bundles an `openssl`
 # with the latest security patches.
 cryptography = ">=3.4.7"
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ac5bddd52f..1d7d9dfdd0 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -284,6 +284,23 @@ class SyncResult:
             or self.device_lists
         )
 
+    @staticmethod
+    def empty(next_batch: StreamToken) -> "SyncResult":
+        "Return a new empty result"
+        return SyncResult(
+            next_batch=next_batch,
+            presence=[],
+            account_data=[],
+            joined=[],
+            invited=[],
+            knocked=[],
+            archived=[],
+            to_device=[],
+            device_lists=DeviceListUpdates(),
+            device_one_time_keys_count={},
+            device_unused_fallback_key_types=[],
+        )
+
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class E2eeSyncResult:
@@ -497,6 +514,24 @@ class SyncHandler:
         if context:
             context.tag = sync_label
 
+        if since_token is not None:
+            # We need to make sure this worker has caught up with the token. If
+            # this returns false it means we timed out waiting, and we should
+            # just return an empty response.
+            start = self.clock.time_msec()
+            if not await self.notifier.wait_for_stream_token(since_token):
+                logger.warning(
+                    "Timed out waiting for worker to catch up. Returning empty response"
+                )
+                return SyncResult.empty(since_token)
+
+            # If we've spent significant time waiting to catch up, take it off
+            # the timeout.
+            now = self.clock.time_msec()
+            if now - start > 1_000:
+                timeout -= now - start
+                timeout = max(timeout, 0)
+
         # if we have a since token, delete any to-device messages before that token
         # (since we now know that the device has received them)
         if since_token is not None:
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 7c1cd3b5f2..ced9e9ad66 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -763,6 +763,29 @@ class Notifier:
 
         return result
 
+    async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
+        """Wait for this worker to catch up with the given stream token."""
+
+        start = self.clock.time_msec()
+        while True:
+            current_token = self.event_sources.get_current_token()
+            if stream_token.is_before_or_eq(current_token):
+                return True
+
+            now = self.clock.time_msec()
+
+            if now - start > 10_000:
+                return False
+
+            logger.info(
+                "Waiting for current token to reach %s; currently at %s",
+                stream_token,
+                current_token,
+            )
+
+            # TODO: be better
+            await self.clock.sleep(0.5)
+
     async def _get_room_ids(
         self, user: UserID, explicit_room_id: Optional[str]
     ) -> Tuple[StrCollection, bool]:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index fd7167904d..f1bd85aa27 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -95,6 +95,10 @@ class DeltaState:
     to_insert: StateMap[str]
     no_longer_in_room: bool = False
 
+    def is_noop(self) -> bool:
+        """Whether this state delta is actually empty"""
+        return not self.to_delete and not self.to_insert and not self.no_longer_in_room
+
 
 class PersistEventsStore:
     """Contains all the functions for writing events to the database.
@@ -1017,6 +1021,9 @@ class PersistEventsStore:
     ) -> None:
         """Update the current state stored in the datatabase for the given room"""
 
+        if state_delta.is_noop():
+            return
+
         async with self._stream_id_gen.get_next() as stream_ordering:
             await self.db_pool.runInteraction(
                 "update_current_state",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 426df2a9d2..c06c44deb1 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -200,7 +200,11 @@ class EventsWorkerStore(SQLBaseStore):
             notifier=hs.get_replication_notifier(),
             stream_name="events",
             instance_name=hs.get_instance_name(),
-            tables=[("events", "instance_name", "stream_ordering")],
+            tables=[
+                ("events", "instance_name", "stream_ordering"),
+                ("current_state_delta_stream", "instance_name", "stream_id"),
+                ("ex_outlier_stream", "instance_name", "event_stream_ordering"),
+            ],
             sequence_name="events_stream_seq",
             writers=hs.config.worker.writers.events,
         )
@@ -210,7 +214,10 @@ class EventsWorkerStore(SQLBaseStore):
             notifier=hs.get_replication_notifier(),
             stream_name="backfill",
             instance_name=hs.get_instance_name(),
-            tables=[("events", "instance_name", "stream_ordering")],
+            tables=[
+                ("events", "instance_name", "stream_ordering"),
+                ("ex_outlier_stream", "instance_name", "event_stream_ordering"),
+            ],
             sequence_name="events_backfill_stream_seq",
             positive=False,
             writers=hs.config.worker.writers.events,
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 509a2d3a0f..151658df53 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -48,7 +48,7 @@ import attr
 from immutabledict import immutabledict
 from signedjson.key import decode_verify_key_bytes
 from signedjson.types import VerifyKey
-from typing_extensions import TypedDict
+from typing_extensions import Self, TypedDict
 from unpaddedbase64 import decode_base64
 from zope.interface import Interface
 
@@ -515,6 +515,27 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
         # at `self.stream`.
         return self.instance_map.get(instance_name, self.stream)
 
+    def is_before_or_eq(self, other_token: Self) -> bool:
+        """Wether this token is before the other token, i.e. every constituent
+        part is before the other.
+
+        Essentially it is `self <= other`.
+
+        Note: if `self.is_before_or_eq(other_token) is False` then that does not
+        imply that the reverse is True.
+        """
+        if self.stream > other_token.stream:
+            return False
+
+        instances = self.instance_map.keys() | other_token.instance_map.keys()
+        for instance in instances:
+            if self.instance_map.get(
+                instance, self.stream
+            ) > other_token.instance_map.get(instance, other_token.stream):
+                return False
+
+        return True
+
 
 @attr.s(frozen=True, slots=True, order=False)
 class RoomStreamToken(AbstractMultiWriterStreamToken):
@@ -1008,6 +1029,41 @@ class StreamToken:
         """Returns the stream ID for the given key."""
         return getattr(self, key.value)
 
+    def is_before_or_eq(self, other_token: "StreamToken") -> bool:
+        """Wether this token is before the other token, i.e. every constituent
+        part is before the other.
+
+        Essentially it is `self <= other`.
+
+        Note: if `self.is_before_or_eq(other_token) is False` then that does not
+        imply that the reverse is True.
+        """
+
+        for _, key in StreamKeyType.__members__.items():
+            if key == StreamKeyType.TYPING:
+                # Typing stream is allowed to "reset", and so comparisons don't
+                # really make sense as is.
+                # TODO: Figure out a better way of tracking resets.
+                continue
+
+            self_value = self.get_field(key)
+            other_value = other_token.get_field(key)
+
+            if isinstance(self_value, RoomStreamToken):
+                assert isinstance(other_value, RoomStreamToken)
+                if not self_value.is_before_or_eq(other_value):
+                    return False
+            elif isinstance(self_value, MultiWriterStreamToken):
+                assert isinstance(other_value, MultiWriterStreamToken)
+                if not self_value.is_before_or_eq(other_value):
+                    return False
+            else:
+                assert isinstance(other_value, int)
+                if self_value > other_value:
+                    return False
+
+        return True
+
 
 StreamToken.START = StreamToken(
     RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0