summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-01-17 09:29:58 +0000
committerGitHub <noreply@github.com>2023-01-17 09:29:58 +0000
commit2b084c5b710d9630178484e6ade597ca7fa814b6 (patch)
treed80e5964b12d2e3710dfe49186fb0229c4afc1dc /synapse
parentAdd parameter to control whether we do a partial state join (#14843) (diff)
downloadsynapse-2b084c5b710d9630178484e6ade597ca7fa814b6.tar.xz
Merge device list replication streams (#14833)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/tcp/client.py8
-rw-r--r--synapse/replication/tcp/streams/__init__.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py74
-rw-r--r--synapse/storage/databases/main/devices.py13
4 files changed, 65 insertions, 33 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 7263bb2796..31022ce5fb 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -187,7 +187,7 @@ class ReplicationDataHandler:
         elif stream_name == DeviceListsStream.NAME:
             all_room_ids: Set[str] = set()
             for row in rows:
-                if row.entity.startswith("@"):
+                if row.entity.startswith("@") and not row.is_signature:
                     room_ids = await self.store.get_rooms_for_user(row.entity)
                     all_room_ids.update(room_ids)
             self.notifier.on_new_event(
@@ -422,7 +422,11 @@ class FederationSenderHandler:
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+            hosts = {
+                row.entity
+                for row in rows
+                if not row.entity.startswith("@") and not row.is_signature
+            }
             for host in hosts:
                 self.federation_sender.send_device_messages(host, immediate=False)
 
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index a7eadfa3c9..9c67f661a3 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -37,7 +37,6 @@ from synapse.replication.tcp.streams._base import (
     Stream,
     ToDeviceStream,
     TypingStream,
-    UserSignatureStream,
 )
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.replication.tcp.streams.federation import FederationStream
@@ -62,7 +61,6 @@ STREAMS_MAP = {
         ToDeviceStream,
         FederationStream,
         AccountDataStream,
-        UserSignatureStream,
         UnPartialStatedRoomStream,
         UnPartialStatedEventStream,
     )
@@ -82,7 +80,6 @@ __all__ = [
     "DeviceListsStream",
     "ToDeviceStream",
     "AccountDataStream",
-    "UserSignatureStream",
     "UnPartialStatedRoomStream",
     "UnPartialStatedEventStream",
 ]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index fbf78da9c2..a4bdb48c0c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -463,18 +463,67 @@ class DeviceListsStream(Stream):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class DeviceListsStreamRow:
         entity: str
+        # Indicates that a user has signed their own device with their user-signing key
+        is_signature: bool
 
     NAME = "device_lists"
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastores().main
+        self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_device_stream_token),
-            store.get_all_device_list_changes_for_remotes,
+            current_token_without_instance(self.store.get_device_stream_token),
+            self._update_function,
+        )
+
+    async def _update_function(
+        self,
+        instance_name: str,
+        from_token: Token,
+        current_token: Token,
+        target_row_count: int,
+    ) -> StreamUpdateResult:
+        (
+            device_updates,
+            devices_to_token,
+            devices_limited,
+        ) = await self.store.get_all_device_list_changes_for_remotes(
+            instance_name, from_token, current_token, target_row_count
         )
 
+        (
+            signatures_updates,
+            signatures_to_token,
+            signatures_limited,
+        ) = await self.store.get_all_user_signature_changes_for_remotes(
+            instance_name, from_token, current_token, target_row_count
+        )
+
+        upper_limit_token = current_token
+        if devices_limited:
+            upper_limit_token = min(upper_limit_token, devices_to_token)
+        if signatures_limited:
+            upper_limit_token = min(upper_limit_token, signatures_to_token)
+
+        device_updates = [
+            (stream_id, (entity, False))
+            for stream_id, (entity,) in device_updates
+            if stream_id <= upper_limit_token
+        ]
+
+        signatures_updates = [
+            (stream_id, (entity, True))
+            for stream_id, (entity,) in signatures_updates
+            if stream_id <= upper_limit_token
+        ]
+
+        updates = list(
+            heapq.merge(device_updates, signatures_updates, key=lambda row: row[0])
+        )
+
+        return updates, upper_limit_token, devices_limited or signatures_limited
+
 
 class ToDeviceStream(Stream):
     """New to_device messages for a client"""
@@ -583,22 +632,3 @@ class AccountDataStream(Stream):
             heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
         )
         return updates, to_token, limited
-
-
-class UserSignatureStream(Stream):
-    """A user has signed their own device with their user-signing key"""
-
-    @attr.s(slots=True, frozen=True, auto_attribs=True)
-    class UserSignatureStreamRow:
-        user_id: str
-
-    NAME = "user_signature"
-    ROW_TYPE = UserSignatureStreamRow
-
-    def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastores().main
-        super().__init__(
-            hs.get_instance_name(),
-            current_token_without_instance(store.get_device_stream_token),
-            store.get_all_user_signature_changes_for_remotes,
-        )
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index b067664473..cd186c8472 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,7 +38,7 @@ from synapse.logging.opentracing import (
     whitelisted_homeserver,
 )
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
+from synapse.replication.tcp.streams._base import DeviceListsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -163,9 +163,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     ) -> None:
         if stream_name == DeviceListsStream.NAME:
             self._invalidate_caches_for_devices(token, rows)
-        elif stream_name == UserSignatureStream.NAME:
-            for row in rows:
-                self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def process_replication_position(
@@ -173,14 +171,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     ) -> None:
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(instance_name, token)
-        elif stream_name == UserSignatureStream.NAME:
-            self._device_list_id_gen.advance(instance_name, token)
+
         super().process_replication_position(stream_name, instance_name, token)
 
     def _invalidate_caches_for_devices(
         self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
     ) -> None:
         for row in rows:
+            if row.is_signature:
+                self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+                continue
+
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.