summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py53
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/federation/transaction_queue.py403
-rw-r--r--synapse/federation/transport/client.py18
-rw-r--r--synapse/federation/transport/server.py10
5 files changed, 307 insertions, 179 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 627acc6a4f..06d0320b1a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -24,7 +24,6 @@ from synapse.api.errors import (
     CodeMessageException, HttpResponseException, SynapseError,
 )
 from synapse.util import unwrapFirstError
-from synapse.util.async import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@@ -122,8 +121,12 @@ class FederationClient(FederationBase):
             pdu.event_id
         )
 
+    def send_presence(self, destination, states):
+        if destination != self.server_name:
+            self._transaction_queue.enqueue_presence(destination, states)
+
     @log_function
-    def send_edu(self, destination, edu_type, content):
+    def send_edu(self, destination, edu_type, content, key=None):
         edu = Edu(
             origin=self.server_name,
             destination=destination,
@@ -134,10 +137,16 @@ class FederationClient(FederationBase):
         sent_edus_counter.inc()
 
         # TODO, add errback, etc.
-        self._transaction_queue.enqueue_edu(edu)
+        self._transaction_queue.enqueue_edu(edu, key=key)
         return defer.succeed(None)
 
     @log_function
+    def send_device_messages(self, destination):
+        """Sends the device messages in the local database to the remote
+        destination"""
+        self._transaction_queue.enqueue_device_messages(destination)
+
+    @log_function
     def send_failure(self, failure, destination):
         self._transaction_queue.enqueue_failure(failure, destination)
         return defer.succeed(None)
@@ -166,7 +175,7 @@ class FederationClient(FederationBase):
         )
 
     @log_function
-    def query_client_keys(self, destination, content):
+    def query_client_keys(self, destination, content, timeout):
         """Query device keys for a device hosted on a remote server.
 
         Args:
@@ -178,10 +187,12 @@ class FederationClient(FederationBase):
             response
         """
         sent_queries_counter.inc("client_device_keys")
-        return self.transport_layer.query_client_keys(destination, content)
+        return self.transport_layer.query_client_keys(
+            destination, content, timeout
+        )
 
     @log_function
-    def claim_client_keys(self, destination, content):
+    def claim_client_keys(self, destination, content, timeout):
         """Claims one-time keys for a device hosted on a remote server.
 
         Args:
@@ -193,7 +204,9 @@ class FederationClient(FederationBase):
             response
         """
         sent_queries_counter.inc("client_one_time_keys")
-        return self.transport_layer.claim_client_keys(destination, content)
+        return self.transport_layer.claim_client_keys(
+            destination, content, timeout
+        )
 
     @defer.inlineCallbacks
     @log_function
@@ -471,7 +484,7 @@ class FederationClient(FederationBase):
                 defer.DeferredList(deferreds, consumeErrors=True)
             )
             for success, result in res:
-                if success:
+                if success and result:
                     signed_events.append(result)
                     batch.discard(result.event_id)
 
@@ -705,24 +718,14 @@ class FederationClient(FederationBase):
 
         raise RuntimeError("Failed to send to any server.")
 
-    @defer.inlineCallbacks
-    def get_public_rooms(self, destinations):
-        results_by_server = {}
-
-        @defer.inlineCallbacks
-        def _get_result(s):
-            if s == self.server_name:
-                defer.returnValue()
-
-            try:
-                result = yield self.transport_layer.get_public_rooms(s)
-                results_by_server[s] = result
-            except:
-                logger.exception("Error getting room list from server %r", s)
-
-        yield concurrently_execute(_get_result, destinations, 3)
+    def get_public_rooms(self, destination, limit=None, since_token=None,
+                         search_filter=None):
+        if destination == self.server_name:
+            return
 
-        defer.returnValue(results_by_server)
+        return self.transport_layer.get_public_rooms(
+            destination, limit, since_token, search_filter
+        )
 
     @defer.inlineCallbacks
     def query_auth(self, destination, room_id, event_id, local_auth):
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5621655098..3fa7b2315c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -188,7 +188,7 @@ class FederationServer(FederationBase):
             except SynapseError as e:
                 logger.info("Failed to handle edu %r: %r", edu_type, e)
             except Exception as e:
