diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/replication/tcp/client.py | 8 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/__init__.py | 3 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 74 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 13 |
4 files changed, 65 insertions, 33 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 7263bb2796..31022ce5fb 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -187,7 +187,7 @@ class ReplicationDataHandler: elif stream_name == DeviceListsStream.NAME: all_room_ids: Set[str] = set() for row in rows: - if row.entity.startswith("@"): + if row.entity.startswith("@") and not row.is_signature: room_ids = await self.store.get_rooms_for_user(row.entity) all_room_ids.update(room_ids) self.notifier.on_new_event( @@ -422,7 +422,11 @@ class FederationSenderHandler: # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. - hosts = {row.entity for row in rows if not row.entity.startswith("@")} + hosts = { + row.entity + for row in rows + if not row.entity.startswith("@") and not row.is_signature + } for host in hosts: self.federation_sender.send_device_messages(host, immediate=False) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index a7eadfa3c9..9c67f661a3 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -37,7 +37,6 @@ from synapse.replication.tcp.streams._base import ( Stream, ToDeviceStream, TypingStream, - UserSignatureStream, ) from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.federation import FederationStream @@ -62,7 +61,6 @@ STREAMS_MAP = { ToDeviceStream, FederationStream, AccountDataStream, - UserSignatureStream, UnPartialStatedRoomStream, UnPartialStatedEventStream, ) @@ -82,7 +80,6 @@ __all__ = [ "DeviceListsStream", "ToDeviceStream", "AccountDataStream", - "UserSignatureStream", "UnPartialStatedRoomStream", "UnPartialStatedEventStream", ] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index fbf78da9c2..a4bdb48c0c 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -463,18 +463,67 @@ class DeviceListsStream(Stream): @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: entity: str + # Indicates that a user has signed their own device with their user-signing key + is_signature: bool NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_device_stream_token), - store.get_all_device_list_changes_for_remotes, + current_token_without_instance(self.store.get_device_stream_token), + self._update_function, + ) + + async def _update_function( + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, + ) -> StreamUpdateResult: + ( + device_updates, + devices_to_token, + devices_limited, + ) = await self.store.get_all_device_list_changes_for_remotes( + instance_name, from_token, current_token, target_row_count ) + ( + signatures_updates, + signatures_to_token, + signatures_limited, + ) = await self.store.get_all_user_signature_changes_for_remotes( + instance_name, from_token, current_token, target_row_count + ) + + upper_limit_token = current_token + if devices_limited: + upper_limit_token = min(upper_limit_token, devices_to_token) + if signatures_limited: + upper_limit_token = min(upper_limit_token, signatures_to_token) + + device_updates = [ + (stream_id, (entity, False)) + for stream_id, (entity,) in device_updates + if stream_id <= upper_limit_token + ] + + signatures_updates = [ + (stream_id, (entity, True)) + for stream_id, (entity,) in signatures_updates + if stream_id <= upper_limit_token + ] + + updates = list( + heapq.merge(device_updates, signatures_updates, key=lambda row: row[0]) + ) + + return updates, upper_limit_token, devices_limited or signatures_limited + class ToDeviceStream(Stream): """New to_device messages for a client""" @@ -583,22 +632,3 @@ class AccountDataStream(Stream): heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0]) ) return updates, to_token, limited - - -class UserSignatureStream(Stream): - """A user has signed their own device with their user-signing key""" - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class UserSignatureStreamRow: - user_id: str - - NAME = "user_signature" - ROW_TYPE = UserSignatureStreamRow - - def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_device_stream_token), - store.get_all_user_signature_changes_for_remotes, - ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index b067664473..cd186c8472 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,7 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream +from synapse.replication.tcp.streams._base import DeviceListsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -163,9 +163,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) -> None: if stream_name == DeviceListsStream.NAME: self._invalidate_caches_for_devices(token, rows) - elif stream_name == UserSignatureStream.NAME: - for row in rows: - self._user_signature_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) def process_replication_position( @@ -173,14 +171,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) -> None: if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) - elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: for row in rows: + if row.is_signature: + self._user_signature_stream_cache.entity_has_changed(row.entity, token) + continue + # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. |