summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py8
-rw-r--r--synapse/replication/tcp/streams/__init__.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py74
3 files changed, 58 insertions, 27 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 7263bb2796..31022ce5fb 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -187,7 +187,7 @@ class ReplicationDataHandler:
         elif stream_name == DeviceListsStream.NAME:
             all_room_ids: Set[str] = set()
             for row in rows:
-                if row.entity.startswith("@"):
+                if row.entity.startswith("@") and not row.is_signature:
                     room_ids = await self.store.get_rooms_for_user(row.entity)
                     all_room_ids.update(room_ids)
             self.notifier.on_new_event(
@@ -422,7 +422,11 @@ class FederationSenderHandler:
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+            hosts = {
+                row.entity
+                for row in rows
+                if not row.entity.startswith("@") and not row.is_signature
+            }
             for host in hosts:
                 self.federation_sender.send_device_messages(host, immediate=False)
 
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index a7eadfa3c9..9c67f661a3 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -37,7 +37,6 @@ from synapse.replication.tcp.streams._base import (
     Stream,
     ToDeviceStream,
     TypingStream,
-    UserSignatureStream,
 )
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.replication.tcp.streams.federation import FederationStream
@@ -62,7 +61,6 @@ STREAMS_MAP = {
         ToDeviceStream,
         FederationStream,
         AccountDataStream,
-        UserSignatureStream,
         UnPartialStatedRoomStream,
         UnPartialStatedEventStream,
     )
@@ -82,7 +80,6 @@ __all__ = [
     "DeviceListsStream",
     "ToDeviceStream",
     "AccountDataStream",
-    "UserSignatureStream",
     "UnPartialStatedRoomStream",
     "UnPartialStatedEventStream",
 ]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index fbf78da9c2..a4bdb48c0c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -463,18 +463,67 @@ class DeviceListsStream(Stream):
     @attr.s(slots=True, frozen=True, auto_attribs=True)
     class DeviceListsStreamRow:
         entity: str
+        # Indicates that a user has signed their own device with their user-signing key
+        is_signature: bool
 
     NAME = "device_lists"
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastores().main
+        self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(store.get_device_stream_token),
-            store.get_all_device_list_changes_for_remotes,
+            current_token_without_instance(self.store.get_device_stream_token),
+            self._update_function,
+        )
+
+    async def _update_function(
+        self,
+        instance_name: str,
+        from_token: Token,
+        current_token: Token,
+        target_row_count: int,
+    ) -> StreamUpdateResult:
+        (
+            device_updates,
+            devices_to_token,
+            devices_limited,
+        ) = await self.store.get_all_device_list_changes_for_remotes(
+            instance_name, from_token, current_token, target_row_count
         )
 
+        (
+            signatures_updates,
+            signatures_to_token,
+            signatures_limited,
+        ) = await self.store.get_all_user_signature_changes_for_remotes(
+            instance_name, from_token, current_token, target_row_count
+        )
+
+        upper_limit_token = current_token
+        if devices_limited:
+            upper_limit_token = min(upper_limit_token, devices_to_token)
+        if signatures_limited:
+            upper_limit_token = min(upper_limit_token, signatures_to_token)
+
+        device_updates = [
+            (stream_id, (entity, False))
+            for stream_id, (entity,) in device_updates
+            if stream_id <= upper_limit_token
+        ]
+
+        signatures_updates = [
+            (stream_id, (entity, True))
+            for stream_id, (entity,) in signatures_updates
+            if stream_id <= upper_limit_token
+        ]
+
+        updates = list(
+            heapq.merge(device_updates, signatures_updates, key=lambda row: row[0])
+        )
+
+        return updates, upper_limit_token, devices_limited or signatures_limited
+
 
 class ToDeviceStream(Stream):
     """New to_device messages for a client"""
@@ -583,22 +632,3 @@ class AccountDataStream(Stream):
             heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
         )
         return updates, to_token, limited
-
-
-class UserSignatureStream(Stream):
-    """A user has signed their own device with their user-signing key"""
-
-    @attr.s(slots=True, frozen=True, auto_attribs=True)
-    class UserSignatureStreamRow:
-        user_id: str
-
-    NAME = "user_signature"
-    ROW_TYPE = UserSignatureStreamRow
-
-    def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastores().main
-        super().__init__(
-            hs.get_instance_name(),
-            current_token_without_instance(store.get_device_stream_token),
-            store.get_all_user_signature_changes_for_remotes,
-        )