-                logger.exception("Failed to handle edu %r", edu_type, e)
+                logger.exception("Failed to handle edu %r", edu_type)
         else:
             logger.warn("Received EDU of type %s with no handler", edu_type)
 
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index cb2ef0210c..f8ca93e4c3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -17,7 +17,7 @@
 from twisted.internet import defer
 
 from .persistence import TransactionActions
-from .units import Transaction
+from .units import Transaction, Edu
 
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
@@ -26,6 +26,7 @@ from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
 from synapse.util.metrics import measure_func
+from synapse.handlers.presence import format_user_presence_state
 import synapse.metrics
 
 import logging
@@ -69,18 +70,28 @@ class TransactionQueue(object):
         # destination -> list of tuple(edu, deferred)
         self.pending_edus_by_dest = edus = {}
 
+        # Presence needs to be separate as we send single aggragate EDUs
+        self.pending_presence_by_dest = presence = {}
+        self.pending_edus_keyed_by_dest = edus_keyed = {}
+
         metrics.register_callback(
             "pending_pdus",
             lambda: sum(map(len, pdus.values())),
         )
         metrics.register_callback(
             "pending_edus",
-            lambda: sum(map(len, edus.values())),
+            lambda: (
+                sum(map(len, edus.values()))
+                + sum(map(len, presence.values()))
+                + sum(map(len, edus_keyed.values()))
+            ),
         )
 
         # destination -> list of tuple(failure, deferred)
         self.pending_failures_by_dest = {}
 
+        self.last_device_stream_id_by_dest = {}
+
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
 
@@ -128,13 +139,27 @@ class TransactionQueue(object):
                 self._attempt_new_transaction, destination
             )
 
-    def enqueue_edu(self, edu):
+    def enqueue_presence(self, destination, states):
+        self.pending_presence_by_dest.setdefault(destination, {}).update({
+            state.user_id: state for state in states
+        })
+
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
+
+    def enqueue_edu(self, edu, key=None):
         destination = edu.destination
 
         if not self.can_send_to(destination):
             return
 
-        self.pending_edus_by_dest.setdefault(destination, []).append(edu)
+        if key:
+            self.pending_edus_keyed_by_dest.setdefault(
+                destination, {}
+            )[(edu.edu_type, key)] = edu
+        else:
+            self.pending_edus_by_dest.setdefault(destination, []).append(edu)
 
         preserve_context_over_fn(
             self._attempt_new_transaction, destination
@@ -155,179 +180,261 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
+    def enqueue_device_messages(self, destination):
+        if destination == self.server_name or destination == "localhost":
+            return
+
+        if not self.can_send_to(destination):
+            return
+
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
+
     @defer.inlineCallbacks
     def _attempt_new_transaction(self, destination):
-        yield run_on_reactor()
-        while True:
-            # list of (pending_pdu, deferred, order)
-            if destination in self.pending_transactions:
-                # XXX: pending_transactions can get stuck on by a never-ending
-                # request at which point pending_pdus_by_dest just keeps growing.
-                # we need application-layer timeouts of some flavour of these
-                # requests
-                logger.debug(
-                    "TX [%s] Transaction already in progress",
-                    destination
-                )
-                return
+        # list of (pending_pdu, deferred, order)
+        if destination in self.pending_transactions:
+            # XXX: pending_transactions can get stuck on by a never-ending
+            # request at which point pending_pdus_by_dest just keeps growing.
+            # we need application-layer timeouts of some flavour of these
+            # requests
+            logger.debug(
+                "TX [%s] Transaction already in progress",
+                destination
+            )
+            return
+
+        try:
+            self.pending_transactions[destination] = 1
+
+            yield run_on_reactor()
+
+            while True:
+                    pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+                    pending_edus = self.pending_edus_by_dest.pop(destination, [])
+                    pending_presence = self.pending_presence_by_dest.pop(destination, {})
+                    pending_failures = self.pending_failures_by_dest.pop(destination, [])
+
+                    pending_edus.extend(
+                        self.pending_edus_keyed_by_dest.pop(destination, {}).values()
+                    )
+
+                    limiter = yield get_retry_limiter(
+                        destination,
+                        self.clock,
+                        self.store,
+                    )
+
+                    device_message_edus, device_stream_id = (
+                        yield self._get_new_device_messages(destination)
+                    )
+
+                    pending_edus.extend(device_message_edus)
+                    if pending_presence:
+                        pending_edus.append(
+                            Edu(
+                                origin=self.server_name,
+                                destination=destination,
+                                edu_type="m.presence",
+                                content={
+                                    "push": [
+                                        format_user_presence_state(
+                                            presence, self.clock.time_msec()
+                                        )
+                                        for presence in pending_presence.values()
+                                    ]
+                                },
+                            )
+                        )
 
-            pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
-            pending_edus = self.pending_edus_by_dest.pop(destination, [])
-            pending_failures = self.pending_failures_by_dest.pop(destination, [])
+                    if pending_pdus:
+                        logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                                     destination, len(pending_pdus))
 
