summary refs log tree commit diff
diff options
context:
space:
mode:
authorOlivier 'reivilibre <oliverw@matrix.org>2024-05-18 19:50:05 +0100
committerOlivier 'reivilibre <oliverw@matrix.org>2024-05-18 19:50:05 +0100
commit02efa51f0f72f64cf878a227f92f6e365a596c47 (patch)
treec6ed355e029d406bcdf6b281aba77384ea83037c
parentMerge branch 'erikj/device_list_sync_perf' into matrix-org-hotfixes (diff)
parentAdd current token to log line (diff)
downloadsynapse-02efa51f0f72f64cf878a227f92f6e365a596c47.tar.xz
Merge branch 'erikj/wait_for_stream_pos' into matrix-org-hotfixes
-rw-r--r--changelog.d/17215.bugfix1
-rw-r--r--pyproject.toml6
-rw-r--r--synapse/handlers/sync.py35
-rw-r--r--synapse/notifier.py23
-rw-r--r--synapse/types/__init__.py58
5 files changed, 118 insertions, 5 deletions
diff --git a/changelog.d/17215.bugfix b/changelog.d/17215.bugfix
new file mode 100644
index 0000000000..10981b798e
--- /dev/null
+++ b/changelog.d/17215.bugfix
@@ -0,0 +1 @@
+Fix bug where duplicate events could be sent down sync when using workers that are overloaded.
diff --git a/pyproject.toml b/pyproject.toml
index dd4521ff71..a7d43b1744 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -200,10 +200,8 @@ netaddr = ">=0.7.18"
 # add a lower bound to the Jinja2 dependency.
 Jinja2 = ">=3.0"
 bleach = ">=1.4.3"
-# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
-# Additionally we need https://github.com/python/typing/pull/817 to allow types to be
-# generic over ParamSpecs.
-typing-extensions = ">=3.10.0.1"
+# We use `Self`, which were added in `typing-extensions` 4.0.
+typing-extensions = ">=4.0"
 # We enforce that we have a `cryptography` version that bundles an `openssl`
 # with the latest security patches.
 cryptography = ">=3.4.7"
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b7917a99d6..fa634d65c7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -279,6 +279,23 @@ class SyncResult:
             or self.device_lists
         )
 
+    @staticmethod
+    def empty(next_batch: StreamToken) -> "SyncResult":
+        "Return a new empty result"
+        return SyncResult(
+            next_batch=next_batch,
+            presence=[],
+            account_data=[],
+            joined=[],
+            invited=[],
+            knocked=[],
+            archived=[],
+            to_device=[],
+            device_lists=DeviceListUpdates(),
+            device_one_time_keys_count={},
+            device_unused_fallback_key_types=[],
+        )
+
 
 class SyncHandler:
     def __init__(self, hs: "HomeServer"):
@@ -401,6 +418,24 @@ class SyncHandler:
         if context:
             context.tag = sync_label
 
+        if since_token is not None:
+            # We need to make sure this worker has caught up with the token. If
+            # this returns false it means we timed out waiting, and we should
+            # just return an empty response.
+            start = self.clock.time_msec()
+            if not await self.notifier.wait_for_stream_token(since_token):
+                logger.warning(
+                    "Timed out waiting for worker to catch up. Returning empty response"
+                )
+                return SyncResult.empty(since_token)
+
+            # If we've spent significant time waiting to catch up, take it off
+            # the timeout.
+            now = self.clock.time_msec()
+            if now - start > 1_000:
+                timeout -= now - start
+                timeout = max(timeout, 0)
+
         # if we have a since token, delete any to-device messages before that token
         # (since we now know that the device has received them)
         if since_token is not None:
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 7c1cd3b5f2..ced9e9ad66 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -763,6 +763,29 @@ class Notifier:
 
         return result
 
+    async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
+        """Wait for this worker to catch up with the given stream token."""
+
+        start = self.clock.time_msec()
+        while True:
+            current_token = self.event_sources.get_current_token()
+            if stream_token.is_before_or_eq(current_token):
+                return True
+
+            now = self.clock.time_msec()
+
+            if now - start > 10_000:
+                return False
+
+            logger.info(
+                "Waiting for current token to reach %s; currently at %s",
+                stream_token,
+                current_token,
+            )
+
+            # TODO: be better
+            await self.clock.sleep(0.5)
+
     async def _get_room_ids(
         self, user: UserID, explicit_room_id: Optional[str]
     ) -> Tuple[StrCollection, bool]:
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 509a2d3a0f..151658df53 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -48,7 +48,7 @@ import attr
 from immutabledict import immutabledict
 from signedjson.key import decode_verify_key_bytes
 from signedjson.types import VerifyKey
-from typing_extensions import TypedDict
+from typing_extensions import Self, TypedDict
 from unpaddedbase64 import decode_base64
 from zope.interface import Interface
 
@@ -515,6 +515,27 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
         # at `self.stream`.
         return self.instance_map.get(instance_name, self.stream)
 
+    def is_before_or_eq(self, other_token: Self) -> bool:
+        """Wether this token is before the other token, i.e. every constituent
+        part is before the other.
+
+        Essentially it is `self <= other`.
+
+        Note: if `self.is_before_or_eq(other_token) is False` then that does not
+        imply that the reverse is True.
+        """
+        if self.stream > other_token.stream:
+            return False
+
+        instances = self.instance_map.keys() | other_token.instance_map.keys()
+        for instance in instances:
+            if self.instance_map.get(
+                instance, self.stream
+            ) > other_token.instance_map.get(instance, other_token.stream):
+                return False
+
+        return True
+
 
 @attr.s(frozen=True, slots=True, order=False)
 class RoomStreamToken(AbstractMultiWriterStreamToken):
@@ -1008,6 +1029,41 @@ class StreamToken:
         """Returns the stream ID for the given key."""
         return getattr(self, key.value)
 
+    def is_before_or_eq(self, other_token: "StreamToken") -> bool:
+        """Wether this token is before the other token, i.e. every constituent
+        part is before the other.
+
+        Essentially it is `self <= other`.
+
+        Note: if `self.is_before_or_eq(other_token) is False` then that does not
+        imply that the reverse is True.
+        """
+
+        for _, key in StreamKeyType.__members__.items():
+            if key == StreamKeyType.TYPING:
+                # Typing stream is allowed to "reset", and so comparisons don't
+                # really make sense as is.
+                # TODO: Figure out a better way of tracking resets.
+                continue
+
+            self_value = self.get_field(key)
+            other_value = other_token.get_field(key)
+
+            if isinstance(self_value, RoomStreamToken):
+                assert isinstance(other_value, RoomStreamToken)
+                if not self_value.is_before_or_eq(other_value):
+                    return False
+            elif isinstance(self_value, MultiWriterStreamToken):
+                assert isinstance(other_value, MultiWriterStreamToken)
+                if not self_value.is_before_or_eq(other_value):
+                    return False
+            else:
+                assert isinstance(other_value, int)
+                if self_value > other_value:
+                    return False
+
+        return True
+
 
 StreamToken.START = StreamToken(
     RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0