summary refs log tree commit diff
path: root/synapse/federation/sender
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/sender')
-rw-r--r--synapse/federation/sender/per_destination_queue.py40
1 files changed, 26 insertions, 14 deletions
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 564c57203d..22a2735405 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -189,11 +189,21 @@ class PerDestinationQueue(object):
 
             pending_pdus = []
             while True:
-                device_message_edus, device_stream_id, dev_list_id = (
-                    # We have to keep 2 free slots for presence and rr_edus
-                    yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2)
+                # We have to keep 2 free slots for presence and rr_edus
+                limit = MAX_EDUS_PER_TRANSACTION - 2
+
+                device_update_edus, dev_list_id = (
+                    yield self._get_device_update_edus(limit)
+                )
+
+                limit -= len(device_update_edus)
+
+                to_device_edus, device_stream_id = (
+                    yield self._get_to_device_message_edus(limit)
                 )
 
+                pending_edus = device_update_edus + to_device_edus
+
                 # BEGIN CRITICAL SECTION
                 #
                 # In order to avoid a race condition, we need to make sure that
@@ -208,10 +218,6 @@ class PerDestinationQueue(object):
                 # We can only include at most 50 PDUs per transactions
                 pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
 
-                pending_edus = []
-
-                # We can only include at most 100 EDUs per transactions
-                # rr_edus and pending_presence take at most one slot each
                 pending_edus.extend(self._get_rr_edus(force_flush=False))
                 pending_presence = self._pending_presence
                 self._pending_presence = {}
@@ -232,7 +238,6 @@ class PerDestinationQueue(object):
                         )
                     )
 
-                pending_edus.extend(device_message_edus)
                 pending_edus.extend(
                     self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
                 )
@@ -272,10 +277,13 @@ class PerDestinationQueue(object):
                         sent_edus_by_type.labels(edu.edu_type).inc()
                     # Remove the acknowledged device messages from the database
                     # Only bother if we actually sent some device messages
-                    if device_message_edus:
+                    if to_device_edus:
                         yield self._store.delete_device_msgs_for_remote(
                             self._destination, device_stream_id
                         )
+
+                    # also mark the device updates as sent
+                    if device_update_edus:
                         logger.info(
                             "Marking as sent %r %r", self._destination, dev_list_id
                         )
@@ -347,7 +355,7 @@ class PerDestinationQueue(object):
         return pending_edus
 
     @defer.inlineCallbacks
-    def _get_new_device_messages(self, limit):
+    def _get_device_update_edus(self, limit):
         last_device_list = self._last_device_list_stream_id
 
         # Retrieve list of new device updates to send to the destination
@@ -366,15 +374,19 @@ class PerDestinationQueue(object):
 
         assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
 
+        defer.returnValue((edus, now_stream_id))
+
+    @defer.inlineCallbacks
+    def _get_to_device_message_edus(self, limit):
         last_device_stream_id = self._last_device_stream_id
         to_device_stream_id = self._store.get_to_device_stream_token()
         contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
             self._destination,
             last_device_stream_id,
             to_device_stream_id,
-            limit - len(edus),
+            limit,
         )
-        edus.extend(
+        edus = [
             Edu(
                 origin=self._server_name,
                 destination=self._destination,
@@ -382,6 +394,6 @@ class PerDestinationQueue(object):
                 content=content,
             )
             for content in contents
-        )
+        ]
 
-        defer.returnValue((edus, stream_id, now_stream_id))
+        defer.returnValue((edus, stream_id))