-            if pending_pdus:
-                logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                             destination, len(pending_pdus))
+                    if not pending_pdus and not pending_edus and not pending_failures:
+                        logger.debug("TX [%s] Nothing to send", destination)
+                        self.last_device_stream_id_by_dest[destination] = (
+                            device_stream_id
+                        )
+                        return
 
-            if not pending_pdus and not pending_edus and not pending_failures:
-                logger.debug("TX [%s] Nothing to send", destination)
-                return
+                    success = yield self._send_new_transaction(
+                        destination, pending_pdus, pending_edus, pending_failures,
+                        device_stream_id,
+                        should_delete_from_device_stream=bool(device_message_edus),
+                        limiter=limiter,
+                    )
+                    if not success:
+                        break
+        except NotRetryingDestination:
+            logger.info(
+                "TX [%s] not ready for retry yet - "
+                "dropping transaction for now",
+                destination,
+            )
+        finally:
+            # We want to be *very* sure we delete this after we stop processing
+            self.pending_transactions.pop(destination, None)
 
-            yield self._send_new_transaction(
-                destination, pending_pdus, pending_edus, pending_failures
+    @defer.inlineCallbacks
+    def _get_new_device_messages(self, destination):
+        last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
+        to_device_stream_id = self.store.get_to_device_stream_token()
+        contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
+            destination, last_device_stream_id, to_device_stream_id
+        )
+        edus = [
+            Edu(
+                origin=self.server_name,
+                destination=destination,
+                edu_type="m.direct_to_device",
+                content=content,
             )
+            for content in contents
+        ]
+        defer.returnValue((edus, stream_id))
 
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures):
+                              pending_failures, device_stream_id,
+                              should_delete_from_device_stream, limiter):
 
-            # Sort based on the order field
-            pending_pdus.sort(key=lambda t: t[1])
-            pdus = [x[0] for x in pending_pdus]
-            edus = pending_edus
-            failures = [x.get_dict() for x in pending_failures]
+        # Sort based on the order field
+        pending_pdus.sort(key=lambda t: t[1])
+        pdus = [x[0] for x in pending_pdus]
+        edus = pending_edus
+        failures = [x.get_dict() for x in pending_failures]
 
-            try:
-                self.pending_transactions[destination] = 1
+        success = True
 
-                logger.debug("TX [%s] _attempt_new_transaction", destination)
+        try:
+            logger.debug("TX [%s] _attempt_new_transaction", destination)
 
-                txn_id = str(self._next_txn_id)
+            txn_id = str(self._next_txn_id)
 
-                limiter = yield get_retry_limiter(
-                    destination,
-                    self.clock,
-                    self.store,
-                )
+            logger.debug(
+                "TX [%s] {%s} Attempting new transaction"
+                " (pdus: %d, edus: %d, failures: %d)",
+                destination, txn_id,
+                len(pdus),
+                len(edus),
+                len(failures)
+            )
 
-                logger.debug(
-                    "TX [%s] {%s} Attempting new transaction"
-                    " (pdus: %d, edus: %d, failures: %d)",
-                    destination, txn_id,
-                    len(pending_pdus),
-                    len(pending_edus),
-                    len(pending_failures)
-                )
+            logger.debug("TX [%s] Persisting transaction...", destination)
 
