summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17371.misc1
-rw-r--r--changelog.d/17391.bugfix1
-rw-r--r--synapse/federation/sender/per_destination_queue.py31
-rw-r--r--synapse/streams/events.py26
-rw-r--r--tests/federation/test_federation_sender.py119
-rw-r--r--tests/handlers/test_sync.py37
6 files changed, 194 insertions, 21 deletions
diff --git a/changelog.d/17371.misc b/changelog.d/17371.misc
new file mode 100644
index 0000000000..0fbf19f4fb
--- /dev/null
+++ b/changelog.d/17371.misc
@@ -0,0 +1 @@
+Limit size of presence EDUs to 50 entries.
diff --git a/changelog.d/17391.bugfix b/changelog.d/17391.bugfix
new file mode 100644
index 0000000000..9686b5c276
--- /dev/null
+++ b/changelog.d/17391.bugfix
@@ -0,0 +1 @@
+Fix bug where `/sync` requests could get blocked indefinitely after an upgrade from Synapse versions before v1.109.0.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index d9f2f017ed..9f1c2fe22a 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -21,6 +21,7 @@
 #
 import datetime
 import logging
+from collections import OrderedDict
 from types import TracebackType
 from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type
 
@@ -68,6 +69,10 @@ sent_edus_by_type = Counter(
 # If the retry interval is larger than this then we enter "catchup" mode
 CATCHUP_RETRY_INTERVAL = 60 * 60 * 1000
 
+# Limit how many presence states we add to each presence EDU, to ensure that
+# they are bounded in size.
+MAX_PRESENCE_STATES_PER_EDU = 50
+
 
 class PerDestinationQueue:
     """
@@ -144,7 +149,7 @@ class PerDestinationQueue:
 
         # Map of user_id -> UserPresenceState of pending presence to be sent to this
         # destination
-        self._pending_presence: Dict[str, UserPresenceState] = {}
+        self._pending_presence: OrderedDict[str, UserPresenceState] = OrderedDict()
 
         # List of room_id -> receipt_type -> user_id -> receipt_dict,
         #
@@ -399,7 +404,7 @@ class PerDestinationQueue:
                 # through another mechanism, because this is all volatile!
                 self._pending_edus = []
                 self._pending_edus_keyed = {}
-                self._pending_presence = {}
+                self._pending_presence.clear()
                 self._pending_receipt_edus = []
 
                 self._start_catching_up()
@@ -721,22 +726,26 @@ class _TransactionQueueManager:
 
         # Add presence EDU.
         if self.queue._pending_presence:
+            # Only send max 50 presence entries in the EDU, to bound the amount
+            # of data we're sending.
+            presence_to_add: List[JsonDict] = []
+            while (
+                self.queue._pending_presence
+                and len(presence_to_add) < MAX_PRESENCE_STATES_PER_EDU
+            ):
+                _, presence = self.queue._pending_presence.popitem(last=False)
+                presence_to_add.append(
+                    format_user_presence_state(presence, self.queue._clock.time_msec())
+                )
+
             pending_edus.append(
                 Edu(
                     origin=self.queue._server_name,
                     destination=self.queue._destination,
                     edu_type=EduTypes.PRESENCE,
-                    content={
-                        "push": [
-                            format_user_presence_state(
-                                presence, self.queue._clock.time_msec()
-                            )
-                            for presence in self.queue._pending_presence.values()
-                        ]
-                    },
+                    content={"push": presence_to_add},
                 )
             )
-            self.queue._pending_presence = {}
 
         # Add read receipt EDUs.
         pending_edus.extend(self.queue._get_receipt_edus(force_flush=False, limit=5))
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 93d5ae1a55..856f646795 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -19,6 +19,7 @@
 #
 #
 
+import logging
 from typing import TYPE_CHECKING, Sequence, Tuple
 
 import attr
@@ -41,6 +42,9 @@ if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
+logger = logging.getLogger(__name__)
+
+
 @attr.s(frozen=True, slots=True, auto_attribs=True)
 class _EventSourcesInner:
     room: RoomEventSource
@@ -139,9 +143,16 @@ class EventSources:
                         key
                     ].get_max_allocated_token()
 
-                    token = token.copy_and_replace(
-                        key, token.room_key.bound_stream_token(max_token)
-                    )
+                    if max_token < token_value.get_max_stream_pos():
+                        logger.error(
+                            "Bounding token from the future '%s': token: %s, bound: %s",
+                            key,
+                            token_value,
+                            max_token,
+                        )
+                        token = token.copy_and_replace(
+                            key, token_value.bound_stream_token(max_token)
+                        )
             else:
                 assert isinstance(current_value, int)
                 if current_value < token_value:
@@ -149,7 +160,14 @@ class EventSources:
                         key
                     ].get_max_allocated_token()
 
-                    token = token.copy_and_replace(key, min(token_value, max_token))
+                    if max_token < token_value:
+                        logger.error(
+                            "Bounding token from the future '%s': token: %s, bound: %s",
+                            key,
+                            token_value,
+                            max_token,
+                        )
+                        token = token.copy_and_replace(key, max_token)
 
         return token
 
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 9073afc70e..6a8887fe74 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -27,6 +27,8 @@ from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
+from synapse.api.presence import UserPresenceState
+from synapse.federation.sender.per_destination_queue import MAX_PRESENCE_STATES_PER_EDU
 from synapse.federation.units import Transaction
 from synapse.handlers.device import DeviceHandler
 from synapse.rest import admin
@@ -266,6 +268,123 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         )
 
 
+class FederationSenderPresenceTestCases(HomeserverTestCase):
+    """
+    Test federation sending for presence updates.
+    """
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.federation_transport_client = Mock(spec=["send_transaction"])
+        self.federation_transport_client.send_transaction = AsyncMock()
+        hs = self.setup_test_homeserver(
+            federation_transport_client=self.federation_transport_client,
+        )
+
+        return hs
+
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        config["federation_sender_instances"] = None
+        return config
+
+    def test_presence_simple(self) -> None:
+        "Test that sending a single presence update works"
+
+        mock_send_transaction: AsyncMock = (
+            self.federation_transport_client.send_transaction
+        )
+        mock_send_transaction.return_value = {}
+
+        sender = self.hs.get_federation_sender()
+        self.get_success(
+            sender.send_presence_to_destinations(
+                [UserPresenceState.default("@user:test")],
+                ["server"],
+            )
+        )
+
+        self.pump()
+
+        # expect a call to send_transaction
+        mock_send_transaction.assert_awaited_once()
+
+        json_cb = mock_send_transaction.call_args[0][1]
+        data = json_cb()
+        self.assertEqual(
+            data["edus"],
+            [
+                {
+                    "edu_type": EduTypes.PRESENCE,
+                    "content": {
+                        "push": [
+                            {
+                                "presence": "offline",
+                                "user_id": "@user:test",
+                            }
+                        ]
+                    },
+                }
+            ],
+        )
+
+    def test_presence_batched(self) -> None:
+        """Test that sending lots of presence updates to a destination are
+        batched, rather than having them all sent in one EDU."""
+
+        mock_send_transaction: AsyncMock = (
+            self.federation_transport_client.send_transaction
+        )
+        mock_send_transaction.return_value = {}
+
+        sender = self.hs.get_federation_sender()
+
+        # We now send lots of presence updates to force the federation sender to
+        # batch the mup.
+        number_presence_updates_to_send = MAX_PRESENCE_STATES_PER_EDU * 2
+        self.get_success(
+            sender.send_presence_to_destinations(
+                [
+                    UserPresenceState.default(f"@user{i}:test")
+                    for i in range(number_presence_updates_to_send)
+                ],
+                ["server"],
+            )
+        )
+
+        self.pump()
+
+        # We should have seen at least one transcation be sent by now.
+        mock_send_transaction.assert_called()
+
+        # We don't want to specify exactly how the presence EDUs get sent out,
+        # could be one per transaction or multiple per transaction. We just want
+        # to assert that a) each presence EDU has bounded number of updates, and
+        # b) that all updates get sent out.
+        presence_edus = []
+        for transaction_call in mock_send_transaction.call_args_list:
+            json_cb = transaction_call[0][1]
+            data = json_cb()
+
+            for edu in data["edus"]:
+                self.assertEqual(edu.get("edu_type"), EduTypes.PRESENCE)
+                presence_edus.append(edu)
+
+        # A set of all user presence we see, this should end up matching the
+        # number we sent out above.
+        seen_users: Set[str] = set()
+
+        for edu in presence_edus:
+            presence_states = edu["content"]["push"]
+
+            # This is where we actually check that the number of presence
+            # updates is bounded.
+            self.assertLessEqual(len(presence_states), MAX_PRESENCE_STATES_PER_EDU)
+
+            seen_users.update(p["user_id"] for p in presence_states)
+
+        self.assertEqual(len(seen_users), number_presence_updates_to_send)
+
+
 class FederationSenderDevicesTestCases(HomeserverTestCase):
     """
     Test federation sending to update devices.
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 5319928c28..674dd4fb54 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -36,7 +36,14 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe
 from synapse.rest import admin
 from synapse.rest.client import knock, login, room
 from synapse.server import HomeServer
-from synapse.types import JsonDict, StreamKeyType, UserID, create_requester
+from synapse.types import (
+    JsonDict,
+    MultiWriterStreamToken,
+    RoomStreamToken,
+    StreamKeyType,
+    UserID,
+    create_requester,
+)
 from synapse.util import Clock
 
 import tests.unittest
@@ -999,7 +1006,13 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
 
         self.get_success(sync_d, by=1.0)
 
-    def test_wait_for_invalid_future_sync_token(self) -> None:
+    @parameterized.expand(
+        [(key,) for key in StreamKeyType.__members__.values()],
+        name_func=lambda func, _, param: f"{func.__name__}_{param.args[0].name}",
+    )
+    def test_wait_for_invalid_future_sync_token(
+        self, stream_key: StreamKeyType
+    ) -> None:
         """Like the previous test, except we give a token that has a stream
         position ahead of what is in the DB, i.e. its invalid and we shouldn't
         wait for the stream to advance (as it may never do so).
@@ -1010,11 +1023,23 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         """
         user = self.register_user("alice", "password")
 
-        # Create a token and arbitrarily advance one of the streams.
+        # Create a token and advance one of the streams.
         current_token = self.hs.get_event_sources().get_current_token()
-        since_token = current_token.copy_and_advance(
-            StreamKeyType.PRESENCE, current_token.presence_key + 1
-        )
+        token_value = current_token.get_field(stream_key)
+
+        # How we advance the streams depends on the type.
+        if isinstance(token_value, int):
+            since_token = current_token.copy_and_advance(stream_key, token_value + 1)
+        elif isinstance(token_value, MultiWriterStreamToken):
+            since_token = current_token.copy_and_advance(
+                stream_key, MultiWriterStreamToken(stream=token_value.stream + 1)
+            )
+        elif isinstance(token_value, RoomStreamToken):
+            since_token = current_token.copy_and_advance(
+                stream_key, RoomStreamToken(stream=token_value.stream + 1)
+            )
+        else:
+            raise Exception("Unreachable")
 
         sync_d = defer.ensureDeferred(
             self.sync_handler.wait_for_sync_for_user(