summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/transaction_queue.py121
-rw-r--r--synapse/handlers/device.py1
-rw-r--r--synapse/storage/devices.py40
3 files changed, 82 insertions, 80 deletions
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 65c6673a87..d18f6b6cfd 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -306,62 +306,74 @@ class TransactionQueue(object):
             yield run_on_reactor()
 
             while True:
-                    pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
-                    pending_edus = self.pending_edus_by_dest.pop(destination, [])
-                    pending_presence = self.pending_presence_by_dest.pop(destination, {})
-                    pending_failures = self.pending_failures_by_dest.pop(destination, [])
+                pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+                pending_edus = self.pending_edus_by_dest.pop(destination, [])
+                pending_presence = self.pending_presence_by_dest.pop(destination, {})
+                pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
-                    pending_edus.extend(
-                        self.pending_edus_keyed_by_dest.pop(destination, {}).values()
-                    )
+                pending_edus.extend(
+                    self.pending_edus_keyed_by_dest.pop(destination, {}).values()
+                )
 
-                    limiter = yield get_retry_limiter(
-                        destination,
-                        self.clock,
-                        self.store,
-                    )
+                limiter = yield get_retry_limiter(
+                    destination,
+                    self.clock,
+                    self.store,
+                )
 
-                    device_message_edus, device_stream_id = (
-                        yield self._get_new_device_messages(destination)
-                    )
+                device_message_edus, device_stream_id, dev_list_id = (
+                    yield self._get_new_device_messages(destination)
+                )
 
-                    pending_edus.extend(device_message_edus)
-                    if pending_presence:
-                        pending_edus.append(
-                            Edu(
-                                origin=self.server_name,
-                                destination=destination,
-                                edu_type="m.presence",
-                                content={
-                                    "push": [
-                                        format_user_presence_state(
-                                            presence, self.clock.time_msec()
-                                        )
-                                        for presence in pending_presence.values()
-                                    ]
-                                },
-                            )
+                pending_edus.extend(device_message_edus)
+                if pending_presence:
+                    pending_edus.append(
+                        Edu(
+                            origin=self.server_name,
+                            destination=destination,
+                            edu_type="m.presence",
+                            content={
+                                "push": [
+                                    format_user_presence_state(
+                                        presence, self.clock.time_msec()
+                                    )
+                                    for presence in pending_presence.values()
+                                ]
+                            },
                         )
+                    )
 
-                    if pending_pdus:
-                        logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                                     destination, len(pending_pdus))
+                if pending_pdus:
+                    logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                                 destination, len(pending_pdus))
 
-                    if not pending_pdus and not pending_edus and not pending_failures:
-                        logger.debug("TX [%s] Nothing to send", destination)
-                        self.last_device_stream_id_by_dest[destination] = (
-                            device_stream_id
+                if not pending_pdus and not pending_edus and not pending_failures:
+                    logger.debug("TX [%s] Nothing to send", destination)
+                    self.last_device_stream_id_by_dest[destination] = (
+                        device_stream_id
+                    )
+                    return
+
+                success = yield self._send_new_transaction(
+                    destination, pending_pdus, pending_edus, pending_failures,
+                    limiter=limiter,
+                )
+                if success:
+                    # Remove the acknowledged device messages from the database
+                    # Only bother if we actually sent some device messages
+                    if device_message_edus:
+                        yield self.store.delete_device_msgs_for_remote(
+                            destination, device_stream_id
+                        )
+                        logger.info("Marking as sent %r %r", destination, dev_list_id)
+                        yield self.store.mark_as_sent_devices_by_remote(
+                            destination, dev_list_id
                         )
-                        return
 
-                    success = yield self._send_new_transaction(
-                        destination, pending_pdus, pending_edus, pending_failures,
-                        device_stream_id,
-                        includes_device_messages=bool(device_message_edus),
-                        limiter=limiter,
-                    )
-                    if not success:
-                        break
+                    self.last_device_stream_id_by_dest[destination] = device_stream_id
+                    self.last_device_list_stream_id_by_dest[destination] = dev_list_id
+                else:
+                    break
         except NotRetryingDestination:
             logger.debug(
                 "TX [%s] not ready for retry yet - "
@@ -374,8 +386,6 @@ class TransactionQueue(object):
 
     @defer.inlineCallbacks
     def _get_new_device_messages(self, destination):
-        # TODO: Send appropriate device list messages
-
         last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
         to_device_stream_id = self.store.get_to_device_stream_token()
         contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
@@ -404,13 +414,12 @@ class TransactionQueue(object):
             )
             for content in results
         )