-                logger.debug("TX [%s] Persisting transaction...", destination)
+            transaction = Transaction.create_new(
+                origin_server_ts=int(self.clock.time_msec()),
+                transaction_id=txn_id,
+                origin=self.server_name,
+                destination=destination,
+                pdus=pdus,
+                edus=edus,
+                pdu_failures=failures,
+            )
 
-                transaction = Transaction.create_new(
-                    origin_server_ts=int(self.clock.time_msec()),
-                    transaction_id=txn_id,
-                    origin=self.server_name,
-                    destination=destination,
-                    pdus=pdus,
-                    edus=edus,
-                    pdu_failures=failures,
-                )
+            self._next_txn_id += 1
 
-                self._next_txn_id += 1
+            yield self.transaction_actions.prepare_to_send(transaction)
+
+            logger.debug("TX [%s] Persisted transaction", destination)
+            logger.info(
+                "TX [%s] {%s} Sending transaction [%s],"
+                " (PDUs: %d, EDUs: %d, failures: %d)",
+                destination, txn_id,
+                transaction.transaction_id,
+                len(pdus),
+                len(edus),
+                len(failures),
+            )
 
-                yield self.transaction_actions.prepare_to_send(transaction)
+            with limiter:
+                # Actually send the transaction
+
+                # FIXME (erikj): This is a bit of a hack to make the Pdu age
+                # keys work
+                def json_data_cb():
+                    data = transaction.get_dict()
+                    now = int(self.clock.time_msec())
+                    if "pdus" in data:
+                        for p in data["pdus"]:
+                            if "age_ts" in p:
+                                unsigned = p.setdefault("unsigned", {})
+                                unsigned["age"] = now - int(p["age_ts"])
+                                del p["age_ts"]
+                    return data
+
+                try:
+                    response = yield self.transport_layer.send_transaction(
+                        transaction, json_data_cb
+                    )
+                    code = 200
+
+                    if response:
+                        for e_id, r in response.get("pdus", {}).items():
+                            if "error" in r:
+                                logger.warn(
+                                    "Transaction returned error for %s: %s",
+                                    e_id, r,
+                                )
+                except HttpResponseException as e:
+                    code = e.code
+                    response = e.response
 
-                logger.debug("TX [%s] Persisted transaction", destination)
                 logger.info(
-                    "TX [%s] {%s} Sending transaction [%s],"
-                    " (PDUs: %d, EDUs: %d, failures: %d)",
-                    destination, txn_id,
-                    transaction.transaction_id,
-                    len(pending_pdus),
-                    len(pending_edus),
-                    len(pending_failures),
+                    "TX [%s] {%s} got %d response",
+                    destination, txn_id, code
                 )
 
-                with limiter:
-                    # Actually send the transaction
-
-                    # FIXME (erikj): This is a bit of a hack to make the Pdu age
-                    # keys work
-                    def json_data_cb():
-                        data = transaction.get_dict()
-                        now = int(self.clock.time_msec())
-                        if "pdus" in data:
-                            for p in data["pdus"]:
-                                if "age_ts" in p:
-                                    unsigned = p.setdefault("unsigned", {})
-                                    unsigned["age"] = now - int(p["age_ts"])
-                                    del p["age_ts"]
-                        return data
-
-                    try:
-                        response = yield self.transport_layer.send_transaction(
-                            transaction, json_data_cb
-                        )
-                        code = 200
-
-                        if response:
-                            for e_id, r in response.get("pdus", {}).items():
-                                if "error" in r:
-                                    logger.warn(
-                                        "Transaction returned error for %s: %s",
-                                        e_id, r,
-                                    )
-                    except HttpResponseException as e:
-                        code = e.code
-                        response = e.response
-
-                    logger.info(
-                        "TX [%s] {%s} got %d response",
-                        destination, txn_id, code
-                    )
+                logger.debug("TX [%s] Sent transaction", destination)
+                logger.debug("TX [%s] Marking as delivered...", destination)
 
-                    logger.debug("TX [%s] Sent transaction", destination)
-                    logger.debug("TX [%s] Marking as delivered...", destination)
+            yield self.transaction_actions.delivered(
+                transaction, code, response
+            )
 
-                yield self.transaction_actions.delivered(
-                    transaction, code, response
-                )
+            logger.debug("TX [%s] Marked as delivered", destination)
 
