summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py6
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/federation/transaction_queue.py353
3 files changed, 215 insertions, 146 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 627acc6a4f..78719eed25 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -138,6 +138,12 @@ class FederationClient(FederationBase):
         return defer.succeed(None)
 
     @log_function
+    def send_device_messages(self, destination):
+        """Sends the device messages in the local database to the remote
+        destination"""
+        self._transaction_queue.enqueue_device_messages(destination)
+
+    @log_function
     def send_failure(self, failure, destination):
         self._transaction_queue.enqueue_failure(failure, destination)
         return defer.succeed(None)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5621655098..3fa7b2315c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -188,7 +188,7 @@ class FederationServer(FederationBase):
             except SynapseError as e:
                 logger.info("Failed to handle edu %r: %r", edu_type, e)
             except Exception as e:
-                logger.exception("Failed to handle edu %r", edu_type, e)
+                logger.exception("Failed to handle edu %r", edu_type)
         else:
             logger.warn("Received EDU of type %s with no handler", edu_type)
 
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index cb2ef0210c..1ac569b305 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -17,7 +17,7 @@
 from twisted.internet import defer
 
 from .persistence import TransactionActions
-from .units import Transaction
+from .units import Transaction, Edu
 
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
@@ -81,6 +81,8 @@ class TransactionQueue(object):
         # destination -> list of tuple(failure, deferred)
         self.pending_failures_by_dest = {}
 
+        self.last_device_stream_id_by_dest = {}
+
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
 
@@ -155,179 +157,240 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
+    def enqueue_device_messages(self, destination):
+        if destination == self.server_name or destination == "localhost":
+            return
+
+        if not self.can_send_to(destination):
+            return
+
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
+
     @defer.inlineCallbacks
     def _attempt_new_transaction(self, destination):
-        yield run_on_reactor()
-        while True:
-            # list of (pending_pdu, deferred, order)
-            if destination in self.pending_transactions:
-                # XXX: pending_transactions can get stuck on by a never-ending
-                # request at which point pending_pdus_by_dest just keeps growing.
-                # we need application-layer timeouts of some flavour of these
-                # requests
-                logger.debug(
-                    "TX [%s] Transaction already in progress",
-                    destination
-                )
-                return
+        # list of (pending_pdu, deferred, order)
+        if destination in self.pending_transactions:
+            # XXX: pending_transactions can get stuck on by a never-ending
+            # request at which point pending_pdus_by_dest just keeps growing.
+            # we need application-layer timeouts of some flavour of these
+            # requests
+            logger.debug(
+                "TX [%s] Transaction already in progress",
+                destination
+            )
+            return
 
-            pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
-            pending_edus = self.pending_edus_by_dest.pop(destination, [])
-            pending_failures = self.pending_failures_by_dest.pop(destination, [])
+        try:
+            self.pending_transactions[destination] = 1
 
-            if pending_pdus:
-                logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                             destination, len(pending_pdus))
+            yield run_on_reactor()
 
-            if not pending_pdus and not pending_edus and not pending_failures:
-                logger.debug("TX [%s] Nothing to send", destination)
-                return
+            while True:
+                    pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+                    pending_edus = self.pending_edus_by_dest.pop(destination, [])
+                    pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
-            yield self._send_new_transaction(
-                destination, pending_pdus, pending_edus, pending_failures
+                    limiter = yield get_retry_limiter(
+                        destination,
+                        self.clock,
+                        self.store,
+                    )
+
+                    device_message_edus, device_stream_id = (
+                        yield self._get_new_device_messages(destination)
+                    )
+
+                    pending_edus.extend(device_message_edus)
+
+                    if pending_pdus:
+                        logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                                     destination, len(pending_pdus))
+
+                    if not pending_pdus and not pending_edus and not pending_failures:
+                        logger.debug("TX [%s] Nothing to send", destination)
+                        self.last_device_stream_id_by_dest[destination] = (
+                            device_stream_id
+                        )
+                        return
+
+                    success = yield self._send_new_transaction(
+                        destination, pending_pdus, pending_edus, pending_failures,
+                        device_stream_id,
+                        should_delete_from_device_stream=bool(device_message_edus),
+                        limiter=limiter,
+                    )
+                    if not success:
+                        break
+        except NotRetryingDestination:
+            logger.info(
+                "TX [%s] not ready for retry yet - "
+                "dropping transaction for now",
+                destination,
+            )
+        finally:
+            # We want to be *very* sure we delete this after we stop processing
+            self.pending_transactions.pop(destination, None)
+
+    @defer.inlineCallbacks
+    def _get_new_device_messages(self, destination):
+        last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
+        to_device_stream_id = self.store.get_to_device_stream_token()
+        contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
+            destination, last_device_stream_id, to_device_stream_id
+        )
+        edus = [
+            Edu(
+                origin=self.server_name,
+                destination=destination,
+                edu_type="m.direct_to_device",
+                content=content,
             )
+            for content in contents
+        ]
+        defer.returnValue((edus, stream_id))
 
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures):
+                              pending_failures, device_stream_id,
+                              should_delete_from_device_stream, limiter):
 
