diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 7263bb2796..31022ce5fb 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -187,7 +187,7 @@ class ReplicationDataHandler:
elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
- if row.entity.startswith("@"):
+ if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
self.notifier.on_new_event(
@@ -422,7 +422,11 @@ class FederationSenderHandler:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
- hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+ hosts = {
+ row.entity
+ for row in rows
+ if not row.entity.startswith("@") and not row.is_signature
+ }
for host in hosts:
self.federation_sender.send_device_messages(host, immediate=False)
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index a7eadfa3c9..9c67f661a3 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -37,7 +37,6 @@ from synapse.replication.tcp.streams._base import (
Stream,
ToDeviceStream,
TypingStream,
- UserSignatureStream,
)
from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.federation import FederationStream
@@ -62,7 +61,6 @@ STREAMS_MAP = {
ToDeviceStream,
FederationStream,
AccountDataStream,
- UserSignatureStream,
UnPartialStatedRoomStream,
UnPartialStatedEventStream,
)
@@ -82,7 +80,6 @@ __all__ = [
"DeviceListsStream",
"ToDeviceStream",
"AccountDataStream",
- "UserSignatureStream",
"UnPartialStatedRoomStream",
"UnPartialStatedEventStream",
]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index fbf78da9c2..a4bdb48c0c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -463,18 +463,67 @@ class DeviceListsStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
entity: str
+ # Indicates that a user has signed their own device with their user-signing key
+ is_signature: bool
NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
+ self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_device_stream_token),
- store.get_all_device_list_changes_for_remotes,
+ current_token_without_instance(self.store.get_device_stream_token),
+ self._update_function,
+ )
+
+ async def _update_function(
+ self,
+ instance_name: str,
+ from_token: Token,
+ current_token: Token,
+ target_row_count: int,
+ ) -> StreamUpdateResult:
+ (
+ device_updates,
+ devices_to_token,
+ devices_limited,
+ ) = await self.store.get_all_device_list_changes_for_remotes(
+ instance_name, from_token, current_token, target_row_count
)
+ (
+ signatures_updates,
+ signatures_to_token,
+ signatures_limited,
+ ) = await self.store.get_all_user_signature_changes_for_remotes(
+ instance_name, from_token, current_token, target_row_count
+ )
+
+ upper_limit_token = current_token
+ if devices_limited:
+ upper_limit_token = min(upper_limit_token, devices_to_token)
+ if signatures_limited:
+ upper_limit_token = min(upper_limit_token, signatures_to_token)
+
+ device_updates = [
+ (stream_id, (entity, False))
+ for stream_id, (entity,) in device_updates
+ if stream_id <= upper_limit_token
+ ]
+
+ signatures_updates = [
+ (stream_id, (entity, True))
+ for stream_id, (entity,) in signatures_updates
+ if stream_id <= upper_limit_token
+ ]
+
+ updates = list(
+ heapq.merge(device_updates, signatures_updates, key=lambda row: row[0])
+ )
+
+ return updates, upper_limit_token, devices_limited or signatures_limited
+
class ToDeviceStream(Stream):
"""New to_device messages for a client"""
@@ -583,22 +632,3 @@ class AccountDataStream(Stream):
heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
)
return updates, to_token, limited
-
-
-class UserSignatureStream(Stream):
- """A user has signed their own device with their user-signing key"""
-
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class UserSignatureStreamRow:
- user_id: str
-
- NAME = "user_signature"
- ROW_TYPE = UserSignatureStreamRow
-
- def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
- super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_device_stream_token),
- store.get_all_user_signature_changes_for_remotes,
- )
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index b067664473..cd186c8472 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,7 +38,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
+from synapse.replication.tcp.streams._base import DeviceListsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -163,9 +163,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None:
if stream_name == DeviceListsStream.NAME:
self._invalidate_caches_for_devices(token, rows)
- elif stream_name == UserSignatureStream.NAME:
- for row in rows:
- self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+
return super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
@@ -173,14 +171,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
- elif stream_name == UserSignatureStream.NAME:
- self._device_list_id_gen.advance(instance_name, token)
+
super().process_replication_position(stream_name, instance_name, token)
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
+ if row.is_signature:
+ self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+ continue
+
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
|