summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py6
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/federation/transaction_queue.py66
3 files changed, 64 insertions, 10 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 627acc6a4f..78719eed25 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -138,6 +138,12 @@ class FederationClient(FederationBase):
         return defer.succeed(None)
 
     @log_function
+    def send_device_messages(self, destination):
+        """Sends the device messages in the local database to the remote
+        destination"""
+        self._transaction_queue.enqueue_device_messages(destination)
+
+    @log_function
     def send_failure(self, failure, destination):
         self._transaction_queue.enqueue_failure(failure, destination)
         return defer.succeed(None)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5621655098..3fa7b2315c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -188,7 +188,7 @@ class FederationServer(FederationBase):
             except SynapseError as e:
                 logger.info("Failed to handle edu %r: %r", edu_type, e)
             except Exception as e:
-                logger.exception("Failed to handle edu %r", edu_type, e)
+                logger.exception("Failed to handle edu %r", edu_type)
         else:
             logger.warn("Received EDU of type %s with no handler", edu_type)
 
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index cb2ef0210c..5c7245d383 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -17,7 +17,7 @@
 from twisted.internet import defer
 
 from .persistence import TransactionActions
-from .units import Transaction
+from .units import Transaction, Edu
 
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
@@ -81,6 +81,8 @@ class TransactionQueue(object):
         # destination -> list of tuple(failure, deferred)
         self.pending_failures_by_dest = {}
 
+        self.last_device_stream_id_by_dest = {}
+
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
 
@@ -155,6 +157,17 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
+    def enqueue_device_messages(self, destination):
+        if destination == self.server_name or destination == "localhost":
+            return
+
+        if not self.can_send_to(destination):
+            return
+
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
+
     @defer.inlineCallbacks
     def _attempt_new_transaction(self, destination):
         yield run_on_reactor()
@@ -175,22 +188,50 @@ class TransactionQueue(object):
             pending_edus = self.pending_edus_by_dest.pop(destination, [])
             pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
+            device_message_edus, device_stream_id = (
+                yield self._get_new_device_messages(destination)
+            )
+
+            pending_edus.extend(device_message_edus)
+
             if pending_pdus:
                 logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
                              destination, len(pending_pdus))
 
             if not pending_pdus and not pending_edus and not pending_failures:
                 logger.debug("TX [%s] Nothing to send", destination)
+                self.last_device_stream_id_by_dest[destination] = device_stream_id
                 return
 
             yield self._send_new_transaction(
-                destination, pending_pdus, pending_edus, pending_failures
+                destination, pending_pdus, pending_edus, pending_failures,
+                device_stream_id,
+                should_delete_from_device_stream=bool(device_message_edus)
             )
 
+    @defer.inlineCallbacks
+    def _get_new_device_messages(self, destination):
+        last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
+        to_device_stream_id = self.store.get_to_device_stream_token()
+        contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
+            destination, last_device_stream_id, to_device_stream_id
+        )
+        edus = [
+            Edu(
+                origin=self.server_name,
+                destination=destination,
+                edu_type="m.direct_to_device",
+                content=content,
+            )
+            for content in contents
+        ]
+        defer.returnValue((edus, stream_id))
+
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures):
+                              pending_failures, device_stream_id,
+                              should_delete_from_device_stream):
 
             # Sort based on the order field
             pending_pdus.sort(key=lambda t: t[1])
@@ -215,9 +256,9 @@ class TransactionQueue(object):
                     "TX [%s] {%s} Attempting new transaction"
                     " (pdus: %d, edus: %d, failures: %d)",
                     destination, txn_id,
-                    len(pending_pdus),
-                    len(pending_edus),
-                    len(pending_failures)
+                    len(pdus),
+                    len(edus),
+                    len(failures)
                 )
 
                 logger.debug("TX [%s] Persisting transaction...", destination)
@@ -242,9 +283,9 @@ class TransactionQueue(object):
                     " (PDUs: %d, EDUs: %d, failures: %d)",
                     destination, txn_id,
                     transaction.transaction_id,
-                    len(pending_pdus),
-                    len(pending_edus),
-                    len(pending_failures),
+                    len(pdus),
+                    len(edus),
+                    len(failures),
                 )
 
                 with limiter:
@@ -299,6 +340,13 @@ class TransactionQueue(object):
                         logger.info(
                             "Failed to send event %s to %s", p.event_id, destination
                         )
+                else:
+                    # Remove the acknowledged device messages from the database
+                    if should_delete_from_device_stream:
+                        yield self.store.delete_device_msgs_for_remote(
+                            destination, device_stream_id
+                        )
+                    self.last_device_stream_id_by_dest[destination] = device_stream_id
             except NotRetryingDestination:
                 logger.info(
                     "TX [%s] not ready for retry yet - "