-            # Sort based on the order field
-            pending_pdus.sort(key=lambda t: t[1])
-            pdus = [x[0] for x in pending_pdus]
-            edus = pending_edus
-            failures = [x.get_dict() for x in pending_failures]
+        # Sort based on the order field
+        pending_pdus.sort(key=lambda t: t[1])
+        pdus = [x[0] for x in pending_pdus]
+        edus = pending_edus
+        failures = [x.get_dict() for x in pending_failures]
 
-            try:
-                self.pending_transactions[destination] = 1
+        success = True
 
-                logger.debug("TX [%s] _attempt_new_transaction", destination)
+        try:
+            logger.debug("TX [%s] _attempt_new_transaction", destination)
 
-                txn_id = str(self._next_txn_id)
+            txn_id = str(self._next_txn_id)
 
-                limiter = yield get_retry_limiter(
-                    destination,
-                    self.clock,
-                    self.store,
-                )
+            logger.debug(
+                "TX [%s] {%s} Attempting new transaction"
+                " (pdus: %d, edus: %d, failures: %d)",
+                destination, txn_id,
+                len(pdus),
+                len(edus),
+                len(failures)
+            )
 
-                logger.debug(
-                    "TX [%s] {%s} Attempting new transaction"
-                    " (pdus: %d, edus: %d, failures: %d)",
-                    destination, txn_id,
-                    len(pending_pdus),
-                    len(pending_edus),
-                    len(pending_failures)
-                )
+            logger.debug("TX [%s] Persisting transaction...", destination)
 
-                logger.debug("TX [%s] Persisting transaction...", destination)
+            transaction = Transaction.create_new(
+                origin_server_ts=int(self.clock.time_msec()),
+                transaction_id=txn_id,
+                origin=self.server_name,
+                destination=destination,
+                pdus=pdus,
+                edus=edus,
+                pdu_failures=failures,
+            )
 
-                transaction = Transaction.create_new(
-                    origin_server_ts=int(self.clock.time_msec()),
-                    transaction_id=txn_id,
-                    origin=self.server_name,
-                    destination=destination,
-                    pdus=pdus,
-                    edus=edus,
-                    pdu_failures=failures,
-                )
+            self._next_txn_id += 1
+
+            yield self.transaction_actions.prepare_to_send(transaction)
 
-                self._next_txn_id += 1
+            logger.debug("TX [%s] Persisted transaction", destination)
+            logger.info(
+                "TX [%s] {%s} Sending transaction [%s],"
+                " (PDUs: %d, EDUs: %d, failures: %d)",
+                destination, txn_id,
+                transaction.transaction_id,
+                len(pdus),
+                len(edus),
+                len(failures),
+            )
 
-                yield self.transaction_actions.prepare_to_send(transaction)
+            with limiter:
+                # Actually send the transaction
+
+                # FIXME (erikj): This is a bit of a hack to make the Pdu age
+                # keys work
+                def json_data_cb():
+                    data = transaction.get_dict()
+                    now = int(self.clock.time_msec())
+                    if "pdus" in data:
+                        for p in data["pdus"]:
+                            if "age_ts" in p:
+                                unsigned = p.setdefault("unsigned", {})
+                                unsigned["age"] = now - int(p["age_ts"])
+                                del p["age_ts"]
+                    return data
+
+                try:
+                    response = yield self.transport_layer.send_transaction(
+                        transaction, json_data_cb
+                    )
+                    code = 200
+
+                    if response:
+                        for e_id, r in response.get("pdus", {}).items():
+                            if "error" in r:
+                                logger.warn(
+                                    "Transaction returned error for %s: %s",
+                                    e_id, r,
+                                )
+                except HttpResponseException as e:
+                    code = e.code
+                    response = e.response
 
