summary refs log tree commit diff
path: root/synapse/federation/sender/per_destination_queue.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/sender/per_destination_queue.py')
-rw-r--r--synapse/federation/sender/per_destination_queue.py151
1 files changed, 86 insertions, 65 deletions
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index be99211003..22a2735405 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -33,12 +33,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage import UserPresenceState
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
+# This is defined in the Matrix spec and enforced by the receiver.
+MAX_EDUS_PER_TRANSACTION = 100
+
 logger = logging.getLogger(__name__)
 
 
 sent_edus_counter = Counter(
-    "synapse_federation_client_sent_edus",
-    "Total number of EDUs successfully sent",
+    "synapse_federation_client_sent_edus", "Total number of EDUs successfully sent"
 )
 
 sent_edus_by_type = Counter(
@@ -58,6 +60,7 @@ class PerDestinationQueue(object):
         destination (str): the server_name of the destination that we are managing
             transmission for.
     """
+
     def __init__(self, hs, transaction_manager, destination):
         self._server_name = hs.hostname
         self._clock = hs.get_clock()
@@ -68,17 +71,17 @@ class PerDestinationQueue(object):
         self.transmission_loop_running = False
 
         # a list of tuples of (pending pdu, order)
-        self._pending_pdus = []    # type: list[tuple[EventBase, int]]
-        self._pending_edus = []    # type: list[Edu]
+        self._pending_pdus = []  # type: list[tuple[EventBase, int]]
+        self._pending_edus = []  # type: list[Edu]
 
         # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
         # based on their key (e.g. typing events by room_id)
         # Map of (edu_type, key) -> Edu
-        self._pending_edus_keyed = {}   # type: dict[tuple[str, str], Edu]
+        self._pending_edus_keyed = {}  # type: dict[tuple[str, str], Edu]
 
         # Map of user_id -> UserPresenceState of pending presence to be sent to this
         # destination
-        self._pending_presence = {}   # type: dict[str, UserPresenceState]
+        self._pending_presence = {}  # type: dict[str, UserPresenceState]
 
         # room_id -> receipt_type -> user_id -> receipt_dict
         self._pending_rrs = {}
@@ -120,9 +123,7 @@ class PerDestinationQueue(object):
         Args:
             states (iterable[UserPresenceState]): presence to send
         """
-        self._pending_presence.update({
-            state.user_id: state for state in states
-        })
+        self._pending_presence.update({state.user_id: state for state in states})
         self.attempt_new_transaction()
 
     def queue_read_receipt(self, receipt):
@@ -132,14 +133,9 @@ class PerDestinationQueue(object):
         Args:
             receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued
         """
-        self._pending_rrs.setdefault(
-            receipt.room_id, {},
-        ).setdefault(
+        self._pending_rrs.setdefault(receipt.room_id, {}).setdefault(
             receipt.receipt_type, {}
-        )[receipt.user_id] = {
-            "event_ids": receipt.event_ids,
-            "data": receipt.data,
-        }
+        )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data}
 
     def flush_read_receipts_for_room(self, room_id):
         # if we don't have any read-receipts for this room, it may be that we've already
@@ -170,10 +166,7 @@ class PerDestinationQueue(object):
             # request at which point pending_pdus just keeps growing.
             # we need application-layer timeouts of some flavour of these
             # requests
-            logger.debug(
-                "TX [%s] Transaction already in progress",
-                self._destination
-            )
+            logger.debug("TX [%s] Transaction already in progress", self._destination)
             return
 
         logger.debug("TX [%s] Starting transaction loop", self._destination)
@@ -196,10 +189,21 @@ class PerDestinationQueue(object):
 
             pending_pdus = []
             while True:
-                device_message_edus, device_stream_id, dev_list_id = (
-                    yield self._get_new_device_messages()
+                # We have to keep 2 free slots for presence and rr_edus
+                limit = MAX_EDUS_PER_TRANSACTION - 2
+
+                device_update_edus, dev_list_id = (
+                    yield self._get_device_update_edus(limit)
+                )
+
+                limit -= len(device_update_edus)
+
+                to_device_edus, device_stream_id = (
+                    yield self._get_to_device_message_edus(limit)
                 )
 
+                pending_edus = device_update_edus + to_device_edus
+
                 # BEGIN CRITICAL SECTION
                 #
                 # In order to avoid a race condition, we need to make sure that
@@ -214,21 +218,7 @@ class PerDestinationQueue(object):
                 # We can only include at most 50 PDUs per transactions
                 pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
 
-                pending_edus = []
-
                 pending_edus.extend(self._get_rr_edus(force_flush=False))
-
-                # We can only include at most 100 EDUs per transactions
-                pending_edus.extend(self._pop_pending_edus(100 - len(pending_edus)))
-
-                pending_edus.extend(
-                    self._pending_edus_keyed.values()
-                )
-
-                self._pending_edus_keyed = {}
-
-                pending_edus.extend(device_message_edus)
-
                 pending_presence = self._pending_presence
                 self._pending_presence = {}
                 if pending_presence:
@@ -248,9 +238,22 @@ class PerDestinationQueue(object):
                         )
                     )
 
+                pending_edus.extend(
+                    self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
+                )
+                while (
+                    len(pending_edus) < MAX_EDUS_PER_TRANSACTION
+                    and self._pending_edus_keyed
+                ):
+                    _, val = self._pending_edus_keyed.popitem()
+                    pending_edus.append(val)
+
                 if pending_pdus:
-                    logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                                 self._destination, len(pending_pdus))
+                    logger.debug(
+                        "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                        self._destination,
+                        len(pending_pdus),
+                    )
 
                 if not pending_pdus and not pending_edus:
                     logger.debug("TX [%s] Nothing to send", self._destination)
@@ -259,7 +262,7 @@ class PerDestinationQueue(object):
 
                 # if we've decided to send a transaction anyway, and we have room, we
                 # may as well send any pending RRs
-                if len(pending_edus) < 100:
+                if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
                     pending_edus.extend(self._get_rr_edus(force_flush=True))
 
                 # END CRITICAL SECTION
@@ -274,10 +277,13 @@ class PerDestinationQueue(object):
                         sent_edus_by_type.labels(edu.edu_type).inc()
                     # Remove the acknowledged device messages from the database
                     # Only bother if we actually sent some device messages
-                    if device_message_edus:
+                    if to_device_edus:
                         yield self._store.delete_device_msgs_for_remote(
                             self._destination, device_stream_id
                         )
+
+                    # also mark the device updates as sent
+                    if device_update_edus:
                         logger.info(
                             "Marking as sent %r %r", self._destination, dev_list_id
                         )
@@ -303,22 +309,25 @@ class PerDestinationQueue(object):
         except HttpResponseException as e:
             logger.warning(
                 "TX [%s] Received %d response to transaction: %s",
-                self._destination, e.code, e,
+                self._destination,
+                e.code,
+                e,
             )
         except RequestSendFailed as e:
-            logger.warning("TX [%s] Failed to send transaction: %s", self._destination, e)
+            logger.warning(
+                "TX [%s] Failed to send transaction: %s", self._destination, e
+            )
 
             for p, _ in pending_pdus:
-                logger.info("Failed to send event %s to %s", p.event_id,
-                            self._destination)
+                logger.info(
+                    "Failed to send event %s to %s", p.event_id, self._destination
+                )
         except Exception:
-            logger.exception(
-                "TX [%s] Failed to send transaction",
-                self._destination,
-            )
+            logger.exception("TX [%s] Failed to send transaction", self._destination)
             for p, _ in pending_pdus:
-                logger.info("Failed to send event %s to %s", p.event_id,
-                            self._destination)
+                logger.info(
+                    "Failed to send event %s to %s", p.event_id, self._destination
+                )
         finally:
             # We want to be *very* sure we clear this after we stop processing
             self.transmission_loop_running = False
@@ -346,33 +355,45 @@ class PerDestinationQueue(object):
         return pending_edus
 
     @defer.inlineCallbacks
-    def _get_new_device_messages(self):
-        last_device_stream_id = self._last_device_stream_id
-        to_device_stream_id = self._store.get_to_device_stream_token()
-        contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
-            self._destination, last_device_stream_id, to_device_stream_id
+    def _get_device_update_edus(self, limit):
+        last_device_list = self._last_device_list_stream_id
+
+        # Retrieve list of new device updates to send to the destination
+        now_stream_id, results = yield self._store.get_devices_by_remote(
+            self._destination, last_device_list, limit=limit,
         )
         edus = [
             Edu(
                 origin=self._server_name,
                 destination=self._destination,
-                edu_type="m.direct_to_device",
+                edu_type="m.device_list_update",
                 content=content,
             )
-            for content in contents
+            for content in results
         ]
 
-        last_device_list = self._last_device_list_stream_id
-        now_stream_id, results = yield self._store.get_devices_by_remote(
-            self._destination, last_device_list
+        assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+
+        defer.returnValue((edus, now_stream_id))
+
+    @defer.inlineCallbacks
+    def _get_to_device_message_edus(self, limit):
+        last_device_stream_id = self._last_device_stream_id
+        to_device_stream_id = self._store.get_to_device_stream_token()
+        contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
+            self._destination,
+            last_device_stream_id,
+            to_device_stream_id,
+            limit,
         )
-        edus.extend(
+        edus = [
             Edu(
                 origin=self._server_name,
                 destination=self._destination,
-                edu_type="m.device_list_update",
+                edu_type="m.direct_to_device",
                 content=content,
             )
-            for content in results
-        )
-        defer.returnValue((edus, stream_id, now_stream_id))
+            for content in contents
+        ]
+
+        defer.returnValue((edus, stream_id))