summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py10
-rw-r--r--synapse/federation/federation_server.py3
-rw-r--r--synapse/federation/transaction_queue.py133
-rw-r--r--synapse/federation/transport/client.py26
-rw-r--r--synapse/federation/transport/server.py8
5 files changed, 123 insertions, 57 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c9175bb33d..b5bcfd705a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -127,6 +127,16 @@ class FederationClient(FederationBase):
         )
 
     @log_function
+    def query_user_devices(self, destination, user_id, timeout=30000):
+        """Query the device keys for a list of user ids hosted on a remote
+        server.
+        """
+        sent_queries_counter.inc("user_devices")
+        return self.transport_layer.query_user_devices(
+            destination, user_id, timeout
+        )
+
+    @log_function
     def claim_client_keys(self, destination, content, timeout):
         """Claims one-time keys for a device hosted on a remote server.
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 862ccbef5d..e922b7ff4a 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -416,6 +416,9 @@ class FederationServer(FederationBase):
     def on_query_client_keys(self, origin, content):
         return self.on_query_request("client_keys", content)
 
+    def on_query_user_devices(self, origin, user_id):
+        return self.on_query_request("user_devices", user_id)
+
     @defer.inlineCallbacks
     @log_function
     def on_claim_client_keys(self, origin, content):
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 6b3a7abb9e..d18f6b6cfd 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -100,6 +100,7 @@ class TransactionQueue(object):
         self.pending_failures_by_dest = {}
 
         self.last_device_stream_id_by_dest = {}
+        self.last_device_list_stream_id_by_dest = {}
 
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
@@ -305,62 +306,74 @@ class TransactionQueue(object):
             yield run_on_reactor()
 
             while True:
-                    pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
-                    pending_edus = self.pending_edus_by_dest.pop(destination, [])
-                    pending_presence = self.pending_presence_by_dest.pop(destination, {})
-                    pending_failures = self.pending_failures_by_dest.pop(destination, [])
+                pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+                pending_edus = self.pending_edus_by_dest.pop(destination, [])
+                pending_presence = self.pending_presence_by_dest.pop(destination, {})
+                pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
-                    pending_edus.extend(
-                        self.pending_edus_keyed_by_dest.pop(destination, {}).values()
-                    )
+                pending_edus.extend(
+                    self.pending_edus_keyed_by_dest.pop(destination, {}).values()
+                )
 
-                    limiter = yield get_retry_limiter(
-                        destination,
-                        self.clock,
-                        self.store,
-                    )
+                limiter = yield get_retry_limiter(
+                    destination,
+                    self.clock,
+                    self.store,
+                )
 
-                    device_message_edus, device_stream_id = (
-                        yield self._get_new_device_messages(destination)
-                    )
+                device_message_edus, device_stream_id, dev_list_id = (
+                    yield self._get_new_device_messages(destination)
+                )
 
-                    pending_edus.extend(device_message_edus)
-                    if pending_presence:
-                        pending_edus.append(
-                            Edu(
-                                origin=self.server_name,
-                                destination=destination,
-                                edu_type="m.presence",
-                                content={
-                                    "push": [
-                                        format_user_presence_state(
-                                            presence, self.clock.time_msec()
-                                        )
-                                        for presence in pending_presence.values()
-                                    ]
-                                },
-                            )
+                pending_edus.extend(device_message_edus)
+                if pending_presence:
+                    pending_edus.append(
+                        Edu(
+                            origin=self.server_name,
+                            destination=destination,
+                            edu_type="m.presence",
+                            content={
+                                "push": [
+                                    format_user_presence_state(
+                                        presence, self.clock.time_msec()
+                                    )
+                                    for presence in pending_presence.values()
+                                ]
+                            },
                         )
+                    )
+
+                if pending_pdus:
+                    logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                                 destination, len(pending_pdus))
 
-                    if pending_pdus:
-                        logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                                     destination, len(pending_pdus))
+                if not pending_pdus and not pending_edus and not pending_failures:
+                    logger.debug("TX [%s] Nothing to send", destination)
+                    self.last_device_stream_id_by_dest[destination] = (
+                        device_stream_id
+                    )
+                    return
 
-                    if not pending_pdus and not pending_edus and not pending_failures:
-                        logger.debug("TX [%s] Nothing to send", destination)
-                        self.last_device_stream_id_by_dest[destination] = (
-                            device_stream_id
+                success = yield self._send_new_transaction(
+                    destination, pending_pdus, pending_edus, pending_failures,
+                    limiter=limiter,
+                )
+                if success:
+                    # Remove the acknowledged device messages from the database
+                    # Only bother if we actually sent some device messages
+                    if device_message_edus:
+                        yield self.store.delete_device_msgs_for_remote(
+                            destination, device_stream_id
+                        )
+                        logger.info("Marking as sent %r %r", destination, dev_list_id)
+                        yield self.store.mark_as_sent_devices_by_remote(
+                            destination, dev_list_id
                         )
-                        return
 
-                    success = yield self._send_new_transaction(
-                        destination, pending_pdus, pending_edus, pending_failures,
-                        device_stream_id,
-                        should_delete_from_device_stream=bool(device_message_edus),
-                        limiter=limiter,
-                    )
-                    if not success:
-                        break
+                    self.last_device_stream_id_by_dest[destination] = device_stream_id
+                    self.last_device_list_stream_id_by_dest[destination] = dev_list_id
+                else:
+                    break
         except NotRetryingDestination:
             logger.debug(
                 "TX [%s] not ready for retry yet - "
@@ -387,13 +400,26 @@ class TransactionQueue(object):
             )
             for content in contents
         ]
-        defer.returnValue((edus, stream_id))
+
+        last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0)
+        now_stream_id, results = yield self.store.get_devices_by_remote(
+            destination, last_device_list
+        )
+        edus.extend(
+            Edu(
+                origin=self.server_name,
+                destination=destination,
+                edu_type="m.device_list_update",
+                content=content,
+            )
+            for content in results
+        )
+        defer.returnValue((edus, stream_id, now_stream_id))
 
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures, device_stream_id,
-                              should_delete_from_device_stream, limiter):
+                              pending_failures, limiter):
 
         # Sort based on the order field
         pending_pdus.sort(key=lambda t: t[1])
@@ -504,13 +530,6 @@ class TransactionQueue(object):
                         "Failed to send event %s to %s", p.event_id, destination
                     )
                 success = False
-            else:
-                # Remove the acknowledged device messages from the database
-                if should_delete_from_device_stream:
-                    yield self.store.delete_device_msgs_for_remote(
-                        destination, device_stream_id
-                    )
-                self.last_device_stream_id_by_dest[destination] = device_stream_id
         except RuntimeError as e:
             # We capture this here as there as nothing actually listens
             # for this finishing functions deferred.
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 915af34409..f49e8a2cc4 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -348,6 +348,32 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
+    def query_user_devices(self, destination, user_id, timeout):
+        """Query the devices for a user id hosted on a remote server.
+
+        Response:
+            {
+              "stream_id": "...",
+              "devices": [ { ... } ]
+            }
+
+        Args:
+            destination(str): The server to query.
+            query_content(dict): The user ids to query.
+        Returns:
+            A dict containg the device keys.
+        """
+        path = PREFIX + "/user/devices/" + user_id
+
+        content = yield self.client.get_json(
+            destination=destination,
+            path=path,
+            timeout=timeout,
+        )
+        defer.returnValue(content)
+
+    @defer.inlineCallbacks
+    @log_function
     def claim_client_keys(self, destination, query_content, timeout):
         """Claim one-time keys for a list of devices hosted on a remote server.
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 159dbd1747..c840da834c 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
         return self.handler.on_query_client_keys(origin, content)
 
 
+class FederationUserDevicesQueryServlet(BaseFederationServlet):
+    PATH = "/user/devices/(?P<user_id>[^/]*)"
+
+    def on_GET(self, origin, content, query, user_id):
+        return self.handler.on_query_user_devices(origin, user_id)
+
+
 class FederationClientKeysClaimServlet(BaseFederationServlet):
     PATH = "/user/keys/claim"
 
@@ -613,6 +620,7 @@ SERVLET_CLASSES = (
     FederationGetMissingEventsServlet,
     FederationEventAuthServlet,
     FederationClientKeysQueryServlet,
+    FederationUserDevicesQueryServlet,
     FederationClientKeysClaimServlet,
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,