-                logger.debug("TX [%s] Persisted transaction", destination)
                 logger.info(
-                    "TX [%s] {%s} Sending transaction [%s],"
-                    " (PDUs: %d, EDUs: %d, failures: %d)",
-                    destination, txn_id,
-                    transaction.transaction_id,
-                    len(pending_pdus),
-                    len(pending_edus),
-                    len(pending_failures),
+                    "TX [%s] {%s} got %d response",
+                    destination, txn_id, code
                 )
 
-                with limiter:
-                    # Actually send the transaction
-
-                    # FIXME (erikj): This is a bit of a hack to make the Pdu age
-                    # keys work
-                    def json_data_cb():
-                        data = transaction.get_dict()
-                        now = int(self.clock.time_msec())
-                        if "pdus" in data:
-                            for p in data["pdus"]:
-                                if "age_ts" in p:
-                                    unsigned = p.setdefault("unsigned", {})
-                                    unsigned["age"] = now - int(p["age_ts"])
-                                    del p["age_ts"]
-                        return data
-
-                    try:
-                        response = yield self.transport_layer.send_transaction(
-                            transaction, json_data_cb
-                        )
-                        code = 200
-
-                        if response:
-                            for e_id, r in response.get("pdus", {}).items():
-                                if "error" in r:
-                                    logger.warn(
-                                        "Transaction returned error for %s: %s",
-                                        e_id, r,
-                                    )
-                    except HttpResponseException as e:
-                        code = e.code
-                        response = e.response
+                logger.debug("TX [%s] Sent transaction", destination)
+                logger.debug("TX [%s] Marking as delivered...", destination)
 
-                    logger.info(
-                        "TX [%s] {%s} got %d response",
-                        destination, txn_id, code
-                    )
-
-                    logger.debug("TX [%s] Sent transaction", destination)
-                    logger.debug("TX [%s] Marking as delivered...", destination)
+            yield self.transaction_actions.delivered(
+                transaction, code, response
+            )
 
-                yield self.transaction_actions.delivered(
-                    transaction, code, response
-                )
+            logger.debug("TX [%s] Marked as delivered", destination)
 
-                logger.debug("TX [%s] Marked as delivered", destination)
+            if code != 200:
+                for p in pdus:
+                    logger.info(
+                        "Failed to send event %s to %s", p.event_id, destination
+                    )
+                success = False
+            else:
+                # Remove the acknowledged device messages from the database
+                if should_delete_from_device_stream:
+                    yield self.store.delete_device_msgs_for_remote(
+                        destination, device_stream_id
+                    )
+                self.last_device_stream_id_by_dest[destination] = device_stream_id
+        except RuntimeError as e:
+            # We capture this here as there as nothing actually listens
+            # for this finishing functions deferred.
+            logger.warn(
+                "TX [%s] Problem in _attempt_transaction: %s",
+                destination,
+                e,
+            )
 
-                if code != 200:
-                    for p in pdus:
-                        logger.info(
-                            "Failed to send event %s to %s", p.event_id, destination
-                        )
-            except NotRetryingDestination:
-                logger.info(
-                    "TX [%s] not ready for retry yet - "
-                    "dropping transaction for now",
-                    destination,
-                )
-            except RuntimeError as e:
-                # We capture this here as there as nothing actually listens
-                # for this finishing functions deferred.
-                logger.warn(
-                    "TX [%s] Problem in _attempt_transaction: %s",
-                    destination,
-                    e,
-                )
+            success = False
+
+            for p in pdus:
+                logger.info("Failed to send event %s to %s", p.event_id, destination)
+        except Exception as e:
+            # We capture this here as there as nothing actually listens
+            # for this finishing functions deferred.
+            logger.warn(
+                "TX [%s] Problem in _attempt_transaction: %s",
+                destination,
+                e,
+            )
 
-                for p in pdus:
-                    logger.info("Failed to send event %s to %s", p.event_id, destination)
-            except Exception as e:
-                # We capture this here as there as nothing actually listens
-                # for this finishing functions deferred.
-                logger.warn(
-                    "TX [%s] Problem in _attempt_transaction: %s",
-                    destination,
-                    e,
-                )
+            success = False
 
-                for p in pdus:
-                    logger.info("Failed to send event %s to %s", p.event_id, destination)
+            for p in pdus:
+                logger.info("Failed to send event %s to %s", p.event_id, destination)
 
-            finally:
-                # We want to be *very* sure we delete this after we stop processing
-                self.pending_transactions.pop(destination, None)
+        defer.returnValue(success)