-                logger.debug("TX [%s] Marked as delivered", destination)
+            if code != 200:
+                for p in pdus:
+                    logger.info(
+                        "Failed to send event %s to %s", p.event_id, destination
+                    )
+                success = False
+            else:
+                # Remove the acknowledged device messages from the database
+                if should_delete_from_device_stream:
+                    yield self.store.delete_device_msgs_for_remote(
+                        destination, device_stream_id
+                    )
+                self.last_device_stream_id_by_dest[destination] = device_stream_id
+        except RuntimeError as e:
+            # We capture this here as there as nothing actually listens
+            # for this finishing functions deferred.
+            logger.warn(
+                "TX [%s] Problem in _attempt_transaction: %s",
+                destination,
+                e,
+            )
 
-                if code != 200:
-                    for p in pdus:
-                        logger.info(
-                            "Failed to send event %s to %s", p.event_id, destination
-                        )
-            except NotRetryingDestination:
-                logger.info(
-                    "TX [%s] not ready for retry yet - "
-                    "dropping transaction for now",
-                    destination,
-                )
-            except RuntimeError as e:
-                # We capture this here as there as nothing actually listens
-                # for this finishing functions deferred.
-                logger.warn(
-                    "TX [%s] Problem in _attempt_transaction: %s",
-                    destination,
-                    e,
-                )
+            success = False
+
+            for p in pdus:
+                logger.info("Failed to send event %s to %s", p.event_id, destination)
+        except Exception as e:
+            # We capture this here as there as nothing actually listens
+            # for this finishing functions deferred.
+            logger.warn(
+                "TX [%s] Problem in _attempt_transaction: %s",
+                destination,
+                e,
+            )
 
-                for p in pdus:
-                    logger.info("Failed to send event %s to %s", p.event_id, destination)
-            except Exception as e:
-                # We capture this here as there as nothing actually listens
-                # for this finishing functions deferred.
-                logger.warn(
-                    "TX [%s] Problem in _attempt_transaction: %s",
-                    destination,
-                    e,
-                )
+            success = False
 
-                for p in pdus:
-                    logger.info("Failed to send event %s to %s", p.event_id, destination)
+            for p in pdus:
+                logger.info("Failed to send event %s to %s", p.event_id, destination)
 
-            finally:
-                # We want to be *very* sure we delete this after we stop processing
-                self.pending_transactions.pop(destination, None)
+        defer.returnValue(success)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 3d088e43cb..db45c7826c 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -248,12 +248,22 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def get_public_rooms(self, remote_server):
+    def get_public_rooms(self, remote_server, limit, since_token,
+                         search_filter=None):
         path = PREFIX + "/publicRooms"
 
+        args = {}
+        if limit:
+            args["limit"] = [str(limit)]
+        if since_token:
+            args["since"] = [since_token]
+
+        # TODO(erikj): Actually send the search_filter across federation.
+
         response = yield self.client.get_json(
             destination=remote_server,
             path=path,
+            args=args,
         )
 
         defer.returnValue(response)
@@ -298,7 +308,7 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def query_client_keys(self, destination, query_content):
+    def query_client_keys(self, destination, query_content, timeout):
         """Query the device keys for a list of user ids hosted on a remote
         server.
 
@@ -327,12 +337,13 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=query_content,
+            timeout=timeout,
         )
         defer.returnValue(content)
 
     @defer.inlineCallbacks
     @log_function
-    def claim_client_keys(self, destination, query_content):
+    def claim_client_keys(self, destination, query_content, timeout):
         """Claim one-time keys for a list of devices hosted on a remote server.
 
         Request:
@@ -363,6 +374,7 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=query_content,
+            timeout=timeout,
         )
         defer.returnValue(content)
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 37c0d4fbc4..fec337be64 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,7 +18,9 @@ from twisted.internet import defer
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import JsonResource
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import (
+    parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
+)
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
 
@@ -554,7 +556,11 @@ class PublicRoomList(BaseFederationServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query):
-        data = yield self.room_list_handler.get_local_public_room_list()
+        limit = parse_integer_from_args(query, "limit", 0)
+        since_token = parse_string_from_args(query, "since", None)
+        data = yield self.room_list_handler.get_local_public_room_list(
+            limit, since_token
+        )
         defer.returnValue((200, data))