summary refs log tree commit diff
path: root/synapse/federation/transaction_queue.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/transaction_queue.py')
-rw-r--r--synapse/federation/transaction_queue.py65
1 files changed, 56 insertions, 9 deletions
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index cb2ef0210c..633c79c352 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -17,7 +17,7 @@
 from twisted.internet import defer
 
 from .persistence import TransactionActions
-from .units import Transaction
+from .units import Transaction, Edu
 
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
@@ -81,6 +81,8 @@ class TransactionQueue(object):
         # destination -> list of tuple(failure, deferred)
         self.pending_failures_by_dest = {}
 
+        self.last_device_stream_id_by_dest = {}
+
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
 
@@ -155,6 +157,17 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
+    def enqueue_device_messages(self, destination):
+        if destination == self.server_name or destination == "localhost":
+            return
+
+        if not self.can_send_to(destination):
+            return
+
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
+
     @defer.inlineCallbacks
     def _attempt_new_transaction(self, destination):
         yield run_on_reactor()
@@ -175,6 +188,12 @@ class TransactionQueue(object):
             pending_edus = self.pending_edus_by_dest.pop(destination, [])
             pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
+            device_message_edus, device_stream_id = (
+                yield self._get_new_device_messages(destination)
+            )
+
+            pending_edus.extend(device_message_edus)
+
             if pending_pdus:
                 logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
                              destination, len(pending_pdus))
@@ -184,13 +203,34 @@ class TransactionQueue(object):
                 return
 
             yield self._send_new_transaction(
-                destination, pending_pdus, pending_edus, pending_failures
+                destination, pending_pdus, pending_edus, pending_failures,
+                device_stream_id,
+                should_delete_from_device_stream=bool(device_message_edus)
             )
 
+    @defer.inlineCallbacks
+    def _get_new_device_messages(self, destination):
+        last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
+        to_device_stream_id = self.store.get_to_device_stream_token()
+        contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
+            destination, last_device_stream_id, to_device_stream_id
+        )
+        edus = [
+            Edu(
+                origin=self.server_name,
+                destination=destination,
+                edu_type="m.direct_to_device",
+                content=content,
+            )
+            for content in contents
+        ]
+        defer.returnValue((edus, stream_id))
+
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures):
+                              pending_failures, device_stream_id,
+                              should_delete_from_device_stream):
 
             # Sort based on the order field
             pending_pdus.sort(key=lambda t: t[1])
@@ -215,9 +255,9 @@ class TransactionQueue(object):
                     "TX [%s] {%s} Attempting new transaction"
                     " (pdus: %d, edus: %d, failures: %d)",
                     destination, txn_id,
-                    len(pending_pdus),
-                    len(pending_edus),
-                    len(pending_failures)
+                    len(pdus),
+                    len(edus),
+                    len(failures)
                 )
 
                 logger.debug("TX [%s] Persisting transaction...", destination)
@@ -242,9 +282,9 @@ class TransactionQueue(object):
                     " (PDUs: %d, EDUs: %d, failures: %d)",
                     destination, txn_id,
                     transaction.transaction_id,
-                    len(pending_pdus),
-                    len(pending_edus),
-                    len(pending_failures),
+                    len(pdus),
+                    len(edus),
+                    len(failures),
                 )
 
                 with limiter:
@@ -299,6 +339,13 @@ class TransactionQueue(object):
                         logger.info(
                             "Failed to send event %s to %s", p.event_id, destination
                         )
+                else:
+                    # Remove the acknowledged device messages from the database
+                    if should_delete_from_device_stream:
+                        yield self.store.delete_device_msgs_for_remote(
+                            destination, device_stream_id
+                        )
+                    self.last_device_stream_id_by_dest[destination] = device_stream_id
             except NotRetryingDestination:
                 logger.info(
                     "TX [%s] not ready for retry yet - "