summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/federation_sender.py20
-rw-r--r--synapse/federation/send_queue.py32
2 files changed, 22 insertions, 30 deletions
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 38d11fdd0f..63a91f1177 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -38,7 +38,11 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.slave.storage.transactions import SlavedTransactionStore
 from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.replication.tcp.streams._base import ReceiptsStream
+from synapse.replication.tcp.streams._base import (
+    DeviceListsStream,
+    ReceiptsStream,
+    ToDeviceStream,
+)
 from synapse.server import HomeServer
 from synapse.storage.database import Database
 from synapse.types import ReadReceipt
@@ -256,6 +260,20 @@ class FederationSenderHandler(object):
                 "process_receipts_for_federation", self._on_new_receipts, rows
             )
 
+        # ... as well as device updates and messages
+        elif stream_name == DeviceListsStream.NAME:
+            hosts = set(row.destination for row in rows)
+            for host in hosts:
+                self.federation_sender.send_device_messages(host)
+
+        elif stream_name == ToDeviceStream.NAME:
+            # The to_device stream includes stuff to be pushed to both local
+            # clients and remote servers, so we ignore entities that start with
+            # '@' (since they'll be local users rather than destinations).
+            hosts = set(row.entity for row in rows if not row.entity.startswith("@"))
+            for host in hosts:
+                self.federation_sender.send_device_messages(host)
+
     @defer.inlineCallbacks
     def _on_new_receipts(self, rows):
         """
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 174f6e42be..0bb82a6bb3 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -69,8 +69,6 @@ class FederationRemoteSendQueue(object):
 
         self.edus = SortedDict()  # stream position -> Edu
 
-        self.device_messages = SortedDict()  # stream position -> destination
-
         self.pos = 1
         self.pos_time = SortedDict()
 
@@ -92,7 +90,6 @@ class FederationRemoteSendQueue(object):
             "keyed_edu",
             "keyed_edu_changed",
             "edus",
-            "device_messages",
             "pos_time",
             "presence_destinations",
         ]:
@@ -171,12 +168,6 @@ class FederationRemoteSendQueue(object):
             for key in keys[:i]:
                 del self.edus[key]
 
-            # Delete things out of device map
-            keys = self.device_messages.keys()
-            i = self.device_messages.bisect_left(position_to_delete)
-            for key in keys[:i]:
-                del self.device_messages[key]
-
     def notify_new_events(self, current_id):
         """As per FederationSender"""
         # We don't need to replicate this as it gets sent down a different
@@ -249,9 +240,8 @@ class FederationRemoteSendQueue(object):
 
     def send_device_messages(self, destination):
         """As per FederationSender"""
-        pos = self._next_pos()
-        self.device_messages[pos] = destination
-        self.notifier.on_new_replication_data()
+        # We don't need to replicate this as it gets sent down a different
+        # stream.
 
     def get_current_token(self):
         return self.pos - 1
@@ -339,14 +329,6 @@ class FederationRemoteSendQueue(object):
         for (pos, edu) in edus:
             rows.append((pos, EduRow(edu)))
 
-        # Fetch changed device messages
-        i = self.device_messages.bisect_right(from_token)
-        j = self.device_messages.bisect_right(to_token) + 1
-        device_messages = {v: k for k, v in self.device_messages.items()[i:j]}
-
-        for (destination, pos) in iteritems(device_messages):
-            rows.append((pos, DeviceRow(destination=destination)))
-
         # Sort rows based on pos
         rows.sort()
 
@@ -504,7 +486,6 @@ ParsedFederationStreamData = namedtuple(
         "presence_destinations",  # list of tuples of UserPresenceState and destinations
         "keyed_edus",  # dict of destination -> { key -> Edu }
         "edus",  # dict of destination -> [Edu]
-        "device_destinations",  # set of destinations
     ),
 )
 
@@ -523,11 +504,7 @@ def process_rows_for_federation(transaction_queue, rows):
     # them into the appropriate collection and then send them off.
 
     buff = ParsedFederationStreamData(
-        presence=[],
-        presence_destinations=[],
-        keyed_edus={},
-        edus={},
-        device_destinations=set(),
+        presence=[], presence_destinations=[], keyed_edus={}, edus={},
     )
 
     # Parse the rows in the stream and add to the buffer
@@ -555,6 +532,3 @@ def process_rows_for_federation(transaction_queue, rows):
     for destination, edu_list in iteritems(buff.edus):
         for edu in edu_list:
             transaction_queue.send_edu(edu, None)
-
-    for destination in buff.device_destinations:
-        transaction_queue.send_device_messages(destination)