summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/slave/storage/devices.py36
-rw-r--r--synapse/replication/tcp/streams/_base.py13
2 files changed, 32 insertions, 17 deletions
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 1c77687eea..23b1650e41 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
         self.hs = hs
 
         self._device_list_id_gen = SlavedIdTracker(
-            db_conn, "device_lists_stream", "stream_id"
+            db_conn,
+            "device_lists_stream",
+            "stream_id",
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
         )
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
@@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
-            for row in rows:
-                self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+            self._invalidate_caches_for_devices(token, rows)
         elif stream_name == UserSignatureStream.NAME:
+            self._device_list_id_gen.advance(token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super(SlavedDeviceStore, self).process_replication_rows(
             stream_name, token, rows
         )
 
-    def _invalidate_caches_for_devices(self, token, user_id, destination):
-        self._device_list_stream_cache.entity_has_changed(user_id, token)
-
-        if destination:
-            self._device_list_federation_stream_cache.entity_has_changed(
-                destination, token
-            )
+    def _invalidate_caches_for_devices(self, token, rows):
+        for row in rows:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            if row.entity.startswith("@"):
+                self._device_list_stream_cache.entity_has_changed(row.entity, token)
+                self.get_cached_devices_for_user.invalidate((row.entity,))
+                self._get_cached_user_device.invalidate_many((row.entity,))
+                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
-        self.get_cached_devices_for_user.invalidate((user_id,))
-        self._get_cached_user_device.invalidate_many((user_id,))
-        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+            else:
+                self._device_list_federation_stream_cache.entity_has_changed(
+                    row.entity, token
+                )
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..7a8b6e9df1 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -94,9 +94,13 @@ PublicRoomsStreamRow = namedtuple(
         "network_id",  # str, optional
     ),
 )
-DeviceListsStreamRow = namedtuple(
-    "DeviceListsStreamRow", ("user_id", "destination")  # str  # str
-)
+
+
+@attr.s
+class DeviceListsStreamRow:
+    entity = attr.ib(type=str)
+
+
 ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
 TagAccountDataStreamRow = namedtuple(
     "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
@@ -363,7 +367,8 @@ class PublicRoomsStream(Stream):
 
 
 class DeviceListsStream(Stream):
-    """Someone added/changed/removed a device
+    """Either a user has updated their devices or a remote server needs to be
+    told about a device update.
     """
 
     NAME = "device_lists"