From e0bb2681340752f2716f1f386139ca37c53f26cd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 29 Mar 2022 22:37:50 +0100 Subject: Fix typechecker problems exposed by signedjson 1.1.2 (#12326) --- synapse/config/key.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/key.py b/synapse/config/key.py index ee83c6c06b..f5377e7d9c 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -16,7 +16,7 @@ import hashlib import logging import os -from typing import Any, Dict, Iterator, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional import attr import jsonschema @@ -38,6 +38,9 @@ from synapse.util.stringutils import random_string, random_string_with_symbols from ._base import Config, ConfigError +if TYPE_CHECKING: + from signedjson.key import VerifyKeyWithExpiry + INSECURE_NOTARY_ERROR = """\ Your server is configured to accept key server responses without signature validation or TLS certificate validation. This is likely to be very insecure. If @@ -300,7 +303,7 @@ class KeyConfig(Config): def read_old_signing_keys( self, old_signing_keys: Optional[JsonDict] - ) -> Dict[str, VerifyKey]: + ) -> Dict[str, "VerifyKeyWithExpiry"]: if old_signing_keys is None: return {} keys = {} @@ -308,8 +311,8 @@ class KeyConfig(Config): if is_signing_algorithm_supported(key_id): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) - verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_key.expired_ts = key_data["expired_ts"] + verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(key_id, key_bytes) # type: ignore[assignment] + verify_key.expired = key_data["expired_ts"] keys[key_id] = verify_key else: raise ConfigError( @@ -422,7 +425,7 @@ def _parse_key_servers( server_name = server["server_name"] result = TrustedKeyServer(server_name=server_name) - verify_keys = server.get("verify_keys") + verify_keys: Optional[Dict[str, str]] = server.get("verify_keys") if verify_keys is not None: result.verify_keys = {} for key_id, key_base64 in verify_keys.items(): -- cgit 1.4.1 From 437a8ed9efdf8f1aefa092d0761076da3ae78100 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Mar 2022 11:43:04 +0200 Subject: Add a configuration to exclude rooms from sync response (#12310) --- changelog.d/12310.feature | 1 + docs/sample_config.yaml | 9 ++++ synapse/config/server.py | 13 ++++++ synapse/handlers/sync.py | 23 ++++++++--- synapse/storage/databases/main/roommember.py | 21 +++++++--- synapse/storage/databases/main/stream.py | 30 +++++++++----- tests/rest/client/test_sync.py | 62 ++++++++++++++++++++++++++++ 7 files changed, 138 insertions(+), 21 deletions(-) create mode 100644 changelog.d/12310.feature (limited to 'synapse/config') diff --git a/changelog.d/12310.feature b/changelog.d/12310.feature new file mode 100644 index 0000000000..f3fbb298f7 --- /dev/null +++ b/changelog.d/12310.feature @@ -0,0 +1 @@ +Add a configuration option to remove a specific set of rooms from sync responses. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a21b48ab2e..b8d8c0dbf0 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -539,6 +539,15 @@ templates: # #custom_template_directory: /path/to/custom/templates/ +# List of rooms to exclude from sync responses. This is useful for server +# administrators wishing to group users into a room without these users being able +# to see it from their client. +# +# By default, no room is excluded. +# +#exclude_rooms_from_sync: +# - !foo:example.com + # Message retention policy at the server level. # diff --git a/synapse/config/server.py b/synapse/config/server.py index 38de4b8000..0f90302c95 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -680,6 +680,10 @@ class ServerConfig(Config): config.get("use_account_validity_in_account_status") or False ) + self.rooms_to_exclude_from_sync: List[str] = ( + config.get("exclude_rooms_from_sync") or [] + ) + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) @@ -1234,6 +1238,15 @@ class ServerConfig(Config): # information about using custom templates. # #custom_template_directory: /path/to/custom/templates/ + + # List of rooms to exclude from sync responses. This is useful for server + # administrators wishing to group users into a room without these users being able + # to see it from their client. + # + # By default, no room is excluded. + # + #exclude_rooms_from_sync: + # - !foo:example.com """ % locals() ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6c569cfb1c..bceafca3b1 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -298,6 +298,8 @@ class SyncHandler: expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) + self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync + async def wait_for_sync_for_user( self, requester: Requester, @@ -1607,13 +1609,15 @@ class SyncHandler: ignored_users = await self.store.ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( - sync_result_builder, ignored_users + sync_result_builder, ignored_users, self.rooms_to_exclude ) tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) + room_changes = await self._get_all_rooms( + sync_result_builder, ignored_users, self.rooms_to_exclude + ) tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1689,7 +1693,10 @@ class SyncHandler: return False async def _get_rooms_changed( - self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] + self, + sync_result_builder: "SyncResultBuilder", + ignored_users: FrozenSet[str], + excluded_rooms: List[str], ) -> _RoomChanges: """Determine the changes in rooms to report to the user. @@ -1721,7 +1728,7 @@ class SyncHandler: # _have_rooms_changed. We could keep the results in memory to avoid a # second query, at the cost of more complicated source code. membership_change_events = await self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key + user_id, since_token.room_key, now_token.room_key, excluded_rooms ) mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} @@ -1922,7 +1929,10 @@ class SyncHandler: ) async def _get_all_rooms( - self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] + self, + sync_result_builder: "SyncResultBuilder", + ignored_users: FrozenSet[str], + ignored_rooms: List[str], ) -> _RoomChanges: """Returns entries for all rooms for the user. @@ -1933,7 +1943,7 @@ class SyncHandler: Args: sync_result_builder ignored_users: Set of users ignored by user. - + ignored_rooms: List of rooms to ignore. """ user_id = sync_result_builder.sync_config.user.to_string() @@ -1944,6 +1954,7 @@ class SyncHandler: room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=Membership.LIST, + excluded_rooms=ignored_rooms, ) room_entries = [] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 3248da5356..98d09b3736 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): return None async def get_rooms_for_local_user_where_membership_is( - self, user_id: str, membership_list: Collection[str] + self, + user_id: str, + membership_list: Collection[str], + excluded_rooms: Optional[List[str]] = None, ) -> List[RoomsForUser]: """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. @@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_id: The user ID. membership_list: A list of synapse.api.constants.Membership values which the user must be in. + excluded_rooms: A list of rooms to ignore. Returns: The RoomsForUser that the user matches the membership types. @@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership_list, ) - # Now we filter out forgotten rooms - forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id) - return [room for room in rooms if room.room_id not in forgotten_rooms] + # Now we filter out forgotten and excluded rooms + rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id) + + if excluded_rooms is not None: + rooms_to_exclude.update(set(excluded_rooms)) + + return [room for room in rooms if room.room_id not in rooms_to_exclude] def _get_rooms_for_local_user_where_membership_is_txn( - self, txn, user_id: str, membership_list: List[str] + self, + txn, + user_id: str, + membership_list: List[str], ) -> List[RoomsForUser]: # Paranoia check. if not self.hs.is_mine_id(user_id): diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 39e1efe373..8e764790db 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -36,7 +36,7 @@ what sort order was used: """ import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple import attr from frozendict import frozendict @@ -585,7 +585,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key async def get_membership_changes_for_user( - self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken + self, + user_id: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, + excluded_rooms: Optional[List[str]] = None, ) -> List[EventBase]: """Fetch membership events for a given user. @@ -610,23 +614,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): min_from_id = from_key.stream max_to_id = to_key.get_max_stream_pos() + args: List[Any] = [user_id, min_from_id, max_to_id] + + ignore_room_clause = "" + if excluded_rooms is not None and len(excluded_rooms) > 0: + ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join( + "?" for _ in excluded_rooms + ) + args = args + excluded_rooms + sql = """ SELECT m.event_id, instance_name, topological_ordering, stream_ordering FROM events AS e, room_memberships AS m WHERE e.event_id = m.event_id AND m.user_id = ? AND e.stream_ordering > ? AND e.stream_ordering <= ? + %s ORDER BY e.stream_ordering ASC - """ - txn.execute( - sql, - ( - user_id, - min_from_id, - max_to_id, - ), + """ % ( + ignore_room_clause, ) + txn.execute(sql, args) + rows = [ _EventDictReturn(event_id, None, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 4351013952..f0f3a54f82 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -772,3 +772,65 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): self.assertIn( self.user_id, device_list_changes, incremental_sync_channel.json_body ) + + +class ExcludeRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + room.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + + self.excluded_room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + self.included_room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # We need to manually append the room ID, because we can't know the ID before + # creating the room, and we can't set the config after starting the homeserver. + self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id) + + def test_join_leave(self) -> None: + """Tests that rooms are correctly excluded from the 'join' and 'leave' sections of + sync responses. + """ + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"]) + + self.helper.leave(self.excluded_room_id, self.user_id, tok=self.tok) + self.helper.leave(self.included_room_id, self.user_id, tok=self.tok) + + channel = self.make_request( + "GET", + "/sync?since=" + channel.json_body["next_batch"], + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["leave"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["leave"]) + + def test_invite(self) -> None: + """Tests that rooms are correctly excluded from the 'invite' section of sync + responses. + """ + invitee = self.register_user("invitee", "password") + invitee_tok = self.login("invitee", "password") + + self.helper.invite(self.excluded_room_id, self.user_id, invitee, tok=self.tok) + self.helper.invite(self.included_room_id, self.user_id, invitee, tok=self.tok) + + channel = self.make_request("GET", "/sync", access_token=invitee_tok) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"]) -- cgit 1.4.1 From d8d0271977938d89585613e9a77537c33c4dc4a9 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 30 Mar 2022 14:39:27 +0100 Subject: Send device list updates to application services (MSC3202) - part 1 (#11881) Co-authored-by: Patrick Cloke --- changelog.d/11881.feature | 1 + synapse/appservice/__init__.py | 12 +- synapse/appservice/api.py | 10 +- synapse/appservice/scheduler.py | 53 +++++++- synapse/config/appservice.py | 1 + synapse/config/experimental.py | 5 +- synapse/handlers/appservice.py | 148 ++++++++++++++++++++- synapse/handlers/sync.py | 40 ++---- synapse/storage/databases/main/appservice.py | 14 +- synapse/storage/databases/main/devices.py | 48 +++++-- ...3202_add_device_list_appservice_stream_type.sql | 23 ++++ synapse/types.py | 25 ++++ tests/appservice/test_scheduler.py | 54 +++++--- tests/handlers/test_appservice.py | 121 ++++++++++++++++- tests/storage/test_appservice.py | 17 ++- 15 files changed, 490 insertions(+), 82 deletions(-) create mode 100644 changelog.d/11881.feature create mode 100644 synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql (limited to 'synapse/config') diff --git a/changelog.d/11881.feature b/changelog.d/11881.feature new file mode 100644 index 0000000000..392294ffc3 --- /dev/null +++ b/changelog.d/11881.feature @@ -0,0 +1 @@ +Send device list changes to application services as specified by [MSC3202](https://github.com/matrix-org/matrix-spec-proposals/pull/3202), using unstable prefixes. The `msc3202_transaction_extensions` experimental homeserver config option must be enabled and `org.matrix.msc3202: true` must be present in the application service registration file for device list changes to be sent. The "left" field is currently always empty. \ No newline at end of file diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 07ec95f1d6..d23d9221bc 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +23,13 @@ from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id +from synapse.types import ( + DeviceListUpdates, + GroupID, + JsonDict, + UserID, + get_domain_from_id, +) from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: @@ -400,6 +407,7 @@ class AppServiceTransaction: to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, ): self.service = service self.id = id @@ -408,6 +416,7 @@ class AppServiceTransaction: self.to_device_messages = to_device_messages self.one_time_key_counts = one_time_key_counts self.unused_fallback_keys = unused_fallback_keys + self.device_list_summary = device_list_summary async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -424,6 +433,7 @@ class AppServiceTransaction: to_device_messages=self.to_device_messages, one_time_key_counts=self.one_time_key_counts, unused_fallback_keys=self.unused_fallback_keys, + device_list_summary=self.device_list_summary, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 98fe354014..0cdbb04bfb 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +28,7 @@ from synapse.appservice import ( from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient -from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -225,6 +226,7 @@ class ApplicationServiceApi(SimpleHttpClient): to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, txn_id: Optional[int] = None, ) -> bool: """ @@ -268,6 +270,7 @@ class ApplicationServiceApi(SimpleHttpClient): } ) + # TODO: Update to stable prefixes once MSC3202 completes FCP merge if service.msc3202_transaction_extensions: if one_time_key_counts: body[ @@ -277,6 +280,11 @@ class ApplicationServiceApi(SimpleHttpClient): body[ "org.matrix.msc3202.device_unused_fallback_keys" ] = unused_fallback_keys + if device_list_summary: + body["org.matrix.msc3202.device_lists"] = { + "changed": list(device_list_summary.changed), + "left": list(device_list_summary.left), + } try: await self.put_json( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index a6084b9c35..3b49e60716 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -72,7 +72,7 @@ from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main import DataStore -from synapse.types import JsonDict +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import Clock if TYPE_CHECKING: @@ -122,6 +122,7 @@ class ApplicationServiceScheduler: events: Optional[Collection[EventBase]] = None, ephemeral: Optional[Collection[JsonDict]] = None, to_device_messages: Optional[Collection[JsonDict]] = None, + device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ Enqueue some data to be sent off to an application service. @@ -133,10 +134,18 @@ class ApplicationServiceScheduler: to_device_messages: The to-device messages to send. These differ from normal to-device messages sent to clients, as they have 'to_device_id' and 'to_user_id' fields. + device_list_summary: A summary of users that the application service either needs + to refresh the device lists of, or those that the application service need no + longer track the device lists of. """ # We purposefully allow this method to run with empty events/ephemeral # collections, so that callers do not need to check iterable size themselves. - if not events and not ephemeral and not to_device_messages: + if ( + not events + and not ephemeral + and not to_device_messages + and not device_list_summary + ): return if events: @@ -147,6 +156,10 @@ class ApplicationServiceScheduler: self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend( to_device_messages ) + if device_list_summary: + self.queuer.queued_device_list_summaries.setdefault( + appservice.id, [] + ).append(device_list_summary) # Kick off a new application service transaction self.queuer.start_background_request(appservice) @@ -169,6 +182,8 @@ class _ServiceQueuer: self.queued_ephemeral: Dict[str, List[JsonDict]] = {} # dict of {service_id: [to_device_message_json]} self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} + # dict of {service_id: [device_list_summary]} + self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {} # the appservices which currently have a transaction in flight self.requests_in_flight: Set[str] = set() @@ -212,7 +227,35 @@ class _ServiceQueuer: ] del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION] - if not events and not ephemeral and not to_device_messages_to_send: + # Consolidate any pending device list summaries into a single, up-to-date + # summary. + # Note: this code assumes that in a single DeviceListUpdates, a user will + # never be in both "changed" and "left" sets. + device_list_summary = DeviceListUpdates() + for summary in self.queued_device_list_summaries.get(service.id, []): + # For every user in the incoming "changed" set: + # * Remove them from the existing "left" set if necessary + # (as we need to start tracking them again) + # * Add them to the existing "changed" set if necessary. + device_list_summary.left.difference_update(summary.changed) + device_list_summary.changed.update(summary.changed) + + # For every user in the incoming "left" set: + # * Remove them from the existing "changed" set if necessary + # (we no longer need to track them) + # * Add them to the existing "left" set if necessary. + device_list_summary.changed.difference_update(summary.left) + device_list_summary.left.update(summary.left) + self.queued_device_list_summaries.clear() + + if ( + not events + and not ephemeral + and not to_device_messages_to_send + # DeviceListUpdates is True if either the 'changed' or 'left' sets have + # at least one entry, otherwise False + and not device_list_summary + ): return one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None @@ -240,6 +283,7 @@ class _ServiceQueuer: to_device_messages_to_send, one_time_key_counts, unused_fallback_keys, + device_list_summary, ) except Exception: logger.exception("AS request failed") @@ -322,6 +366,7 @@ class _TransactionController: to_device_messages: Optional[List[JsonDict]] = None, one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, + device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -336,6 +381,7 @@ class _TransactionController: appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. + device_list_summary: The device list summary to include in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -345,6 +391,7 @@ class _TransactionController: to_device_messages=to_device_messages or [], one_time_key_counts=one_time_key_counts or {}, unused_fallback_keys=unused_fallback_keys or {}, + device_list_summary=device_list_summary or DeviceListUpdates(), ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 439bfe1526..ada165f238 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -170,6 +170,7 @@ def _load_appservice( # When enabled, appservice transactions contain the following information: # - device One-Time Key counts # - device unused fallback key usage states + # - device list changes msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) if not isinstance(msc3202_transaction_extensions, bool): raise ValueError( diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 064db4487c..d6bb1f752b 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -59,8 +59,9 @@ class ExperimentalConfig(Config): "msc3202_device_masquerading", False ) - # Portion of MSC3202 related to transaction extensions: - # sending one-time key counts and fallback key usage to application services. + # The portion of MSC3202 related to transaction extensions: + # sending device list changes, one-time key counts and fallback key + # usage to application services. self.msc3202_transaction_extensions: bool = experimental.get( "msc3202_transaction_extensions", False ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index bd913e524e..316c4b677c 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -33,7 +33,13 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID +from synapse.types import ( + DeviceListUpdates, + JsonDict, + RoomAlias, + RoomStreamToken, + UserID, +) from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure @@ -58,6 +64,9 @@ class ApplicationServicesHandler: self._msc2409_to_device_messages_enabled = ( hs.config.experimental.msc2409_to_device_messages_enabled ) + self._msc3202_transaction_extensions_enabled = ( + hs.config.experimental.msc3202_transaction_extensions + ) self.current_max = 0 self.is_processing = False @@ -204,9 +213,9 @@ class ApplicationServicesHandler: Args: stream_key: The stream the event came from. - `stream_key` can be "typing_key", "receipt_key", "presence_key" or - "to_device_key". Any other value for `stream_key` will cause this function - to return early. + `stream_key` can be "typing_key", "receipt_key", "presence_key", + "to_device_key" or "device_list_key". Any other value for `stream_key` + will cause this function to return early. Ephemeral events will only be pushed to appservices that have opted into receiving them by setting `push_ephemeral` to true in their registration @@ -230,6 +239,7 @@ class ApplicationServicesHandler: "receipt_key", "presence_key", "to_device_key", + "device_list_key", ): return @@ -253,15 +263,37 @@ class ApplicationServicesHandler: ): return + # Ignore device lists if the feature flag is not enabled + if ( + stream_key == "device_list_key" + and not self._msc3202_transaction_extensions_enabled + ): + return + # Check whether there are any appservices which have registered to receive # ephemeral events. # # Note that whether these events are actually relevant to these appservices # is decided later on. + services = self.store.get_app_services() services = [ service - for service in self.store.get_app_services() - if service.supports_ephemeral + for service in services + # Different stream keys require different support booleans + if ( + stream_key + in ( + "typing_key", + "receipt_key", + "presence_key", + "to_device_key", + ) + and service.supports_ephemeral + ) + or ( + stream_key == "device_list_key" + and service.msc3202_transaction_extensions + ) ] if not services: # Bail out early if none of the target appservices have explicitly registered @@ -336,6 +368,20 @@ class ApplicationServicesHandler: service, "to_device", new_token ) + elif stream_key == "device_list_key": + device_list_summary = await self._get_device_list_summary( + service, new_token + ) + if device_list_summary: + self.scheduler.enqueue_for_appservice( + service, device_list_summary=device_list_summary + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_appservice_stream_type_pos( + service, "device_list", new_token + ) + async def _handle_typing( self, service: ApplicationService, new_token: int ) -> List[JsonDict]: @@ -542,6 +588,96 @@ class ApplicationServicesHandler: return message_payload + async def _get_device_list_summary( + self, + appservice: ApplicationService, + new_key: int, + ) -> DeviceListUpdates: + """ + Retrieve a list of users who have changed their device lists. + + Args: + appservice: The application service to retrieve device list changes for. + new_key: The stream key of the device list change that triggered this method call. + + Returns: + A set of device list updates, comprised of users that the appservices needs to: + * resync the device list of, and + * stop tracking the device list of. + """ + # Fetch the last successfully processed device list update stream ID + # for this appservice. + from_key = await self.store.get_type_stream_id_for_appservice( + appservice, "device_list" + ) + + # Fetch the users who have modified their device list since then. + users_with_changed_device_lists = ( + await self.store.get_users_whose_devices_changed(from_key, to_key=new_key) + ) + + # Filter out any users the application service is not interested in + # + # For each user who changed their device list, we want to check whether this + # appservice would be interested in the change. + filtered_users_with_changed_device_lists = { + user_id + for user_id in users_with_changed_device_lists + if await self._is_appservice_interested_in_device_lists_of_user( + appservice, user_id + ) + } + + # Create a summary of "changed" and "left" users. + # TODO: Calculate "left" users. + device_list_summary = DeviceListUpdates( + changed=filtered_users_with_changed_device_lists + ) + + return device_list_summary + + async def _is_appservice_interested_in_device_lists_of_user( + self, + appservice: ApplicationService, + user_id: str, + ) -> bool: + """ + Returns whether a given application service is interested in the device list + updates of a given user. + + The application service is interested in the user's device list updates if any + of the following are true: + * The user is the appservice's sender localpart user. + * The user is in the appservice's user namespace. + * At least one member of one room that the user is a part of is in the + appservice's user namespace. + * The appservice is explicitly (via room ID or alias) interested in at + least one room that the user is in. + + Args: + appservice: The application service to gauge interest of. + user_id: The ID of the user whose device list interest is in question. + + Returns: + True if the application service is interested in the user's device lists, False + otherwise. + """ + # This method checks against both the sender localpart user as well as if the + # user is in the appservice's user namespace. + if appservice.is_interested_in_user(user_id): + return True + + # Determine whether any of the rooms the user is in justifies sending this + # device list update to the application service. + room_ids = await self.store.get_rooms_for_user(user_id) + for room_id in room_ids: + # This method covers checking room members for appservice interest as well as + # room ID and alias checks. + if await appservice.is_interested_in_room(room_id, self.store): + return True + + return False + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index bceafca3b1..303c38c746 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,17 +13,7 @@ # limitations under the License. import itertools import logging -from typing import ( - TYPE_CHECKING, - Any, - Collection, - Dict, - FrozenSet, - List, - Optional, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -41,6 +31,7 @@ from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( + DeviceListUpdates, JsonDict, MutableStateMap, Requester, @@ -184,21 +175,6 @@ class GroupsSyncResult: return bool(self.join or self.invite or self.leave) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DeviceLists: - """ - Attributes: - changed: List of user_ids whose devices may have changed - left: List of user_ids whose devices we no longer track - """ - - changed: Collection[str] - left: Collection[str] - - def __bool__(self) -> bool: - return bool(self.changed or self.left) - - @attr.s(slots=True, auto_attribs=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined @@ -240,7 +216,7 @@ class SyncResult: knocked: List[KnockedSyncResult] archived: List[ArchivedSyncResult] to_device: List[JsonDict] - device_lists: DeviceLists + device_lists: DeviceListUpdates device_one_time_keys_count: JsonDict device_unused_fallback_key_types: List[str] groups: Optional[GroupsSyncResult] @@ -1264,8 +1240,8 @@ class SyncHandler: newly_joined_or_invited_or_knocked_users: Set[str], newly_left_rooms: Set[str], newly_left_users: Set[str], - ) -> DeviceLists: - """Generate the DeviceLists section of sync + ) -> DeviceListUpdates: + """Generate the DeviceListUpdates section of sync Args: sync_result_builder @@ -1383,9 +1359,11 @@ class SyncHandler: if any(e.room_id in joined_rooms for e in entries): newly_left_users.discard(user_id) - return DeviceLists(changed=users_that_have_changed, left=newly_left_users) + return DeviceListUpdates( + changed=users_that_have_changed, left=newly_left_users + ) else: - return DeviceLists(changed=[], left=[]) + return DeviceListUpdates() async def _generate_sync_entry_for_to_device( self, sync_result_builder: "SyncResultBuilder" diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index abea4383c7..55e1ab099d 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -29,7 +29,7 @@ from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.types import JsonDict +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import _CacheContext, cached @@ -217,6 +217,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -231,6 +232,7 @@ class ApplicationServiceTransactionWorkerStore( appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. + device_list_summary: The device list summary to include in the transaction. Returns: A new transaction. @@ -268,6 +270,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages=to_device_messages, one_time_key_counts=one_time_key_counts, unused_fallback_keys=unused_fallback_keys, + device_list_summary=device_list_summary, ) return await self.db_pool.runInteraction( @@ -359,8 +362,8 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) - # TODO: to-device messages, one-time key counts and unused fallback keys - # are not yet populated for catch-up transactions. + # TODO: to-device messages, one-time key counts, device list summaries and unused + # fallback keys are not yet populated for catch-up transactions. # We likely want to populate those for reliability. return AppServiceTransaction( service=service, @@ -370,6 +373,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages=[], one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: @@ -430,7 +434,7 @@ class ApplicationServiceTransactionWorkerStore( async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: - if type not in ("read_receipt", "presence", "to_device"): + if type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (type,) @@ -458,7 +462,7 @@ class ApplicationServiceTransactionWorkerStore( async def set_appservice_stream_type_pos( self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if stream_type not in ("read_receipt", "presence", "to_device"): + if stream_type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (stream_type,) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3b3a089b76..f08f7834d3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore): return self._device_list_stream_cache.get_all_entities_changed(from_key) async def get_users_whose_devices_changed( - self, from_key: int, user_ids: Iterable[str] + self, + from_key: int, + user_ids: Optional[Iterable[str]] = None, + to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key: The device lists stream token - user_ids: The user IDs to query for devices. + from_key: The minimum device lists stream token to query device list changes for, + exclusive. + user_ids: If provided, only check if these users have changed their device lists. + Otherwise changes from all users are returned. + to_key: The maximum device lists stream token to query device list changes for, + inclusive. Returns: - The set of user_ids whose devices have changed since `from_key` + The set of user_ids whose devices have changed since `from_key` (exclusive) + until `to_key` (inclusive). """ - # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + if user_ids is None: + # Get set of all users that have had device list changes since 'from_key' + user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( + from_key + ) + else: + # The same as above, but filter results to only those users in 'user_ids' + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) - if not to_check: + if not user_ids_to_check: return set() def _get_users_whose_devices_changed_txn(txn): changes = set() - sql = """ + stream_id_where_clause = "stream_id > ?" + sql_args = [from_key] + + if to_key: + stream_id_where_clause += " AND stream_id <= ?" + sql_args.append(to_key) + + sql = f""" SELECT DISTINCT user_id FROM device_lists_stream - WHERE stream_id > ? + WHERE {stream_id_where_clause} AND """ - for chunk in batch_iter(to_check, 100): + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "user_id", chunk ) - txn.execute(sql + clause, (from_key,) + tuple(args)) + txn.execute(sql + clause, sql_args + args) changes.update(user_id for user_id, in txn) return changes diff --git a/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql new file mode 100644 index 0000000000..7590e34b94 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to track what device list changes stream id that this application +-- service has been caught up to. + +-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible +-- state is useful for determining if we've ever sent traffic for a stream type +-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for +-- one way this can be used. +ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT; \ No newline at end of file diff --git a/synapse/types.py b/synapse/types.py index 5ce2a5b0a5..500597b3a4 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -25,6 +25,7 @@ from typing import ( Match, MutableMapping, Optional, + Set, Tuple, Type, TypeVar, @@ -748,6 +749,30 @@ class ReadReceipt: data: JsonDict +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DeviceListUpdates: + """ + An object containing a diff of information regarding other users' device lists, intended for + a recipient to carry out device list tracking. + + Attributes: + changed: A set of users whose device lists have changed recently. + left: A set of users who the recipient no longer needs to track the device lists of. + Typically when those users no longer share any end-to-end encryption enabled rooms. + """ + + # We need to use a factory here, otherwise `set` is not evaluated at + # object instantiation, but instead at class definition instantiation. + # The latter happening only once, thus always giving you the same sets + # across multiple DeviceListUpdates instances. + # Also see: don't define mutable default arguments. + changed: Set[str] = attr.ib(factory=set) + left: Set[str] = attr.ib(factory=set) + + def __bool__(self) -> bool: + return bool(self.changed or self.left) + + def get_verify_key_from_cross_signing_key(key_info): """Get the key ID and signedjson verify key from a cross-signing key dict diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 1cbb059357..0b22afdc75 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -24,6 +24,7 @@ from synapse.appservice.scheduler import ( ) from synapse.logging.context import make_deferred_yieldable from synapse.server import HomeServer +from synapse.types import DeviceListUpdates from synapse.util import Clock from tests import unittest @@ -70,6 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): to_device_messages=[], # txn made and saved one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -96,6 +98,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): to_device_messages=[], # txn made and saved one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(0, txn.send.call_count) # txn not sent though self.assertEqual(0, txn.complete.call_count) # or completed @@ -124,6 +127,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): to_device_messages=[], one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made self.assertEqual(1, self.recoverer.recover.call_count) # and invoked @@ -225,7 +229,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) + self.txn_ctrl.send.assert_called_once_with( + service, [event], [], [], None, None, DeviceListUpdates() + ) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -240,12 +246,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [event], [], [], None, None, DeviceListUpdates() + ) self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) self.txn_ctrl.send.assert_called_with( - service, [event2, event3], [], [], None, None + service, [event2, event3], [], [], None, None, DeviceListUpdates() ) self.assertEqual(2, self.txn_ctrl.send.call_count) @@ -272,15 +280,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv1, [srv_1_event], [], [], None, None, DeviceListUpdates() + ) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv2, [srv_2_event], [], [], None, None, DeviceListUpdates() + ) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv2, [srv_2_event2], [], [], None, None, DeviceListUpdates() + ) self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -300,17 +314,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Expect the first event to be sent immediately. self.txn_ctrl.send.assert_called_with( - service, [event_list[0]], [], [], None, None + service, [event_list[0]], [], [], None, None, DeviceListUpdates() ) srv_1_defer.callback(service) # Then send the next 100 events self.txn_ctrl.send.assert_called_with( - service, event_list[1:101], [], [], None, None + service, event_list[1:101], [], [], None, None, DeviceListUpdates() ) srv_2_defer.callback(service) # Then the final 99 events self.txn_ctrl.send.assert_called_with( - service, event_list[101:], [], [], None, None + service, event_list[101:], [], [], None, None, DeviceListUpdates() ) self.assertEqual(3, self.txn_ctrl.send.call_count) @@ -320,7 +334,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], event_list, [], None, None + service, [], event_list, [], None, None, DeviceListUpdates() ) def test_send_multiple_ephemeral_no_queue(self): @@ -329,7 +343,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], event_list, [], None, None + service, [], event_list, [], None, None, DeviceListUpdates() ) def test_send_single_ephemeral_with_queue(self): @@ -345,13 +359,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [], event_list_1, [], None, None, DeviceListUpdates() + ) self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [], None, None + service, + [], + event_list_2 + event_list_3, + [], + None, + None, + DeviceListUpdates(), ) self.assertEqual(2, self.txn_ctrl.send.call_count) @@ -365,8 +387,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], first_chunk, [], None, None + service, [], first_chunk, [], None, None, DeviceListUpdates() ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [], second_chunk, [], None, None, DeviceListUpdates() + ) self.assertEqual(2, self.txn_ctrl.send.call_count) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index cead9f90df..8c72cf6b30 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -15,6 +15,8 @@ from typing import Dict, Iterable, List, Optional from unittest.mock import Mock +from parameterized import parameterized + from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -471,6 +473,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): to_device_messages, _otks, _fbks, + _device_list_summary, ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent @@ -583,7 +586,15 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): service_id_to_message_count: Dict[str, int] = {} for call in self.send_mock.call_args_list: - service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = call[0] # Check that this was made to an interested service self.assertIn(service, interested_appservices) @@ -627,6 +638,114 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): return appservice +class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase): + """ + Tests that the ApplicationServicesHandler sends device list updates to application + services correctly. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Allow us to modify cached feature flags mid-test + self.as_handler = hs.get_application_service_handler() + + # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that + # will be sent over the wire + self.put_json = simple_async_mock() + hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] + + # Mock out application services, and allow defining our own in tests + self._services: List[ApplicationService] = [] + self.hs.get_datastores().main.get_app_services = Mock( + return_value=self._services + ) + + # Test across a variety of configuration values + @parameterized.expand( + [ + (True, True, True), + (True, False, False), + (False, True, False), + (False, False, False), + ] + ) + def test_application_service_receives_device_list_updates( + self, + experimental_feature_enabled: bool, + as_supports_txn_extensions: bool, + as_should_receive_device_list_updates: bool, + ): + """ + Tests that an application service receives notice of changed device + lists for a user, when a user changes their device lists. + + Arguments above are populated by parameterized. + + Args: + as_should_receive_device_list_updates: Whether we expect the AS to receive the + device list changes. + experimental_feature_enabled: Whether the "msc3202_transaction_extensions" experimental + feature is enabled. This feature must be enabled for device lists to ASs to work. + as_supports_txn_extensions: Whether the application service has explicitly registered + to receive information defined by MSC3202 - which includes device list changes. + """ + # Change whether the experimental feature is enabled or disabled before making + # device list changes + self.as_handler._msc3202_transaction_extensions_enabled = ( + experimental_feature_enabled + ) + + # Create an appservice that is interested in "local_user" + appservice = ApplicationService( + token=random_string(10), + hostname="example.com", + id=random_string(10), + sender="@as:example.com", + rate_limited=False, + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@local_user:.+", + "exclusive": False, + } + ], + }, + supports_ephemeral=True, + msc3202_transaction_extensions=as_supports_txn_extensions, + # Must be set for Synapse to try pushing data to the AS + hs_token="abcde", + url="some_url", + ) + + # Register the application service + self._services.append(appservice) + + # Register a user on the homeserver + self.local_user = self.register_user("local_user", "password") + self.local_user_token = self.login("local_user", "password") + + if as_should_receive_device_list_updates: + # Ensure that the resulting JSON uses the unstable prefix and contains the + # expected users + self.put_json.assert_called_once() + json_body = self.put_json.call_args[1]["json_body"] + + # Our application service should have received a device list update with + # "local_user" in the "changed" list + device_list_dict = json_body.get("org.matrix.msc3202.device_lists", {}) + self.assertEqual([], device_list_dict["left"]) + self.assertEqual([self.local_user], device_list_dict["changed"]) + + else: + # No device list changes should have been sent out + self.put_json.assert_not_called() + + class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): # Argument indices for pulling out arguments from a `send_mock`. ARG_OTK_COUNTS = 4 diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 6249eb8c11..97de9a59e0 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.types import DeviceListUpdates from synapse.util import Clock from tests import unittest @@ -267,7 +268,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) ) self.assertEqual(txn.id, 1) @@ -283,7 +286,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) self.assertEqual(txn.id, 9646) self.assertEqual(txn.events, events) @@ -296,7 +301,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) self.assertEqual(txn.id, 9644) self.assertEqual(txn.events, events) @@ -320,7 +327,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) self.assertEqual(txn.id, 9644) self.assertEqual(txn.events, events) -- cgit 1.4.1 From bebf994ee804ef63ce16801c6694713fcd685320 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 31 Mar 2022 15:05:13 -0400 Subject: Move MSC2654 support behind an experimental configuration flag. (#12295) To match the current thinking on disabling experimental features by default. --- changelog.d/12295.misc | 1 + synapse/config/experimental.py | 3 +++ synapse/rest/client/sync.py | 4 +++- tests/rest/client/test_sync.py | 5 +++++ 4 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12295.misc (limited to 'synapse/config') diff --git a/changelog.d/12295.misc b/changelog.d/12295.misc new file mode 100644 index 0000000000..9c34e16909 --- /dev/null +++ b/changelog.d/12295.misc @@ -0,0 +1 @@ +Move [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654) support behind an experimental configuration flag. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d6bb1f752b..43db5fcdd9 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -78,3 +78,6 @@ class ExperimentalConfig(Config): # The deprecated groups feature. self.groups_enabled: bool = experimental.get("groups_enabled", True) + + # MSC2654: Unread counts + self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 53c385a86c..0bf32f873b 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -99,6 +99,7 @@ class SyncRestServlet(RestServlet): self.presence_handler = hs.get_presence_handler() self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() + self._msc2654_enabled = hs.config.experimental.msc2654_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: # This will always be set by the time Twisted calls us. @@ -521,7 +522,8 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary - result["org.matrix.msc2654.unread_count"] = room.unread_count + if self._msc2654_enabled: + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 3cebbd18f0..773c16a54c 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -496,6 +496,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc2654_enabled": True} + return config + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" -- cgit 1.4.1 From 5c9e39e6192e952ba8a5bb8e5485bc6067f91699 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2022 15:25:20 +0100 Subject: Track device list updates per room. (#12321) This is a first step in dealing with #7721. The idea is basically that rather than calculating the full set of users a device list update needs to be sent to up front, we instead simply record the rooms the user was in at the time of the change. This will allow a few things: 1. we can defer calculating the set of remote servers that need to be poked about the change; and 2. during `/sync` and `/keys/changes` we can avoid also avoid calculating users who share rooms with other users, and instead just look at the rooms that have changed. However, care needs to be taken to correctly handle server downgrades. As such this PR writes to both `device_lists_changes_in_room` and the `device_lists_outbound_pokes` table synchronously. In a future release we can then bump the database schema compat version to `69` and then we can assume that the new `device_lists_changes_in_room` exists and is handled. There is a temporary option to disable writing to `device_lists_outbound_pokes` synchronously, allowing us to test the new code path does work (and by implication upgrading to a future release and downgrading to this one will work correctly). Note: Ideally we'd do the calculation of room to servers on a worker (e.g. the background worker), but currently only master can write to the `device_list_outbound_pokes` table. --- changelog.d/12321.misc | 1 + synapse/_scripts/synapse_port_db.py | 1 + synapse/config/server.py | 8 + synapse/handlers/device.py | 132 +++++++++++-- synapse/replication/slave/storage/devices.py | 1 + synapse/storage/databases/main/__init__.py | 1 + synapse/storage/databases/main/devices.py | 217 ++++++++++++++++++--- synapse/storage/schema/__init__.py | 1 + .../delta/69/01device_list_oubound_by_room.sql | 38 ++++ tests/federation/test_federation_sender.py | 23 ++- tests/storage/test_devices.py | 14 +- 11 files changed, 390 insertions(+), 47 deletions(-) create mode 100644 changelog.d/12321.misc create mode 100644 synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql (limited to 'synapse/config') diff --git a/changelog.d/12321.misc b/changelog.d/12321.misc new file mode 100644 index 0000000000..200e7c44fe --- /dev/null +++ b/changelog.d/12321.misc @@ -0,0 +1 @@ +Add ground work for speeding up device list updates for users in large numbers of rooms. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index c38666da18..6324df883b 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -97,6 +97,7 @@ BOOLEAN_COLUMNS = { "users": ["shadow_banned"], "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], + "device_lists_changes_in_room": ["converted_to_destinations"], } diff --git a/synapse/config/server.py b/synapse/config/server.py index 0f90302c95..b3a9e50752 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -680,6 +680,14 @@ class ServerConfig(Config): config.get("use_account_validity_in_account_status") or False ) + # This is a temporary option that enables fully using the new + # `device_lists_changes_in_room` without the backwards compat code. This + # is primarily for testing. If enabled the server should *not* be + # downgraded, as it may lead to missing device list updates. + self.use_new_device_lists_changes_in_room = ( + config.get("use_new_device_lists_changes_in_room") or False + ) + self.rooms_to_exclude_from_sync: List[str] = ( config.get("exclude_rooms_from_sync") or [] ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d5ccaa0c37..c710c02cf9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -37,7 +37,10 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import ( JsonDict, StreamToken, @@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) + # Whether `_handle_new_device_update_async` is currently processing. + self._handle_new_device_update_is_processing = False + + # If a new device update may have happened while the loop was + # processing. + self._handle_new_device_update_new_data = False + + # On start up check if there are any updates pending. + hs.get_reactor().callWhenRunning(self._handle_new_device_update_async) + + # Used to decide if we calculate outbound pokes up front or not. By + # default we do to allow safely downgrading Synapse. + self.use_new_device_lists_changes_in_room = ( + hs.config.server.use_new_device_lists_changes_in_room + ) + def _check_device_name_length(self, name: Optional[str]) -> None: """ Checks whether a device name is longer than the maximum allowed length. @@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler): # No changes to notify about, so this is a no-op. return - users_who_share_room = await self.store.get_users_who_share_room_with_user( - user_id - ) + room_ids = await self.store.get_rooms_for_user(user_id) + + hosts: Optional[Set[str]] = None + if not self.use_new_device_lists_changes_in_room: + hosts = set() - hosts: Set[str] = set() - if self.hs.is_mine_id(user_id): - hosts.update(get_domain_from_id(u) for u in users_who_share_room) - hosts.discard(self.server_name) + if self.hs.is_mine_id(user_id): + for room_id in room_ids: + joined_users = await self.store.get_users_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in joined_users) - set_tag("target_hosts", hosts) + set_tag("target_hosts", hosts) + + hosts.discard(self.server_name) position = await self.store.add_device_change_to_streams( - user_id, device_ids, list(hosts) + user_id, + device_ids, + hosts=hosts, + room_ids=room_ids, ) if not position: @@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler): # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - users_to_notify = users_who_share_room.union({user_id}) + self.notifier.on_new_event( + "device_list_key", position, users={user_id}, rooms=room_ids + ) - self.notifier.on_new_event("device_list_key", position, users=users_to_notify) + # We may need to do some processing asynchronously. + self._handle_new_device_update_async() if hosts: logger.info( @@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler): return {"success": True} + @wrap_as_background_process("_handle_new_device_update_async") + async def _handle_new_device_update_async(self) -> None: + """Called when we have a new local device list update that we need to + send out over federation. + + This happens in the background so as not to block the original request + that generated the device update. + """ + if self._handle_new_device_update_is_processing: + self._handle_new_device_update_new_data = True + return + + self._handle_new_device_update_is_processing = True + + # The stream ID we processed previous iteration (if any), and the set of + # hosts we've already poked about for this update. This is so that we + # don't poke the same remote server about the same update repeatedly. + current_stream_id = None + hosts_already_sent_to: Set[str] = set() + + try: + while True: + self._handle_new_device_update_new_data = False + rows = await self.store.get_uncoverted_outbound_room_pokes() + if not rows: + # If the DB returned nothing then there is nothing left to + # do, *unless* a new device list update happened during the + # DB query. + if self._handle_new_device_update_new_data: + continue + else: + return + + for user_id, device_id, room_id, stream_id, opentracing_context in rows: + joined_user_ids = await self.store.get_users_in_room(room_id) + hosts = {get_domain_from_id(u) for u in joined_user_ids} + hosts.discard(self.server_name) + + # Check if we've already sent this update to some hosts + if current_stream_id == stream_id: + hosts -= hosts_already_sent_to + + await self.store.add_device_list_outbound_pokes( + user_id=user_id, + device_id=device_id, + room_id=room_id, + stream_id=stream_id, + hosts=hosts, + context=opentracing_context, + ) + + # Notify replication that we've updated the device list stream. + self.notifier.notify_replication() + + if hosts: + logger.info( + "Sending device list update notif for %r to: %r", + user_id, + hosts, + ) + for host in hosts: + self.federation_sender.send_device_messages( + host, immediate=False + ) + log_kv( + {"message": "sent device update to host", "host": host} + ) + + if current_stream_id != stream_id: + # Clear the set of hosts we've already sent to as we're + # processing a new update. + hosts_already_sent_to.clear() + + hosts_already_sent_to.update(hosts) + current_stream_id = stream_id + + finally: + self._handle_new_device_update_is_processing = False + def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 0ffd34f1da..f040e33bfb 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -44,6 +44,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), ], ) device_list_max = self._device_list_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 1ea0b2aa6f..cdbe3872fa 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -146,6 +146,7 @@ class DataStore( extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), ], ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index f08f7834d3..07eea4b3d2 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -810,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes ) AS e WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? """ @@ -1528,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def add_device_change_to_streams( - self, user_id: str, device_ids: Collection[str], hosts: Collection[str] + self, + user_id: str, + device_ids: Collection[str], + hosts: Optional[Collection[str]], + room_ids: Collection[str], ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. @@ -1537,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: The ID of the user whose device changed. device_ids: The IDs of any changed devices. If empty, this function will return None. - hosts: The remote destinations that should be notified of the change. + hosts: The remote destinations that should be notified of the change. If + None then the set of hosts have *not* been calculated, and will be + calculated later by a background task. + room_ids: The rooms that the user is in Returns: The maximum stream ID of device list updates that were added to the database, or @@ -1546,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return None - async with self._device_list_id_gen.get_next_mult( - len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_change_to_stream", - self._add_device_change_to_stream_txn, + context = get_active_span_text_map() + + def add_device_changes_txn( + txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes + ): + self._add_device_change_to_stream_txn( + txn, user_id, device_ids, - stream_ids, + stream_ids_for_device_change, ) - if not hosts: - return stream_ids[-1] + self._add_device_outbound_room_poke_txn( + txn, + user_id, + device_ids, + room_ids, + stream_ids_for_device_change, + context, + hosts_have_been_calculated=hosts is not None, + ) - context = get_active_span_text_map() - async with self._device_list_id_gen.get_next_mult( - len(hosts) * len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_outbound_poke_to_stream", - self._add_device_outbound_poke_to_stream_txn, + # If the set of hosts to send to has not been calculated yet (and so + # `hosts` is None) or there are no `hosts` to send to, then skip + # trying to persist them to the DB. + if not hosts: + return + + self._add_device_outbound_poke_to_stream_txn( + txn, user_id, device_ids, hosts, - stream_ids, + stream_ids_for_outbound_pokes, context, ) + # `device_lists_stream` wants a stream ID per device update. + num_stream_ids = len(device_ids) + + if hosts: + # `device_lists_outbound_pokes` wants a different stream ID for + # each row, which is a row per host per device update. + num_stream_ids += len(hosts) * len(device_ids) + + async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids: + stream_ids_for_device_change = stream_ids[: len(device_ids)] + stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :] + + await self.db_pool.runInteraction( + "add_device_change_to_stream", + add_device_changes_txn, + stream_ids_for_device_change, + stream_ids_for_outbound_pokes, + ) + return stream_ids[-1] def _add_device_change_to_stream_txn( @@ -1617,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_ids: Iterable[str], hosts: Collection[str], - stream_ids: List[str], + stream_ids: List[int], context: Dict[str, str], ) -> None: for host in hosts: @@ -1628,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) now = self._clock.time_msec() - next_stream_id = iter(stream_ids) + stream_id_iterator = iter(stream_ids) + encoded_context = json_encoder.encode(context) self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", @@ -1645,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=[ ( destination, - next(next_stream_id), + next(stream_id_iterator), user_id, device_id, False, now, - json_encoder.encode(context) - if whitelisted_homeserver(destination) - else "{}", + encoded_context if whitelisted_homeserver(destination) else "{}", ) for destination in hosts for device_id in device_ids ], ) + + def _add_device_outbound_room_poke_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Iterable[str], + room_ids: Collection[str], + stream_ids: List[str], + context: Dict[str, str], + hosts_have_been_calculated: bool, + ) -> None: + """Record the user in the room has updated their device. + + Args: + hosts_have_been_calculated: True if `device_lists_outbound_pokes` + has been updated already with the updates. + """ + + # We only need to convert to outbound pokes if they are our user. + converted_to_destinations = ( + hosts_have_been_calculated or not self.hs.is_mine_id(user_id) + ) + + encoded_context = json_encoder.encode(context) + + # The `device_lists_changes_in_room.stream_id` column matches the + # corresponding `stream_id` of the update in the `device_lists_stream` + # table, i.e. all rows persisted for the same device update will have + # the same `stream_id` (but different room IDs). + self.db_pool.simple_insert_many_txn( + txn, + table="device_lists_changes_in_room", + keys=( + "user_id", + "device_id", + "room_id", + "stream_id", + "converted_to_destinations", + "opentracing_context", + ), + values=[ + ( + user_id, + device_id, + room_id, + stream_id, + converted_to_destinations, + encoded_context, + ) + for room_id in room_ids + for device_id, stream_id in zip(device_ids, stream_ids) + ], + ) + + async def get_uncoverted_outbound_room_pokes( + self, limit: int = 10 + ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: + """Get device list changes by room that have not yet been handled and + written to `device_lists_outbound_pokes`. + + Returns: + A list of user ID, device ID, room ID, stream ID and optional opentracing context. + """ + + sql = """ + SELECT user_id, device_id, room_id, stream_id, opentracing_context + FROM device_lists_changes_in_room + WHERE NOT converted_to_destinations + ORDER BY stream_id + LIMIT ? + """ + + def get_uncoverted_outbound_room_pokes_txn(txn): + txn.execute(sql, (limit,)) + return txn.fetchall() + + return await self.db_pool.runInteraction( + "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn + ) + + async def add_device_list_outbound_pokes( + self, + user_id: str, + device_id: str, + room_id: str, + stream_id: int, + hosts: Collection[str], + context: Optional[Dict[str, str]], + ) -> None: + """Queue the device update to be sent to the given set of hosts, + calculated from the room ID. + + Marks the associated row in `device_lists_changes_in_room` as handled. + """ + + def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): + if hosts: + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_ids=[device_id], + hosts=hosts, + stream_ids=stream_ids, + context=context, + ) + + self.db_pool.simple_update_txn( + txn, + table="device_lists_changes_in_room", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "stream_id": stream_id, + "room_id": room_id, + }, + updatevalues={"converted_to_destinations": True}, + ) + + if not hosts: + # If there are no hosts then we don't try and generate stream IDs. + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + [], + ) + + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + stream_ids, + ) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index ea900e0f3d..151f2aa9bb 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -60,6 +60,7 @@ Changes in SCHEMA_VERSION = 68: new events. Changes in SCHEMA_VERSION = 69: + - We now write to `device_lists_changes_in_room` table. - Use sequence to generate future `application_services_txns.txn_id`s """ diff --git a/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql new file mode 100644 index 0000000000..b5b1782b2a --- /dev/null +++ b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql @@ -0,0 +1,38 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_lists_changes_in_room ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + room_id TEXT NOT NULL, + + -- This initially matches `device_lists_stream.stream_id`. Note that we + -- delete older values from `device_lists_stream`, so we can't use a foreign + -- constraint here. + -- + -- The table will contain rows with the same `stream_id` but different + -- `room_id`, as for each device update we store a row per room the user is + -- joined to. Therefore `(stream_id, room_id)` gives a unique index. + stream_id BIGINT NOT NULL, + + -- We have a background process which goes through this table and converts + -- entries into rows in `device_lists_outbound_pokes`. Once we have processed + -- a row, we mark it as such by setting `converted_to_destinations=TRUE`. + converted_to_destinations BOOLEAN NOT NULL, + opentracing_context TEXT +); + +CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id); +CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations; diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index e90592855a..a6e91956af 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -14,6 +14,7 @@ from typing import Optional from unittest.mock import Mock +from parameterized import parameterized_class from signedjson import key, sign from signedjson.types import BaseKey, SigningKey @@ -154,6 +155,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ) +@parameterized_class( + [ + {"enable_room_poke_code_path": False}, + {"enable_room_poke_code_path": True}, + ] +) class FederationSenderDevicesTestCases(HomeserverTestCase): servlets = [ admin.register_servlets, @@ -168,17 +175,21 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def default_config(self): c = super().default_config() c["send_federation"] = True + c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path return c def prepare(self, reactor, clock, hs): - # stub out get_users_who_share_room_with_user so that it claims that - # `@user2:host2` is in the room - def get_users_who_share_room_with_user(user_id): + # stub out `get_rooms_for_user` and `get_users_in_room` so that the + # server thinks the user shares a room with `@user2:host2` + def get_rooms_for_user(user_id): + return defer.succeed({"!room:host1"}) + + hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user + + def get_users_in_room(room_id): return defer.succeed({"@user2:host2"}) - hs.get_datastores().main.get_users_who_share_room_with_user = ( - get_users_who_share_room_with_user - ) + hs.get_datastores().main.get_users_in_room = get_users_in_room # whenever send_transaction is called, record the edu data self.edus = [] diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 21ffc5a909..d1227dd4ac 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase): # Add two device updates with sequential `stream_id`s self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get all device updates ever meant for this remote @@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase): "device_id5", ] self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get device updates meant for this remote @@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase): # Add some more device updates to ensure it still resumes properly device_ids = ["device_id6", "device_id7"] self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get the next batch of device updates @@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase): self.get_success( self.store.add_device_change_to_streams( - "@user_id:test", device_ids, ["somehost"] + "@user_id:test", device_ids, ["somehost"], ["!some:room"] ) ) -- cgit 1.4.1