-        defer.returnValue((edus, stream_id))
+        defer.returnValue((edus, stream_id, now_stream_id))
 
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures, device_stream_id,
-                              includes_device_messages, limiter):
+                              pending_failures, limiter):
 
         # Sort based on the order field
         pending_pdus.sort(key=lambda t: t[1])
@@ -521,14 +530,6 @@ class TransactionQueue(object):
                         "Failed to send event %s to %s", p.event_id, destination
                     )
                 success = False
-            else:
-                # Remove the acknowledged device messages from the database
-                # Only bother if we actually sent some device messages
-                if includes_device_messages:
-                    yield self.store.delete_device_msgs_for_remote(
-                        destination, device_stream_id
-                    )
-                self.last_device_stream_id_by_dest[destination] = device_stream_id
         except RuntimeError as e:
             # We capture this here as there as nothing actually listens
             # for this finishing functions deferred.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d92780b642..ba4c48d590 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -29,6 +29,7 @@ class DeviceHandler(BaseHandler):
         super(DeviceHandler, self).__init__(hs)
 
         self.state = hs.get_state_handler()
+        self.federation = hs.get_federation_sender()
 
     @defer.inlineCallbacks
     def check_device_registered(self, user_id, device_id,
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index b594f501f9..9628e2ff75 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -141,11 +141,11 @@ class DeviceStore(SQLBaseStore):
     def get_devices_by_remote(self, destination, from_stream_id):
         now_stream_id = self._device_list_id_gen.get_current_token()
 
-        has_changed = self._device_list_stream_cache.has_entity_changed(
+        has_changed = self._device_list_federation_stream_cache.has_entity_changed(
             destination, int(from_stream_id)
         )
         if not has_changed:
-            defer.returnValue((now_stream_id, []))
+            return (now_stream_id, [])
 
         return self.runInteraction(
             "get_devices_by_remote", self._get_devices_by_remote_txn,
@@ -165,7 +165,7 @@ class DeviceStore(SQLBaseStore):
         rows = txn.fetchall()
 
         if not rows:
-            return now_stream_id, []
+            return (now_stream_id, [])
 
         # maps (user_id, device_id) -> stream_id
         query_map = {(r[0], r[1]): r[2] for r in rows}
@@ -189,7 +189,7 @@ class DeviceStore(SQLBaseStore):
                 result = {
                     "user_id": user_id,
                     "device_id": device_id,
-                    "prev_id": prev_id,
+                    "prev_id": [prev_id] if prev_id else [],
                     "stream_id": stream_id,
                 }
 
@@ -202,9 +202,9 @@ class DeviceStore(SQLBaseStore):
                 if device_display_name:
                     result["device_display_name"] = device_display_name
 
-                results.setdefault(user_id, {})[device_id] = result
+                results.append(result)
 
-        return now_stream_id, results
+        return (now_stream_id, results)
 
     def mark_as_sent_devices_by_remote(self, destination, stream_id):
         return self.runInteraction(
@@ -212,19 +212,6 @@ class DeviceStore(SQLBaseStore):
             destination, stream_id,
         )
 
-    @defer.inlineCallbacks
-    def get_user_whose_devices_changed(self, from_key):
-        from_key = int(from_key)
-        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
-        if changed is not None:
-            defer.returnValue(set(changed))
-
-        sql = """
-            SELECT user_id FROM device_lists_stream WHERE stream_id > ?
-        """
-        rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
-        defer.returnValue(set(row["user_id"] for row in rows))
-
     def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
         sql = """
             DELETE FROM device_lists_outbound_pokes
@@ -239,7 +226,20 @@ class DeviceStore(SQLBaseStore):
             UPDATE device_lists_outbound_pokes SET sent = ?
             WHERE destination = ? AND stream_id <= ?
         """
-        txn.execute(sql, (destination, True,))
+        txn.execute(sql, (True, destination, stream_id,))
+
+    @defer.inlineCallbacks
+    def get_user_whose_devices_changed(self, from_key):
+        from_key = int(from_key)
+        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
+        if changed is not None:
+            defer.returnValue(set(changed))
+
+        sql = """
+            SELECT user_id FROM device_lists_stream WHERE stream_id > ?
+        """
+        rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+        defer.returnValue(set(row["user_id"] for row in rows))
 
     @defer.inlineCallbacks
     def add_device_change_to_streams(self, user_id, device_id, hosts):