From f9f03426de338ae1879e174f63adf698bbfc3a4b Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 19 Aug 2022 17:17:10 +0100 Subject: Implement MSC3852: Expose `last_seen_user_agent` to users for their own devices; also expose to Admin API (#13549) --- synapse/handlers/device.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'synapse/handlers/device.py') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1a8379854c..f5c586f657 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -74,6 +74,7 @@ class DeviceWorkerHandler: self._state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() self.server_name = hs.hostname + self._msc3852_enabled = hs.config.experimental.msc3852_enabled @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: @@ -747,7 +748,13 @@ def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) - device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) + device.update( + { + "last_seen_user_agent": ip.get("user_agent"), + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + } + ) class DeviceListUpdater: -- cgit 1.5.1 From 1a209efdb2a6c51e52dd277de7581099852877ae Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 24 Aug 2022 14:15:37 -0500 Subject: Update `get_users_in_room` mis-use to get hosts with dedicated `get_current_hosts_in_room` (#13605) See https://github.com/matrix-org/synapse/pull/13575#discussion_r953023755 --- changelog.d/13605.misc | 1 + synapse/handlers/device.py | 8 ++++++-- synapse/handlers/directory.py | 12 +++++++----- synapse/handlers/presence.py | 3 +-- synapse/handlers/typing.py | 7 ++++--- tests/federation/test_federation_sender.py | 17 ++++++++++++----- 6 files changed, 31 insertions(+), 17 deletions(-) create mode 100644 changelog.d/13605.misc (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/13605.misc b/changelog.d/13605.misc new file mode 100644 index 0000000000..88d518383b --- /dev/null +++ b/changelog.d/13605.misc @@ -0,0 +1 @@ +Refactor `get_users_in_room(room_id)` mis-use with dedicated `get_current_hosts_in_room(room_id)` function. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f5c586f657..9c2c3a0e68 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -310,6 +310,7 @@ class DeviceHandler(DeviceWorkerHandler): super().__init__(hs) self.federation_sender = hs.get_federation_sender() + self._storage_controllers = hs.get_storage_controllers() self.device_list_updater = DeviceListUpdater(hs, self) @@ -694,8 +695,11 @@ class DeviceHandler(DeviceWorkerHandler): # Ignore any users that aren't ours if self.hs.is_mine_id(user_id): - joined_user_ids = await self.store.get_users_in_room(room_id) - hosts = {get_domain_from_id(u) for u in joined_user_ids} + hosts = set( + await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) + ) hosts.discard(self.server_name) # Check if we've already sent this update to some hosts diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 948f66a94d..7127d5aefc 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -30,7 +30,7 @@ from synapse.api.errors import ( from synapse.appservice import ApplicationService from synapse.module_api import NOT_SPAM from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id +from synapse.types import JsonDict, Requester, RoomAlias if TYPE_CHECKING: from synapse.server import HomeServer @@ -83,8 +83,9 @@ class DirectoryHandler: # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - users = await self.store.get_users_in_room(room_id) - servers = {get_domain_from_id(u) for u in users} + servers = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) if not servers: raise SynapseError(400, "Failed to get server list") @@ -287,8 +288,9 @@ class DirectoryHandler: Codes.NOT_FOUND, ) - users = await self.store.get_users_in_room(room_id) - extra_servers = {get_domain_from_id(u) for u in users} + extra_servers = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) servers_set = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 741504ba9f..4e575ffbaa 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -2051,8 +2051,7 @@ async def get_interested_remotes( ) for room_id, states in room_ids_to_states.items(): - user_ids = await store.get_users_in_room(room_id) - hosts = {get_domain_from_id(user_id) for user_id in user_ids} + hosts = await store.get_current_hosts_in_room(room_id) for host in hosts: hosts_and_states.setdefault(host, set()).update(states) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index bcac3372a2..a4cd8b8f0c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.replication.tcp.streams import TypingStream from synapse.streams import EventSource -from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, StreamKeyType, UserID from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -362,8 +362,9 @@ class TypingWriterHandler(FollowerTypingHandler): ) return - users = await self.store.get_users_in_room(room_id) - domains = {get_domain_from_id(u) for u in users} + domains = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) if self.server_name in domains: logger.info("Got typing update from %s: %r", user_id, content) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 01a1db6115..a5aa500ef8 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -173,17 +173,24 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - # stub out `get_rooms_for_user` and `get_users_in_room` so that the + test_room_id = "!room:host1" + + # stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the # server thinks the user shares a room with `@user2:host2` def get_rooms_for_user(user_id): - return defer.succeed({"!room:host1"}) + return defer.succeed({test_room_id}) hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user - def get_users_in_room(room_id): - return defer.succeed({"@user2:host2"}) + async def get_current_hosts_in_room(room_id): + if room_id == test_room_id: + return ["host2"] + + # TODO: We should fail the test when we encounter an unxpected room ID. + # We can't just use `self.fail(...)` here because the app code is greedy + # with `Exception` and will catch it before the test can see it. - hs.get_datastores().main.get_users_in_room = get_users_in_room + hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # whenever send_transaction is called, record the edu data self.edus = [] -- cgit 1.5.1 From d3d9ca156e323fe194b1bcb1af1628f65a2f3c1c Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 7 Sep 2022 11:03:32 +0000 Subject: Cancel the processing of key query requests when they time out. (#13680) --- changelog.d/13680.feature | 1 + synapse/api/auth.py | 5 +++ synapse/handlers/device.py | 3 ++ synapse/handlers/e2e_keys.py | 40 +++++++++++++--------- synapse/rest/client/keys.py | 6 ++-- synapse/storage/controllers/state.py | 4 +++ synapse/storage/databases/main/devices.py | 4 +++ synapse/storage/databases/main/end_to_end_keys.py | 5 ++- synapse/storage/databases/main/event_federation.py | 2 ++ synapse/storage/databases/main/events_worker.py | 4 +++ synapse/storage/databases/main/roommember.py | 2 ++ synapse/storage/databases/main/state.py | 2 ++ synapse/storage/databases/main/stream.py | 2 ++ synapse/storage/databases/state/store.py | 3 ++ .../storage/util/partial_state_events_tracker.py | 3 ++ synapse/types.py | 5 +++ tests/http/server/_base.py | 10 +++++- tests/rest/client/test_keys.py | 29 ++++++++++++++++ 18 files changed, 110 insertions(+), 20 deletions(-) create mode 100644 changelog.d/13680.feature (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/13680.feature b/changelog.d/13680.feature new file mode 100644 index 0000000000..4234c7e082 --- /dev/null +++ b/changelog.d/13680.feature @@ -0,0 +1 @@ +Cancel the processing of key query requests when they time out. \ No newline at end of file diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9a1aea083f..8e54ef84b2 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -38,6 +38,7 @@ from synapse.logging.opentracing import ( trace, ) from synapse.types import Requester, create_requester +from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer @@ -118,6 +119,7 @@ class Auth: errcode=Codes.NOT_JOINED, ) + @cancellable async def get_user_by_req( self, request: SynapseRequest, @@ -166,6 +168,7 @@ class Auth: parent_span.set_tag("appservice_id", requester.app_service.id) return requester + @cancellable async def _wrapped_get_user_by_req( self, request: SynapseRequest, @@ -281,6 +284,7 @@ class Auth: 403, "Application service has not registered this user (%s)" % user_id ) + @cancellable async def _get_appservice_user(self, request: Request) -> Optional[Requester]: """ Given a request, reads the request parameters to determine: @@ -523,6 +527,7 @@ class Auth: return bool(query_params) or bool(auth_headers) @staticmethod + @cancellable def get_access_token_from_request(request: Request) -> str: """Extracts the access_token from the request. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9c2c3a0e68..c5ac169644 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -52,6 +52,7 @@ from synapse.types import ( from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.cancellation import cancellable from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination @@ -124,6 +125,7 @@ class DeviceWorkerHandler: return device + @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken ) -> Collection[str]: @@ -163,6 +165,7 @@ class DeviceWorkerHandler: @trace @measure_func("device.get_user_ids_changed") + @cancellable async def get_user_ids_changed( self, user_id: str, from_token: StreamToken ) -> JsonDict: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index c938339ddd..ec81639c78 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -37,7 +37,8 @@ from synapse.types import ( get_verify_key_from_cross_signing_key, ) from synapse.util import json_decoder, unwrapFirstError -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import Linearizer, delay_cancellation +from synapse.util.cancellation import cancellable from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: @@ -91,6 +92,7 @@ class E2eKeysHandler: ) @trace + @cancellable async def query_devices( self, query_body: JsonDict, @@ -208,22 +210,26 @@ class E2eKeysHandler: r[user_id] = remote_queries[user_id] # Now fetch any devices that we don't have in our cache + # TODO It might make sense to propagate cancellations into the + # deferreds which are querying remote homeservers. await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self._query_devices_for_destination, - results, - cross_signing_keys, - failures, - destination, - queries, - timeout, - ) - for destination, queries in remote_queries_not_in_cache.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + delay_cancellation( + defer.gatherResults( + [ + run_in_background( + self._query_devices_for_destination, + results, + cross_signing_keys, + failures, + destination, + queries, + timeout, + ) + for destination, queries in remote_queries_not_in_cache.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) ) ret = {"device_keys": results, "failures": failures} @@ -347,6 +353,7 @@ class E2eKeysHandler: return + @cancellable async def get_cross_signing_keys_from_cache( self, query: Iterable[str], from_user_id: Optional[str] ) -> Dict[str, Dict[str, dict]]: @@ -393,6 +400,7 @@ class E2eKeysHandler: } @trace + @cancellable async def query_local_devices( self, query: Mapping[str, Optional[List[str]]] ) -> Dict[str, Dict[str, dict]]: diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index a395694fa5..f653d2a3e1 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -27,9 +27,9 @@ from synapse.http.servlet import ( ) from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag +from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.types import JsonDict, StreamToken - -from ._base import client_patterns, interactive_auth_handler +from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer @@ -156,6 +156,7 @@ class KeyQueryServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() + @cancellable async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() @@ -199,6 +200,7 @@ class KeyChangesServlet(RestServlet): self.device_handler = hs.get_device_handler() self.store = hs.get_datastores().main + @cancellable async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index ba5380ce3e..bbe568bf05 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -36,6 +36,7 @@ from synapse.storage.util.partial_state_events_tracker import ( PartialStateEventsTracker, ) from synapse.types import MutableStateMap, StateMap +from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer @@ -229,6 +230,7 @@ class StateStorageController: @trace @tag_args + @cancellable async def get_state_ids_for_events( self, event_ids: Collection[str], @@ -350,6 +352,7 @@ class StateStorageController: @trace @tag_args + @cancellable async def get_state_group_for_events( self, event_ids: Collection[str], @@ -398,6 +401,7 @@ class StateStorageController: event_id, room_id, prev_group, delta_ids, current_state_ids ) + @cancellable async def get_current_state_ids( self, room_id: str, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index ca0fe8c4be..5d700ca6c3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -53,6 +53,7 @@ from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -668,6 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore): ... @trace + @cancellable async def get_user_devices_from_cache( self, query_list: List[Tuple[str, Optional[str]]] ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: @@ -743,6 +745,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore): return self._device_list_stream_cache.get_all_entities_changed(from_key) + @cancellable async def get_users_whose_devices_changed( self, from_key: int, @@ -1221,6 +1224,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore): desc="get_min_device_lists_changes_in_room", ) + @cancellable async def get_device_list_changes_in_rooms( self, room_ids: Collection[str], from_id: int ) -> Optional[Set[str]]: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 46c0d06157..8e9e1b0b4b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -50,6 +50,7 @@ from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -135,6 +136,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return now_stream_id, [] @trace + @cancellable async def get_e2e_device_keys_for_cs_api( self, query_list: List[Tuple[str, Optional[str]]] ) -> Dict[str, Dict[str, JsonDict]]: @@ -197,6 +199,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ... @trace + @cancellable async def get_e2e_device_keys_and_signatures( self, query_list: Collection[Tuple[str, Optional[str]]], @@ -887,6 +890,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return keys + @cancellable async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None ) -> Dict[str, Optional[Dict[str, JsonDict]]]: @@ -902,7 +906,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker keys were not found, either their user ID will not be in the dict, or their user ID will map to None. """ - result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index e687f87eca..ca47a22bf1 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -48,6 +48,7 @@ from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -976,6 +977,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return int(min_depth) if min_depth is not None else None + @cancellable async def get_forward_extremities_for_room_at_stream_ordering( self, room_id: str, stream_ordering: int ) -> List[str]: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 84f17a9945..52914febf9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -81,6 +81,7 @@ from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import AsyncLruCache +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -339,6 +340,7 @@ class EventsWorkerStore(SQLBaseStore): ) -> Optional[EventBase]: ... + @cancellable async def get_event( self, event_id: str, @@ -433,6 +435,7 @@ class EventsWorkerStore(SQLBaseStore): @trace @tag_args + @cancellable async def get_events_as_list( self, event_ids: Collection[str], @@ -584,6 +587,7 @@ class EventsWorkerStore(SQLBaseStore): return events + @cancellable async def _get_events_from_cache_or_db( self, event_ids: Iterable[str], allow_rejected: bool = False ) -> Dict[str, EventCacheEntry]: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4f0adb136a..a77e49dc66 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -55,6 +55,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -770,6 +771,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): _get_users_server_still_shares_room_with_txn, ) + @cancellable async def get_rooms_for_user( self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None ) -> FrozenSet[str]: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 0b10af0e58..e607ccfdc9 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -36,6 +36,7 @@ from synapse.storage.state import StateFilter from synapse.types import JsonDict, JsonMapping, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -281,6 +282,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # FIXME: how should this be cached? + @cancellable async def get_partial_filtered_current_state_ids( self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index a347430aa7..3f9bfaeac5 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -72,6 +72,7 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import PersistedEventPosition, RoomStreamToken from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer @@ -597,6 +598,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key + @cancellable async def get_membership_changes_for_user( self, user_id: str, diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index bb64543c1f..f8cfcaca83 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -31,6 +31,7 @@ from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer @@ -156,6 +157,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "get_state_group_delta", _get_state_group_delta_txn ) + @cancellable async def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter ) -> Dict[int, StateMap[str]]: @@ -235,6 +237,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types + @cancellable async def _get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py index b4bf49dace..8d8894d1d5 100644 --- a/synapse/storage/util/partial_state_events_tracker.py +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -24,6 +24,7 @@ from synapse.logging.opentracing import trace_with_opname from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore from synapse.util import unwrapFirstError +from synapse.util.cancellation import cancellable logger = logging.getLogger(__name__) @@ -60,6 +61,7 @@ class PartialStateEventsTracker: o.callback(None) @trace_with_opname("PartialStateEventsTracker.await_full_state") + @cancellable async def await_full_state(self, event_ids: Collection[str]) -> None: """Wait for all the given events to have full state. @@ -154,6 +156,7 @@ class PartialCurrentStateTracker: o.callback(None) @trace_with_opname("PartialCurrentStateTracker.await_full_state") + @cancellable async def await_full_state(self, room_id: str) -> None: # We add the deferred immediately so that the DB call to check for # partial state doesn't race when we unpartial the room. diff --git a/synapse/types.py b/synapse/types.py index 668d48d646..ec44601f54 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -52,6 +52,7 @@ from twisted.internet.interfaces import ( ) from synapse.api.errors import Codes, SynapseError +from synapse.util.cancellation import cancellable from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: @@ -699,7 +700,11 @@ class StreamToken: START: ClassVar["StreamToken"] @classmethod + @cancellable async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": + """ + Creates a RoomStreamToken from its textual representation. + """ try: keys = string.split(cls._SEPARATOR) while len(keys) < len(attr.fields(cls)): diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 5726e60cee..5071f83574 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -140,6 +140,8 @@ def make_request_with_cancellation_test( method: str, path: str, content: Union[bytes, str, JsonDict] = b"", + *, + token: Optional[str] = None, ) -> FakeChannel: """Performs a request repeatedly, disconnecting at successive `await`s, until one completes. @@ -211,7 +213,13 @@ def make_request_with_cancellation_test( with deferred_patch.patch(): # Start the request. channel = make_request( - reactor, site, method, path, content, await_result=False + reactor, + site, + method, + path, + content, + await_result=False, + access_token=token, ) request = channel.request diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py index bbc8e74243..741fecea77 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py @@ -19,6 +19,7 @@ from synapse.rest import admin from synapse.rest.client import keys, login from tests import unittest +from tests.http.server._base import make_request_with_cancellation_test class KeyQueryTestCase(unittest.HomeserverTestCase): @@ -89,3 +90,31 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): Codes.BAD_JSON, channel.result, ) + + def test_key_query_cancellation(self) -> None: + """ + Tests that /keys/query is cancellable and does not swallow the + CancelledError. + """ + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + + bob = self.register_user("bob", "uncle") + + channel = make_request_with_cancellation_test( + "test_key_query_cancellation", + self.reactor, + self.site, + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + # Empty list means we request keys for all bob's devices + bob: [], + }, + }, + token=alice_token, + ) + + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertIn(bob, channel.json_body["device_keys"]) -- cgit 1.5.1 From c73774467edb04c372caecb9e843542654f7610b Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 14 Sep 2022 10:42:57 +0100 Subject: Fix bug in device list caching when remote users leave rooms (#13749) When a remote user leaves the last room shared with the homeserver, we have to mark their device list as unsubscribed, otherwise we would hold on to a stale device list in our cache. Crucially, the device list would remain cached even after the remote user rejoined the room, which could lead to E2EE failures until the next change to the remote user's device list. Fixes #13651. Signed-off-by: Sean Quah --- changelog.d/13749.bugfix | 1 + synapse/handlers/device.py | 11 ----------- synapse/handlers/e2e_keys.py | 26 ++++++++++++++++++++++++++ synapse/storage/controllers/persist_events.py | 20 +++++++++++++++++--- tests/handlers/test_e2e_keys.py | 8 +++++++- 5 files changed, 51 insertions(+), 15 deletions(-) create mode 100644 changelog.d/13749.bugfix (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/13749.bugfix b/changelog.d/13749.bugfix new file mode 100644 index 0000000000..8ffafec07b --- /dev/null +++ b/changelog.d/13749.bugfix @@ -0,0 +1 @@ +Fix a long standing bug where device lists would remain cached when remote users left and rejoined the last room shared with the local homeserver. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c5ac169644..901e2310b7 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -45,7 +45,6 @@ from synapse.types import ( JsonDict, StreamKeyType, StreamToken, - UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, ) @@ -324,8 +323,6 @@ class DeviceHandler(DeviceWorkerHandler): self.device_list_updater.incoming_device_list_update, ) - hs.get_distributor().observe("user_left_room", self.user_left_room) - # Whether `_handle_new_device_update_async` is currently processing. self._handle_new_device_update_is_processing = False @@ -569,14 +566,6 @@ class DeviceHandler(DeviceWorkerHandler): StreamKeyType.DEVICE_LIST, position, users=[from_user_id] ) - async def user_left_room(self, user: UserID, room_id: str) -> None: - user_id = user.to_string() - room_ids = await self.store.get_rooms_for_user(user_id) - if not room_ids: - # We no longer share rooms with this user, so we'll no longer - # receive device updates. Mark this in DB. - await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) - async def store_dehydrated_device( self, user_id: str, diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index ec81639c78..8eed63ccf3 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -175,6 +175,32 @@ class E2eKeysHandler: user_ids_not_in_cache, remote_results, ) = await self.store.get_user_devices_from_cache(query_list) + + # Check that the homeserver still shares a room with all cached users. + # Note that this check may be slightly racy when a remote user leaves a + # room after we have fetched their cached device list. In the worst case + # we will do extra federation queries for devices that we had cached. + cached_users = set(remote_results.keys()) + valid_cached_users = ( + await self.store.get_users_server_still_shares_room_with( + remote_results.keys() + ) + ) + invalid_cached_users = cached_users - valid_cached_users + if invalid_cached_users: + # Fix up results. If we get here, there is either a bug in device + # list tracking, or we hit the race mentioned above. + user_ids_not_in_cache.update(invalid_cached_users) + for invalid_user_id in invalid_cached_users: + remote_results.pop(invalid_user_id) + # This log message may be removed if it turns out it's almost + # entirely triggered by races. + logger.error( + "Devices for %s were cached, but the server no longer shares " + "any rooms with them. The cached device lists are stale.", + invalid_cached_users, + ) + for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) for device_id, device in devices.items(): diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index dad3731b9b..501dbbc990 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -598,9 +598,9 @@ class EventsPersistenceStorageController: # room state_delta_for_room: Dict[str, DeltaState] = {} - # Set of remote users which were in rooms the server has left. We - # should check if we still share any rooms and if not we mark their - # device lists as stale. + # Set of remote users which were in rooms the server has left or who may + # have left rooms the server is in. We should check if we still share any + # rooms and if not we mark their device lists as stale. potentially_left_users: Set[str] = set() if not backfilled: @@ -725,6 +725,20 @@ class EventsPersistenceStorageController: current_state = {} delta.no_longer_in_room = True + # Add all remote users that might have left rooms. + potentially_left_users.update( + user_id + for event_type, user_id in delta.to_delete + if event_type == EventTypes.Member + and not self.is_mine_id(user_id) + ) + potentially_left_users.update( + user_id + for event_type, user_id in delta.to_insert.keys() + if event_type == EventTypes.Member + and not self.is_mine_id(user_id) + ) + state_delta_for_room[room_id] = delta await self.persist_events_store._persist_events_and_state_updates( diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 1e6ad4b663..95698bc275 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -891,6 +891,12 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): new_callable=mock.MagicMock, return_value=make_awaitable(["some_room_id"]), ) + mock_get_users = mock.patch.object( + self.store, + "get_users_server_still_shares_room_with", + new_callable=mock.MagicMock, + return_value=make_awaitable({remote_user_id}), + ) mock_request = mock.patch.object( self.hs.get_federation_client(), "query_user_devices", @@ -898,7 +904,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): return_value=make_awaitable(response_body), ) - with mock_get_rooms, mock_request as mocked_federation_request: + with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request: # Make the first query and sanity check it succeeds. response_1 = self.get_success( e2e_handler.query_devices( -- cgit 1.5.1 From 03c2bfb7f89d637930da52723161ce74d4f89233 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 23 Sep 2022 13:44:03 +0100 Subject: Send device list updates out to servers in partially joined rooms (#13874) Use the provided list of servers in the room from the `/send_join` response, since we will not know which users are in the room. This isn't sufficient to ensure that all remote servers receive the right device list updates, since the `/send_join` response may be inaccurate or we may calculate the membership state of new users in the room incorrectly. Signed-off-by: Sean Quah --- changelog.d/13874.misc | 1 + synapse/handlers/device.py | 6 ++++- synapse/storage/controllers/state.py | 44 +++++++++++++++++++++++++++++++++- synapse/storage/databases/main/room.py | 17 +++++++++++++ 4 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13874.misc (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/13874.misc b/changelog.d/13874.misc new file mode 100644 index 0000000000..499e488c35 --- /dev/null +++ b/changelog.d/13874.misc @@ -0,0 +1 @@ +Faster room joins: Send device list updates to most servers in rooms with partial state. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 901e2310b7..6566b3bf3d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -688,11 +688,15 @@ class DeviceHandler(DeviceWorkerHandler): # Ignore any users that aren't ours if self.hs.is_mine_id(user_id): hosts = set( - await self._storage_controllers.state.get_current_hosts_in_room( + await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) ) hosts.discard(self.server_name) + # For rooms with partial state, `hosts` is merely an + # approximation. When we transition to a full state room, we + # will have to send out device list updates to any servers we + # missed. # Check if we've already sent this update to some hosts if current_stream_id == stream_id: diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index bbe568bf05..b1aa17047c 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -23,6 +23,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Tuple, ) @@ -524,12 +525,53 @@ class StateStorageController: return state_map.get(key) async def get_current_hosts_in_room(self, room_id: str) -> List[str]: - """Get current hosts in room based on current state.""" + """Get current hosts in room based on current state. + + Blocks until we have full state for the given room. This only happens for rooms + with partial state. + + Returns: + A list of hosts in the room, sorted by longest in the room first. (aka. + sorted by join with the lowest depth first). + """ await self._partial_state_room_tracker.await_full_state(room_id) return await self.stores.main.get_current_hosts_in_room(room_id) + async def get_current_hosts_in_room_or_partial_state_approximation( + self, room_id: str + ) -> Sequence[str]: + """Get approximation of current hosts in room based on current state. + + For rooms with full state, this is equivalent to `get_current_hosts_in_room`, + with the same order of results. + + For rooms with partial state, no blocking occurs. Instead, the list of hosts + in the room at the time of joining is combined with the list of hosts which + joined the room afterwards. The returned list may include hosts that are not + actually in the room and exclude hosts that are in the room, since we may + calculate state incorrectly during the partial state phase. The order of results + is arbitrary for rooms with partial state. + """ + # We have to read this list first to mitigate races with un-partial stating. + # This will be empty for rooms with full state. + hosts_at_join = await self.stores.main.get_partial_state_servers_at_join( + room_id + ) + + hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id) + hosts_from_state_set = set(hosts_from_state) + + # First take the list of hosts based on the current state. + # For rooms with partial state, this will be missing most hosts. + hosts = list(hosts_from_state) + # Then add in the list of hosts in the room at the time we joined. + # This will be an empty list for rooms with full state. + hosts.extend(host for host in hosts_at_join if host not in hosts_from_state_set) + + return hosts + async def get_users_in_room_with_profiles( self, room_id: str ) -> Dict[str, ProfileInfo]: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index bef66f1992..5dd116d766 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -25,6 +25,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Tuple, Union, cast, @@ -1133,6 +1134,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_rooms_for_retention_period_in_range_txn, ) + async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]: + """Gets the list of servers in a partial state room at the time we joined it. + + Returns: + The `servers_in_room` list from the `/send_join` response for partial state + rooms. May not be accurate or complete, as it comes from a remote + homeserver. + An empty list for full state rooms. + """ + return await self.db_pool.simple_select_onecol( + "partial_state_rooms_servers", + keyvalues={"room_id": room_id}, + retcol="server_name", + desc="get_partial_state_servers_at_join", + ) + async def get_partial_state_rooms_and_servers( self, ) -> Mapping[str, Collection[str]]: -- cgit 1.5.1 From f49f73c0da5502792c65d3de1ffd352ceb6af562 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 23 Sep 2022 17:55:15 +0100 Subject: Faster room joins: Avoid blocking `/keys/changes` (#13888) Part of the work for #12993. Once #12993 is fully resolved, we expect `/keys/changes` to behave sensibly when joined to a room with partial state. Signed-off-by: Sean Quah --- changelog.d/13888.misc | 1 + synapse/handlers/device.py | 7 +++++-- synapse/storage/controllers/state.py | 7 ++++++- 3 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 changelog.d/13888.misc (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/13888.misc b/changelog.d/13888.misc new file mode 100644 index 0000000000..4ffd9bcede --- /dev/null +++ b/changelog.d/13888.misc @@ -0,0 +1 @@ +Faster room joins: Avoid waiting for full state when processing `/keys/changes` requests. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6566b3bf3d..bad262731c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -195,7 +195,9 @@ class DeviceWorkerHandler: possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = await self._state_storage.get_current_state_ids(room_id) + current_state_ids = await self._state_storage.get_current_state_ids( + room_id, await_full_state=False + ) # The user may have left the room # TODO: Check if they actually did or if we were just invited. @@ -234,7 +236,8 @@ class DeviceWorkerHandler: # mapping from event_id -> state_dict prev_state_ids = await self._state_storage.get_state_ids_for_events( - event_ids + event_ids, + await_full_state=False, ) # Check if we've joined the room? If so we just blindly add all the users to diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index b1aa17047c..bb60130afe 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -407,6 +407,7 @@ class StateStorageController: self, room_id: str, state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, on_invalidate: Optional[Callable[[], None]] = None, ) -> StateMap[str]: """Get the current state event ids for a room based on the @@ -419,13 +420,17 @@ class StateStorageController: room_id: The room to get the state IDs of. state_filter: The state filter used to fetch state from the database. + await_full_state: if true, will block if we do not yet have complete + state for the room. on_invalidate: Callback for when the `get_current_state_ids` cache for the room gets invalidated. Returns: The current state of the room. """ - if not state_filter or state_filter.must_await_full_state(self._is_mine_id): + if await_full_state and ( + not state_filter or state_filter.must_await_full_state(self._is_mine_id) + ): await self._partial_state_room_tracker.await_full_state(room_id) if state_filter and not state_filter.is_full(): -- cgit 1.5.1 From 4b17a5ace846d82b09fccce79da77a8207a6765f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 28 Sep 2022 14:42:43 +0100 Subject: Handle remote device list updates during partial join (#13913) c.f. #12993 (comment), point 3 This stores all device list updates that we receive while partial joins are ongoing, and processes them once we have the full state. Note: We don't actually process the device lists in the same ways as if we weren't partially joined. Instead of updating the device list remote cache, we simply notify local users that a change in the remote user's devices has happened. I think this is safe as if the local user requests the keys for the remote user and we don't have them we'll simply fetch them as normal. --- changelog.d/13913.misc | 1 + synapse/handlers/device.py | 62 ++++++++++++++++++++++ synapse/handlers/federation.py | 4 ++ synapse/storage/databases/main/devices.py | 55 +++++++++++++++++++ synapse/storage/databases/main/room.py | 20 +++++++ .../delta/73/04pending_device_list_updates.sql | 28 ++++++++++ 6 files changed, 170 insertions(+) create mode 100644 changelog.d/13913.misc create mode 100644 synapse/storage/schema/main/delta/73/04pending_device_list_updates.sql (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/13913.misc b/changelog.d/13913.misc new file mode 100644 index 0000000000..30b4401049 --- /dev/null +++ b/changelog.d/13913.misc @@ -0,0 +1 @@ +Faster remote room joins: correctly handle remote device list updates during a partial join. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index bad262731c..f2ef591103 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -309,6 +309,17 @@ class DeviceWorkerHandler: "self_signing_key": self_signing_key, } + async def handle_room_un_partial_stated(self, room_id: str) -> None: + """Handles sending appropriate device list updates in a room that has + gone from partial to full state. + """ + + # TODO(faster_joins): worker mode support + # https://github.com/matrix-org/synapse/issues/12994 + logger.error( + "Trying handling device list state for partial join: not supported on workers." + ) + class DeviceHandler(DeviceWorkerHandler): def __init__(self, hs: "HomeServer"): @@ -746,6 +757,15 @@ class DeviceHandler(DeviceWorkerHandler): finally: self._handle_new_device_update_is_processing = False + async def handle_room_un_partial_stated(self, room_id: str) -> None: + """Handles sending appropriate device list updates in a room that has + gone from partial to full state. + """ + + # We defer to the device list updater implementation as we're on the + # right worker. + await self.device_list_updater.handle_room_un_partial_stated(room_id) + def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] @@ -836,6 +856,16 @@ class DeviceListUpdater: ) return + # Check if we are partially joining any rooms. If so we need to store + # all device list updates so that we can handle them correctly once we + # know who is in the room. + partial_rooms = await self.store.get_partial_state_rooms_and_servers() + if partial_rooms: + await self.store.add_remote_device_list_to_pending( + user_id, + device_id, + ) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we @@ -1175,3 +1205,35 @@ class DeviceListUpdater: device_ids.append(verify_key.version) return device_ids + + async def handle_room_un_partial_stated(self, room_id: str) -> None: + """Handles sending appropriate device list updates in a room that has + gone from partial to full state. + """ + + pending_updates = ( + await self.store.get_pending_remote_device_list_updates_for_room(room_id) + ) + + for user_id, device_id in pending_updates: + logger.info( + "Got pending device list update in room %s: %s / %s", + room_id, + user_id, + device_id, + ) + position = await self.store.add_device_change_to_streams( + user_id, + [device_id], + room_ids=[room_id], + ) + + if not position: + # This should only happen if there are no updates, which + # shouldn't happen when we've passed in a non-empty set of + # device IDs. + continue + + self.device_handler.notifier.on_new_event( + StreamKeyType.DEVICE_LIST, position, rooms=[room_id] + ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8f847ff845..360ab6fee2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -149,6 +149,7 @@ class FederationHandler: self.http_client = hs.get_proxied_blacklisted_http_client() self._replication = hs.get_replication_data_handler() self._federation_event_handler = hs.get_federation_event_handler() + self._device_handler = hs.get_device_handler() self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( @@ -1631,6 +1632,9 @@ class FederationHandler: # https://github.com/matrix-org/synapse/issues/12994 await self.state_handler.update_current_state(room_id) + logger.info("Handling any pending device list updates") + await self._device_handler.handle_room_un_partial_stated(room_id) + logger.info("Clearing partial-state flag for %s", room_id) success = await self.store.clear_partial_state_room(room_id) if success: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 1151fb0cc3..1e562d4a40 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1995,3 +1995,58 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): add_device_list_outbound_pokes_txn, stream_ids, ) + + async def add_remote_device_list_to_pending( + self, user_id: str, device_id: str + ) -> None: + """Add a device list update to the table tracking remote device list + updates during partial joins. + """ + + async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + await self.db_pool.simple_upsert( + table="device_lists_remote_pending", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={"stream_id": stream_id}, + desc="add_remote_device_list_to_pending", + ) + + async def get_pending_remote_device_list_updates_for_room( + self, room_id: str + ) -> Collection[Tuple[str, str]]: + """Get the set of remote device list updates from the pending table for + the room. + """ + + min_device_stream_id = await self.db_pool.simple_select_one_onecol( + table="partial_state_rooms", + keyvalues={ + "room_id": room_id, + }, + retcol="device_lists_stream_id", + desc="get_pending_remote_device_list_updates_for_room_device", + ) + + sql = """ + SELECT user_id, device_id FROM device_lists_remote_pending AS d + INNER JOIN current_state_events AS c ON + type = 'm.room.member' + AND state_key = user_id + AND membership = 'join' + WHERE + room_id = ? AND stream_id > ? + """ + + def get_pending_remote_device_list_updates_for_room_txn( + txn: LoggingTransaction, + ) -> Collection[Tuple[str, str]]: + txn.execute(sql, (room_id, min_device_stream_id)) + return cast(Collection[Tuple[str, str]], txn.fetchall()) + + return await self.db_pool.runInteraction( + "get_pending_remote_device_list_updates_for_room", + get_pending_remote_device_list_updates_for_room_txn, + ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 064c332fb7..672c9a03fc 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1217,6 +1217,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) + # We now delete anything from `device_lists_remote_pending` with a + # stream ID less than the minimum + # `partial_state_rooms.device_lists_stream_id`, as we no longer need them. + device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn( + txn, + table="partial_state_rooms", + keyvalues={}, + retcol="MIN(device_lists_stream_id)", + allow_none=True, + ) + if device_lists_stream_id is None: + # There are no rooms being currently partially joined, so we delete everything. + txn.execute("DELETE FROM device_lists_remote_pending") + else: + sql = """ + DELETE FROM device_lists_remote_pending + WHERE stream_id <= ? + """ + txn.execute(sql, (device_lists_stream_id,)) + @cached() async def is_partial_state_room(self, room_id: str) -> bool: """Checks if this room has partial state. diff --git a/synapse/storage/schema/main/delta/73/04pending_device_list_updates.sql b/synapse/storage/schema/main/delta/73/04pending_device_list_updates.sql new file mode 100644 index 0000000000..dbd78d677d --- /dev/null +++ b/synapse/storage/schema/main/delta/73/04pending_device_list_updates.sql @@ -0,0 +1,28 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Stores remote device lists we have received for remote users while a partial +-- join is in progress. +-- +-- This allows us to replay any device list updates if it turns out the remote +-- user was in the partially joined room +CREATE TABLE device_lists_remote_pending( + stream_id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +-- We only keep the most recent update for a given user/device pair. +CREATE UNIQUE INDEX device_lists_remote_pending_user_device_id ON device_lists_remote_pending(user_id, device_id); -- cgit 1.5.1 From 5f659d4a88e602ca8519984808dcf4df036c781b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 28 Sep 2022 23:22:35 +0100 Subject: Handle local device list updates during partial join (#13934) --- changelog.d/13934.misc | 1 + synapse/handlers/device.py | 84 ++++++++++++++++++++++++++++++- synapse/storage/databases/main/devices.py | 55 +++++++++++++++----- synapse/storage/databases/main/room.py | 16 ++++++ 4 files changed, 141 insertions(+), 15 deletions(-) create mode 100644 changelog.d/13934.misc (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/13934.misc b/changelog.d/13934.misc new file mode 100644 index 0000000000..6610a9f567 --- /dev/null +++ b/changelog.d/13934.misc @@ -0,0 +1 @@ +Correctly handle sending local device list updates to remote servers during a partial join. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f2ef591103..03082fce42 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -762,10 +762,90 @@ class DeviceHandler(DeviceWorkerHandler): gone from partial to full state. """ - # We defer to the device list updater implementation as we're on the - # right worker. + # We defer to the device list updater to handle pending remote device + # list updates. await self.device_list_updater.handle_room_un_partial_stated(room_id) + # Replay local updates. + ( + join_event_id, + device_lists_stream_id, + ) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state( + room_id + ) + + # Get the local device list changes that have happened in the room since + # we started joining. If there are no updates there's nothing left to do. + changes = await self.store.get_device_list_changes_in_room( + room_id, device_lists_stream_id + ) + local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)} + if not local_changes: + return + + # Note: We have persisted the full state at this point, we just haven't + # cleared the `partial_room` flag. + join_state_ids = await self._state_storage.get_state_ids_for_event( + join_event_id, await_full_state=False + ) + current_state_ids = await self.store.get_partial_current_state_ids(room_id) + + # Now we need to work out all servers that might have been in the room + # at any point during our join. + + # First we look for any membership states that have changed between the + # initial join and now... + all_keys = set(join_state_ids) + all_keys.update(current_state_ids) + + potentially_changed_hosts = set() + for etype, state_key in all_keys: + if etype != EventTypes.Member: + continue + + prev = join_state_ids.get((etype, state_key)) + current = current_state_ids.get((etype, state_key)) + + if prev != current: + potentially_changed_hosts.add(get_domain_from_id(state_key)) + + # ... then we add all the hosts that are currently joined to the room... + current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id) + potentially_changed_hosts.update(current_hosts_in_room) + + # ... and finally we remove any hosts that we were told about, as we + # will have sent device list updates to those hosts when they happened. + known_hosts_at_join = await self.store.get_partial_state_servers_at_join( + room_id + ) + potentially_changed_hosts.difference_update(known_hosts_at_join) + + potentially_changed_hosts.discard(self.server_name) + + if not potentially_changed_hosts: + # Nothing to do. + return + + logger.info( + "Found %d changed hosts to send device list updates to", + len(potentially_changed_hosts), + ) + + for user_id, device_id in local_changes: + await self.store.add_device_list_outbound_pokes( + user_id=user_id, + device_id=device_id, + room_id=room_id, + stream_id=None, + hosts=potentially_changed_hosts, + context=None, + ) + + # Notify things that device lists need to be sent out. + self.notifier.notify_replication() + for host in potentially_changed_hosts: + self.federation_sender.send_device_messages(host, immediate=False) + def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 1e562d4a40..18358eca46 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1307,6 +1307,33 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return changes + async def get_device_list_changes_in_room( + self, room_id: str, min_stream_id: int + ) -> Collection[Tuple[str, str]]: + """Get all device list changes that happened in the room since the given + stream ID. + + Returns: + Collection of user ID/device ID tuples of all devices that have + changed + """ + + sql = """ + SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room + WHERE room_id = ? AND stream_id > ? + """ + + def get_device_list_changes_in_room_txn( + txn: LoggingTransaction, + ) -> Collection[Tuple[str, str]]: + txn.execute(sql, (room_id, min_stream_id)) + return cast(Collection[Tuple[str, str]], txn.fetchall()) + + return await self.db_pool.runInteraction( + "get_device_list_changes_in_room", + get_device_list_changes_in_room_txn, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__( @@ -1946,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, room_id: str, - stream_id: int, + stream_id: Optional[int], hosts: Collection[str], context: Optional[Dict[str, str]], ) -> None: """Queue the device update to be sent to the given set of hosts, calculated from the room ID. - Marks the associated row in `device_lists_changes_in_room` as handled. + Marks the associated row in `device_lists_changes_in_room` as handled, + if `stream_id` is provided. """ def add_device_list_outbound_pokes_txn( @@ -1969,17 +1997,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context=context, ) - self.db_pool.simple_update_txn( - txn, - table="device_lists_changes_in_room", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "stream_id": stream_id, - "room_id": room_id, - }, - updatevalues={"converted_to_destinations": True}, - ) + if stream_id: + self.db_pool.simple_update_txn( + txn, + table="device_lists_changes_in_room", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "stream_id": stream_id, + "room_id": room_id, + }, + updatevalues={"converted_to_destinations": True}, + ) if not hosts: # If there are no hosts then we don't try and generate stream IDs. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 672c9a03fc..059eef5c22 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1256,6 +1256,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return entry is not None + async def get_join_event_id_and_device_lists_stream_id_for_partial_state( + self, room_id: str + ) -> Tuple[str, int]: + """Get the event ID of the initial join that started the partial + join, and the device list stream ID at the point we started the partial + join. + """ + + result = await self.db_pool.simple_select_one( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcols=("join_event_id", "device_lists_stream_id"), + desc="get_join_event_id_for_partial_state", + ) + return result["join_event_id"], result["device_lists_stream_id"] + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" -- cgit 1.5.1 From a466164647b969efd2e85168144cd75693443c05 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Thu, 29 Sep 2022 14:55:12 +0100 Subject: Optimise get_rooms_for_user (drop with_stream_ordering) (#13787) --- changelog.d/13787.misc | 1 + synapse/handlers/device.py | 6 +- synapse/handlers/sync.py | 14 +--- synapse/storage/_base.py | 1 + synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/roommember.py | 117 +++++++++++++-------------- tests/handlers/test_sync.py | 1 + 7 files changed, 66 insertions(+), 75 deletions(-) create mode 100644 changelog.d/13787.misc (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/13787.misc b/changelog.d/13787.misc new file mode 100644 index 0000000000..a9b93717f0 --- /dev/null +++ b/changelog.d/13787.misc @@ -0,0 +1 @@ +Optimise get rooms for user calls. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 03082fce42..f9cc5bddbc 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -273,11 +273,9 @@ class DeviceWorkerHandler: possibly_left = possibly_changed | possibly_left # Double check if we still share rooms with the given user. - users_rooms = await self.store.get_rooms_for_users_with_stream_ordering( - possibly_left - ) + users_rooms = await self.store.get_rooms_for_users(possibly_left) for changed_user_id, entries in users_rooms.items(): - if any(e.room_id in room_ids for e in entries): + if any(rid in room_ids for rid in entries): possibly_left.discard(changed_user_id) else: possibly_joined.discard(changed_user_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e75fc6b947..4abb9b6127 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1490,16 +1490,14 @@ class SyncHandler: since_token.device_list_key ) if changed_users is not None: - result = await self.store.get_rooms_for_users_with_stream_ordering( - changed_users - ) + result = await self.store.get_rooms_for_users(changed_users) for changed_user_id, entries in result.items(): # Check if the changed user shares any rooms with the user, # or if the changed user is the syncing user (as we always # want to include device list updates of their own devices). if user_id == changed_user_id or any( - e.room_id in joined_rooms for e in entries + rid in joined_rooms for rid in entries ): users_that_have_changed.add(changed_user_id) else: @@ -1533,13 +1531,9 @@ class SyncHandler: newly_left_users.update(left_users) # Remove any users that we still share a room with. - left_users_rooms = ( - await self.store.get_rooms_for_users_with_stream_ordering( - newly_left_users - ) - ) + left_users_rooms = await self.store.get_rooms_for_users(newly_left_users) for user_id, entries in left_users_rooms.items(): - if any(e.room_id in joined_rooms for e in entries): + if any(rid in joined_rooms for rid in entries): newly_left_users.discard(user_id) return DeviceListUpdates( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 313e8aca7d..bf42aeb8d1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -94,6 +94,7 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache( "get_rooms_for_user_with_stream_ordering", (user_id,) ) + self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,)) # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index db6ce83a2b..3b8ed1f7ee 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -205,6 +205,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_rooms_for_user_with_stream_ordering.invalidate( (data.state_key,) ) + self.get_rooms_for_user.invalidate((data.state_key,)) else: raise Exception("Unknown events stream row type %s" % (row.type,)) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 8ada3cdac3..982e1f08e3 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,7 +15,6 @@ import logging from typing import ( TYPE_CHECKING, - Callable, Collection, Dict, FrozenSet, @@ -52,7 +51,6 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList -from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -600,58 +598,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): for room_id, instance, stream_id in txn ) - @cachedList( - cached_method_name="get_rooms_for_user_with_stream_ordering", - list_name="user_ids", - ) - async def get_rooms_for_users_with_stream_ordering( - self, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: - """A batched version of `get_rooms_for_user_with_stream_ordering`. - - Returns: - Map from user_id to set of rooms that is currently in. - """ - return await self.db_pool.runInteraction( - "get_rooms_for_users_with_stream_ordering", - self._get_rooms_for_users_with_stream_ordering_txn, - user_ids, - ) - - def _get_rooms_for_users_with_stream_ordering_txn( - self, txn: LoggingTransaction, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: - - clause, args = make_in_list_sql_clause( - self.database_engine, - "c.state_key", - user_ids, - ) - - sql = f""" - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND c.membership = ? - AND {clause} - """ - - txn.execute(sql, [Membership.JOIN] + args) - - result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { - user_id: set() for user_id in user_ids - } - for user_id, room_id, instance, stream_id in txn: - result[user_id].add( - GetRoomsForUserWithStreamOrdering( - room_id, PersistedEventPosition(instance, stream_id) - ) - ) - - return {user_id: frozenset(v) for user_id, v in result.items()} - async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] ) -> Set[str]: @@ -693,19 +639,68 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {row[0] for row in txn} - @cancellable - async def get_rooms_for_user( - self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None - ) -> FrozenSet[str]: + @cached(max_entries=500000, iterable=True) + async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. """ - rooms = await self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate + rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate( + (user_id,), + None, + update_metrics=False, + ) + if rooms: + return frozenset(r.room_id for r in rooms) + + room_ids = await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + "state_key": user_id, + }, + retcol="room_id", + desc="get_rooms_for_user", ) - return frozenset(r.room_id for r in rooms) + + return frozenset(room_ids) + + @cachedList( + cached_method_name="get_rooms_for_user", + list_name="user_ids", + ) + async def get_rooms_for_users( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[str]]: + """A batched version of `get_rooms_for_user`. + + Returns: + Map from user_id to set of rooms that is currently in. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="current_state_events", + column="state_key", + iterable=user_ids, + retcols=( + "state_key", + "room_id", + ), + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + }, + desc="get_rooms_for_users", + ) + + user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} + + for row in rows: + user_rooms[row["state_key"]].add(row["room_id"]) + + return {key: frozenset(rooms) for key, rooms in user_rooms.items()} @cached(max_entries=10000) async def does_pair_of_users_share_a_room( diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index e3f38fbcc5..ab5c101eb7 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -159,6 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() + self.store.get_rooms_for_user.invalidate_all() self.get_success(self.store._get_event_cache.clear()) self.store._event_ref.clear() -- cgit 1.5.1 From c3a4780080a5bcb04132283c0f32f7452655792a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 18 Oct 2022 12:33:18 +0100 Subject: When restarting a partial join resync, prioritise the server which actioned a partial join (#14126) --- changelog.d/14126.misc | 1 + synapse/handlers/device.py | 5 +- synapse/handlers/federation.py | 57 +++++++++++++--------- synapse/storage/database.py | 2 +- synapse/storage/databases/main/room.py | 43 +++++++++++++--- .../delta/73/09partial_joined_via_destination.sql | 18 +++++++ 6 files changed, 95 insertions(+), 31 deletions(-) create mode 100644 changelog.d/14126.misc create mode 100644 synapse/storage/schema/main/delta/73/09partial_joined_via_destination.sql (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/14126.misc b/changelog.d/14126.misc new file mode 100644 index 0000000000..30b3482fbd --- /dev/null +++ b/changelog.d/14126.misc @@ -0,0 +1 @@ +Faster joins: prioritise the server we joined by when restarting a partial join resync. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f9cc5bddbc..c597639a7f 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -937,7 +937,10 @@ class DeviceListUpdater: # Check if we are partially joining any rooms. If so we need to store # all device list updates so that we can handle them correctly once we # know who is in the room. - partial_rooms = await self.store.get_partial_state_rooms_and_servers() + # TODO(faster joins): this fetches and processes a bunch of data that we don't + # use. Could be replaced by a tighter query e.g. + # SELECT EXISTS(SELECT 1 FROM partial_state_rooms) + partial_rooms = await self.store.get_partial_state_room_resync_info() if partial_rooms: await self.store.add_remote_device_list_to_pending( user_id, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5f7e0a1f79..ccc045d36f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -632,6 +632,7 @@ class FederationHandler: room_id=room_id, servers=ret.servers_in_room, device_lists_stream_id=self.store.get_device_stream_token(), + joined_via=origin, ) try: @@ -1615,13 +1616,13 @@ class FederationHandler: """Resumes resyncing of all partial-state rooms after a restart.""" assert not self.config.worker.worker_app - partial_state_rooms = await self.store.get_partial_state_rooms_and_servers() - for room_id, servers_in_room in partial_state_rooms.items(): + partial_state_rooms = await self.store.get_partial_state_room_resync_info() + for room_id, resync_info in partial_state_rooms.items(): run_as_background_process( desc="sync_partial_state_room", func=self._sync_partial_state_room, - initial_destination=None, - other_destinations=servers_in_room, + initial_destination=resync_info.joined_via, + other_destinations=resync_info.servers_in_room, room_id=room_id, ) @@ -1650,28 +1651,12 @@ class FederationHandler: # really leave, that might mean we have difficulty getting the room state over # federation. # https://github.com/matrix-org/synapse/issues/12802 - # - # TODO(faster_joins): we need some way of prioritising which homeservers in - # `other_destinations` to try first, otherwise we'll spend ages trying dead - # homeservers for large rooms. - # https://github.com/matrix-org/synapse/issues/12999 - - if initial_destination is None and len(other_destinations) == 0: - raise ValueError( - f"Cannot resync state of {room_id}: no destinations provided" - ) # Make an infinite iterator of destinations to try. Once we find a working # destination, we'll stick with it until it flakes. - destinations: Collection[str] - if initial_destination is not None: - # Move `initial_destination` to the front of the list. - destinations = list(other_destinations) - if initial_destination in destinations: - destinations.remove(initial_destination) - destinations = [initial_destination] + destinations - else: - destinations = other_destinations + destinations = _prioritise_destinations_for_partial_state_resync( + initial_destination, other_destinations, room_id + ) destination_iter = itertools.cycle(destinations) # `destination` is the current remote homeserver we're pulling from. @@ -1769,3 +1754,29 @@ class FederationHandler: room_id, destination, ) + + +def _prioritise_destinations_for_partial_state_resync( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, +) -> Collection[str]: + """Work out the order in which we should ask servers to resync events. + + If an `initial_destination` is given, it takes top priority. Otherwise + all servers are treated equally. + + :raises ValueError: if no destination is provided at all. + """ + if initial_destination is None and len(other_destinations) == 0: + raise ValueError(f"Cannot resync state of {room_id}: no destinations provided") + + if initial_destination is None: + return other_destinations + + # Move `initial_destination` to the front of the list. + destinations = list(other_destinations) + if initial_destination in destinations: + destinations.remove(initial_destination) + destinations = [initial_destination] + destinations + return destinations diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 7bb21f8f81..4717c9728a 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1658,7 +1658,7 @@ class DatabasePool: table: string giving the table name keyvalues: dict of column names and values to select the row with retcol: string giving the name of the column to return - allow_none: If true, return None instead of failing if the SELECT + allow_none: If true, return None instead of raising StoreError if the SELECT statement returns no rows desc: description of the transaction, for logging and metrics """ diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index e41c99027a..7d97f8f60e 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -97,6 +97,12 @@ class RoomSortOrder(Enum): STATE_EVENTS = "state_events" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PartialStateResyncInfo: + joined_via: Optional[str] + servers_in_room: List[str] = attr.ib(factory=list) + + class RoomWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -1160,17 +1166,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): desc="get_partial_state_servers_at_join", ) - async def get_partial_state_rooms_and_servers( + async def get_partial_state_room_resync_info( self, - ) -> Mapping[str, Collection[str]]: - """Get all rooms containing events with partial state, and the servers known - to be in the room. + ) -> Mapping[str, PartialStateResyncInfo]: + """Get all rooms containing events with partial state, and the information + needed to restart a "resync" of those rooms. Returns: A dictionary of rooms with partial state, with room IDs as keys and lists of servers in rooms as values. """ - room_servers: Dict[str, List[str]] = {} + room_servers: Dict[str, PartialStateResyncInfo] = {} + + rows = await self.db_pool.simple_select_list( + table="partial_state_rooms", + keyvalues={}, + retcols=("room_id", "joined_via"), + desc="get_server_which_served_partial_join", + ) + + for row in rows: + room_id = row["room_id"] + joined_via = row["joined_via"] + room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via) rows = await self.db_pool.simple_select_list( "partial_state_rooms_servers", @@ -1182,7 +1200,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): for row in rows: room_id = row["room_id"] server_name = row["server_name"] - room_servers.setdefault(room_id, []).append(server_name) + entry = room_servers.get(room_id) + if entry is None: + # There is a foreign key constraint which enforces that every room_id in + # partial_state_rooms_servers appears in partial_state_rooms. So we + # expect `entry` to be non-null. (This reasoning fails if we've + # partial-joined between the two SELECTs, but this is unlikely to happen + # in practice.) + continue + entry.servers_in_room.append(server_name) return room_servers @@ -1827,6 +1853,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, servers: Collection[str], device_lists_stream_id: int, + joined_via: str, ) -> None: """Mark the given room as containing events with partial state. @@ -1842,6 +1869,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): servers: other servers known to be in the room device_lists_stream_id: the device_lists stream ID at the time when we first joined the room. + joined_via: the server name we requested a partial join from. """ await self.db_pool.runInteraction( "store_partial_state_room", @@ -1849,6 +1877,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id, servers, device_lists_stream_id, + joined_via, ) def _store_partial_state_room_txn( @@ -1857,6 +1886,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, servers: Collection[str], device_lists_stream_id: int, + joined_via: str, ) -> None: DatabasePool.simple_insert_txn( txn, @@ -1866,6 +1896,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): "device_lists_stream_id": device_lists_stream_id, # To be updated later once the join event is persisted. "join_event_id": None, + "joined_via": joined_via, }, ) DatabasePool.simple_insert_many_txn( diff --git a/synapse/storage/schema/main/delta/73/09partial_joined_via_destination.sql b/synapse/storage/schema/main/delta/73/09partial_joined_via_destination.sql new file mode 100644 index 0000000000..066d602b18 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/09partial_joined_via_destination.sql @@ -0,0 +1,18 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- When we resync partial state, we prioritise doing so using the server we +-- partial-joined from. To do this we need to record which server that was! +ALTER TABLE partial_state_rooms ADD COLUMN joined_via TEXT; -- cgit 1.5.1 From 9cae44f49e6bf4f6b8a20ab11a65da417bb1565f Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 22 Nov 2022 16:46:52 +0000 Subject: Track unconverted device list outbound pokes using a position instead (#14516) When a local device list change is added to `device_lists_changes_in_room`, the `converted_to_destinations` flag is set to `FALSE` and the `_handle_new_device_update_async` background process is started. This background process looks for unconverted rows in `device_lists_changes_in_room`, copies them to `device_lists_outbound_pokes` and updates the flag. To update the `converted_to_destinations` flag, the database performs a `DELETE` and `INSERT` internally, which fragments the table. To avoid this, track unconverted rows using a `(stream ID, room ID)` position instead of the flag. From now on, the `converted_to_destinations` column indicates rows that need converting to outbound pokes, but does not indicate whether the conversion has already taken place. Closes #14037. Signed-off-by: Sean Quah --- changelog.d/14516.misc | 1 + synapse/handlers/device.py | 30 +++++- synapse/storage/database.py | 13 +-- synapse/storage/databases/main/devices.py | 107 +++++++++++++-------- .../73/12refactor_device_list_outbound_pokes.sql | 53 ++++++++++ tests/storage/test_devices.py | 3 +- 6 files changed, 158 insertions(+), 49 deletions(-) create mode 100644 changelog.d/14516.misc create mode 100644 synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/14516.misc b/changelog.d/14516.misc new file mode 100644 index 0000000000..51666c6ffc --- /dev/null +++ b/changelog.d/14516.misc @@ -0,0 +1 @@ +Refactor conversion of device list changes in room to outbound pokes to track unconverted rows using a `(stream ID, room ID)` position instead of updating the `converted_to_destinations` flag on every row. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c597639a7f..da3ddafeae 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -682,13 +682,33 @@ class DeviceHandler(DeviceWorkerHandler): hosts_already_sent_to: Set[str] = set() try: + stream_id, room_id = await self.store.get_device_change_last_converted_pos() + while True: self._handle_new_device_update_new_data = False - rows = await self.store.get_uncoverted_outbound_room_pokes() + max_stream_id = self.store.get_device_stream_token() + rows = await self.store.get_uncoverted_outbound_room_pokes( + stream_id, room_id + ) if not rows: # If the DB returned nothing then there is nothing left to # do, *unless* a new device list update happened during the # DB query. + + # Advance `(stream_id, room_id)`. + # `max_stream_id` comes from *before* the query for unconverted + # rows, which means that any unconverted rows must have a larger + # stream ID. + if max_stream_id > stream_id: + stream_id, room_id = max_stream_id, "" + await self.store.set_device_change_last_converted_pos( + stream_id, room_id + ) + else: + assert max_stream_id == stream_id + # Avoid moving `room_id` backwards. + pass + if self._handle_new_device_update_new_data: continue else: @@ -718,7 +738,6 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, room_id=room_id, - stream_id=stream_id, hosts=hosts, context=opentracing_context, ) @@ -752,6 +771,12 @@ class DeviceHandler(DeviceWorkerHandler): hosts_already_sent_to.update(hosts) current_stream_id = stream_id + # Advance `(stream_id, room_id)`. + _, _, room_id, stream_id, _ = rows[-1] + await self.store.set_device_change_last_converted_pos( + stream_id, room_id + ) + finally: self._handle_new_device_update_is_processing = False @@ -834,7 +859,6 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, room_id=room_id, - stream_id=None, hosts=potentially_changed_hosts, context=None, ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0dc44b246c..a14b13aec8 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -2075,13 +2075,14 @@ class DatabasePool: retcols: Collection[str], allow_none: bool = False, ) -> Optional[Dict[str, Any]]: - select_sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) + select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + + if keyvalues: + select_sql += " WHERE %s" % (" AND ".join("%s = ?" % k for k in keyvalues),) + txn.execute(select_sql, list(keyvalues.values())) + else: + txn.execute(select_sql) - txn.execute(select_sql, list(keyvalues.values())) row = txn.fetchone() if not row: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 57230df5ae..37629115ab 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -2008,27 +2008,48 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def get_uncoverted_outbound_room_pokes( - self, limit: int = 10 + self, start_stream_id: int, start_room_id: str, limit: int = 10 ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: """Get device list changes by room that have not yet been handled and written to `device_lists_outbound_pokes`. + Args: + start_stream_id: Together with `start_room_id`, indicates the position after + which to return device list changes. + start_room_id: Together with `start_stream_id`, indicates the position after + which to return device list changes. + limit: The maximum number of device list changes to return. + Returns: - A list of user ID, device ID, room ID, stream ID and optional opentracing context. + A list of user ID, device ID, room ID, stream ID and optional opentracing + context, in order of ascending (stream ID, room ID). """ sql = """ SELECT user_id, device_id, room_id, stream_id, opentracing_context FROM device_lists_changes_in_room - WHERE NOT converted_to_destinations - ORDER BY stream_id + WHERE + (stream_id, room_id) > (?, ?) AND + stream_id <= ? AND + NOT converted_to_destinations + ORDER BY stream_id ASC, room_id ASC LIMIT ? """ def get_uncoverted_outbound_room_pokes_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: - txn.execute(sql, (limit,)) + txn.execute( + sql, + ( + start_stream_id, + start_room_id, + # Avoid returning rows if there may be uncommitted device list + # changes with smaller stream IDs. + self._device_list_id_gen.get_current_token(), + limit, + ), + ) return [ ( @@ -2050,49 +2071,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, room_id: str, - stream_id: Optional[int], hosts: Collection[str], context: Optional[Dict[str, str]], ) -> None: """Queue the device update to be sent to the given set of hosts, calculated from the room ID. - - Marks the associated row in `device_lists_changes_in_room` as handled, - if `stream_id` is provided. """ + if not hosts: + return def add_device_list_outbound_pokes_txn( txn: LoggingTransaction, stream_ids: List[int] ) -> None: - if hosts: - self._add_device_outbound_poke_to_stream_txn( - txn, - user_id=user_id, - device_id=device_id, - hosts=hosts, - stream_ids=stream_ids, - context=context, - ) - - if stream_id: - self.db_pool.simple_update_txn( - txn, - table="device_lists_changes_in_room", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "stream_id": stream_id, - "room_id": room_id, - }, - updatevalues={"converted_to_destinations": True}, - ) - - if not hosts: - # If there are no hosts then we don't try and generate stream IDs. - return await self.db_pool.runInteraction( - "add_device_list_outbound_pokes", - add_device_list_outbound_pokes_txn, - [], + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_id=device_id, + hosts=hosts, + stream_ids=stream_ids, + context=context, ) async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: @@ -2156,3 +2153,37 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "get_pending_remote_device_list_updates_for_room", get_pending_remote_device_list_updates_for_room_txn, ) + + async def get_device_change_last_converted_pos(self) -> Tuple[int, str]: + """ + Get the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + + Rows with a strictly greater position where `converted_to_destinations` is + `FALSE` have not been converted. + """ + + row = await self.db_pool.simple_select_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + retcols=["stream_id", "room_id"], + desc="get_device_change_last_converted_pos", + ) + return row["stream_id"], row["room_id"] + + async def set_device_change_last_converted_pos( + self, + stream_id: int, + room_id: str, + ) -> None: + """ + Set the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + """ + + await self.db_pool.simple_update_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + updatevalues={"stream_id": stream_id, "room_id": room_id}, + desc="set_device_change_last_converted_pos", + ) diff --git a/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql new file mode 100644 index 0000000000..93d7fcb79b --- /dev/null +++ b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql @@ -0,0 +1,53 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Prior to this schema delta, we tracked the set of unconverted rows in +-- `device_lists_changes_in_room` using the `converted_to_destinations` flag. When rows +-- were converted to `device_lists_outbound_pokes`, the `converted_to_destinations` flag +-- would be set. +-- +-- After this schema delta, the `converted_to_destinations` is still populated like +-- before, but the set of unconverted rows is determined by the `stream_id` in the new +-- `device_lists_changes_converted_stream_position` table. +-- +-- If rolled back, Synapse will re-send all device list changes that happened since the +-- schema delta. + +CREATE TABLE IF NOT EXISTS device_lists_changes_converted_stream_position( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + -- The (stream id, room id) of the last row in `device_lists_changes_in_room` that + -- has been converted to `device_lists_outbound_pokes`. Rows with a strictly larger + -- (stream id, room id) where `converted_to_destinations` is `FALSE` have not been + -- converted. + stream_id BIGINT NOT NULL, + -- `room_id` may be an empty string, which compares less than all valid room IDs. + room_id TEXT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO device_lists_changes_converted_stream_position (stream_id, room_id) VALUES ( + ( + SELECT COALESCE( + -- The last converted stream id is the smallest unconverted stream id minus + -- one. + MIN(stream_id) - 1, + -- If there is no unconverted stream id, the last converted stream id is the + -- largest stream id. + -- Otherwise, pick 1, since stream ids start at 2. + (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room) + ) FROM device_lists_changes_in_room WHERE NOT converted_to_destinations + ), + '' +); diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index f37505b6cf..8e7db2c4ec 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -28,7 +28,7 @@ class DeviceStoreTestCase(HomeserverTestCase): """ for device_id in device_ids: - stream_id = self.get_success( + self.get_success( self.store.add_device_change_to_streams( user_id, [device_id], ["!some:room"] ) @@ -39,7 +39,6 @@ class DeviceStoreTestCase(HomeserverTestCase): user_id=user_id, device_id=device_id, room_id="!some:room", - stream_id=stream_id, hosts=[host], context={}, ) -- cgit 1.5.1 From 6d47b7e32589e816eb766446cc1ff19ea73fc7c1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 14:08:04 -0500 Subject: Add a type hint for `get_device_handler()` and fix incorrect types. (#14055) This was the last untyped handler from the HomeServer object. Since it was being treated as Any (and thus unchecked) it was being used incorrectly in a few places. --- changelog.d/14055.misc | 1 + synapse/handlers/deactivate_account.py | 4 +++ synapse/handlers/device.py | 65 ++++++++++++++++++++++++++-------- synapse/handlers/e2e_keys.py | 61 ++++++++++++++++--------------- synapse/handlers/register.py | 4 +++ synapse/handlers/set_password.py | 6 +++- synapse/handlers/sso.py | 9 +++++ synapse/module_api/__init__.py | 10 +++++- synapse/replication/http/devices.py | 11 ++++-- synapse/rest/admin/__init__.py | 26 ++++++++------ synapse/rest/admin/devices.py | 13 +++++-- synapse/rest/client/devices.py | 17 ++++++--- synapse/rest/client/logout.py | 9 +++-- synapse/server.py | 2 +- tests/handlers/test_device.py | 19 ++++++---- tests/rest/admin/test_device.py | 5 ++- 16 files changed, 185 insertions(+), 77 deletions(-) create mode 100644 changelog.d/14055.misc (limited to 'synapse/handlers/device.py') diff --git a/changelog.d/14055.misc b/changelog.d/14055.misc new file mode 100644 index 0000000000..02980bc528 --- /dev/null +++ b/changelog.d/14055.misc @@ -0,0 +1 @@ +Add missing type hints to `HomeServer`. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 816e1a6d79..d74d135c0c 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError +from synapse.handlers.device import DeviceHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Codes, Requester, UserID, create_requester @@ -76,6 +77,9 @@ class DeactivateAccountHandler: True if identity server supports removing threepids, otherwise False. """ + # This can only be called on the main process. + assert isinstance(self._device_handler, DeviceHandler) + # Check if this user can be deactivated if not await self._third_party_rules.check_can_deactivate_user( user_id, by_admin diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index da3ddafeae..b1e55e1b9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 class DeviceWorkerHandler: + device_list_updater: "DeviceListWorkerUpdater" + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs @@ -76,6 +78,8 @@ class DeviceWorkerHandler: self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled + self.device_list_updater = DeviceListWorkerUpdater(hs) + @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ @@ -99,6 +103,19 @@ class DeviceWorkerHandler: log_kv(device_map) return devices + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: """Retrieve the given device @@ -127,7 +144,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken - ) -> Collection[str]: + ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. """ @@ -320,6 +337,8 @@ class DeviceWorkerHandler: class DeviceHandler(DeviceWorkerHandler): + device_list_updater: "DeviceListUpdater" + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler): await self.delete_devices(user_id, [old_device_id]) return device_id - async def get_dehydrated_device( - self, user_id: str - ) -> Optional[Tuple[str, JsonDict]]: - """Retrieve the information for a dehydrated device. - - Args: - user_id: the user whose dehydrated device we are looking for - Returns: - a tuple whose first item is the device ID, and the second item is - the dehydrated device information - """ - return await self.store.get_dehydrated_device(user_id) - async def rehydrate_device( self, user_id: str, access_token: str, device_id: str ) -> dict: @@ -882,7 +888,36 @@ def _update_device_from_client_ips( ) -class DeviceListUpdater: +class DeviceListWorkerUpdater: + "Handles incoming device list updates from federation and contacts the main process over replication" + + def __init__(self, hs: "HomeServer"): + from synapse.replication.http.devices import ( + ReplicationUserDevicesResyncRestServlet, + ) + + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[JsonDict]: + """Fetches all devices for a user and updates the device cache with them. + + Args: + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale + if the attempt to resync failed. + Returns: + A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + """ + return await self._user_device_resync_client(user_id=user_id) + + +class DeviceListUpdater(DeviceListWorkerUpdater): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index bf1221f523..5fe102e2f2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -27,9 +27,9 @@ from twisted.internet import defer from synapse.api.constants import EduTypes from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( JsonDict, UserID, @@ -56,27 +56,23 @@ class E2eKeysHandler: self.is_mine = hs.is_mine self.clock = hs.get_clock() - self._edu_updater = SigningKeyEduUpdater(hs, self) - federation_registry = hs.get_federation_registry() - self._is_master = hs.config.worker.worker_app is None - if not self._is_master: - self._user_device_resync_client = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) - else: + is_master = hs.config.worker.worker_app is None + if is_master: + edu_updater = SigningKeyEduUpdater(hs) + # Only register this edu handler on master as it requires writing # device updates to the db federation_registry.register_edu_handler( EduTypes.SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # also handle the unstable version # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # doesn't really work as part of the generic query API, because the @@ -319,14 +315,13 @@ class E2eKeysHandler: # probably be tracking their device lists. However, we haven't # done an initial sync on the device list so we do it now. try: - if self._is_master: - resync_results = await self.device_handler.device_list_updater.user_device_resync( + resync_results = ( + await self.device_handler.device_list_updater.user_device_resync( user_id ) - else: - resync_results = await self._user_device_resync_client( - user_id=user_id - ) + ) + if resync_results is None: + raise ValueError("Device resync failed") # Add the device keys to the results. user_devices = resync_results["devices"] @@ -605,6 +600,8 @@ class E2eKeysHandler: async def upload_keys_for_user( self, user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) time_now = self.clock.time_msec() @@ -732,6 +729,8 @@ class E2eKeysHandler: user_id: the user uploading the keys keys: the signing keys """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) # if a master key is uploaded, then check it. Otherwise, load the # stored master key, to check signatures on other keys @@ -823,6 +822,9 @@ class E2eKeysHandler: Raises: SynapseError: if the signatures dict is not valid. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + failures = {} # signatures to be stored. Each item will be a SignatureListItem @@ -1200,6 +1202,9 @@ class E2eKeysHandler: A tuple of the retrieved key content, the key's ID and the matching VerifyKey. If the key cannot be retrieved, all values in the tuple will instead be None. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + try: remote_result = await self.federation.query_user_devices( user.domain, user.to_string() @@ -1396,11 +1401,14 @@ class SignatureListItem: class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() - self.e2e_keys_handler = e2e_keys_handler + + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_signing_key") @@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater: user_id: the user whose updates we are processing """ - device_handler = self.e2e_keys_handler.device_handler - device_list_updater = device_handler.device_list_updater - async with self._remote_edu_linearizer.queue(user_id): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: @@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = ( - await device_list_updater.process_cross_signing_key_update( - user_id, - master_key, - self_signing_key, - ) + new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + new_device_ids - await device_handler.notify_device_update(user_id, device_ids) + await self._device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ca1c7a1866..6307fa9c5d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,6 +38,7 @@ from synapse.api.errors import ( ) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( @@ -841,6 +842,9 @@ class RegistrationHandler: refresh_token = None refresh_token_id = None + # This can only run on the main process. + assert isinstance(self.device_handler, DeviceHandler) + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 73861bbd40..bd9d0bb34b 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.types import Requester if TYPE_CHECKING: @@ -29,7 +30,10 @@ class SetPasswordHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + # This can only be instantiated on the main process. + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler async def set_password( self, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 749d7e93b0..e1c0bff1b2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -37,6 +37,7 @@ from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.device import DeviceHandler from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent @@ -1035,6 +1036,8 @@ class SsoHandler: ) -> None: """Revoke any devices and in-flight logins tied to a provider session. + Can only be called from the main process. + Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". @@ -1042,6 +1045,12 @@ class SsoHandler: expected_user_id: The user we're expecting to logout. If set, it will ignore sessions belonging to other users and log an error. """ + + # It is expected that this is the main process. + assert isinstance( + self._device_handler, DeviceHandler + ), "revoking SSO sessions can only be called on the main process" + # Invalidate any running user-mapping sessions to_delete = [] for session_id, session in self._username_mapping_sessions.items(): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1adc1fd64f..96a661177a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,6 +86,7 @@ from synapse.handlers.auth import ( ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.device import DeviceHandler from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( @@ -207,6 +208,7 @@ class ModuleApi: self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self._push_rules_handler = hs.get_push_rules_handler() + self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -784,6 +786,8 @@ class ModuleApi: ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user + Can only be called from the main process. + Added in Synapse v0.25.0. Args: @@ -796,6 +800,10 @@ class ModuleApi: Raises: synapse.api.errors.AuthError: the access token is invalid """ + assert isinstance( + self._device_handler, DeviceHandler + ), "invalidate_access_token can only be called on the main process" + # see if the access token corresponds to a device user_info = yield defer.ensureDeferred( self._auth.get_user_by_access_token(access_token) @@ -805,7 +813,7 @@ class ModuleApi: if device_id: # delete the device, which will also delete its access tokens yield defer.ensureDeferred( - self._hs.get_device_handler().delete_devices(user_id, [device_id]) + self._device_handler.delete_devices(user_id, [device_id]) ) else: # no associated device. Just delete the access token. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index c21629def8..7c4941c3d3 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request @@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.device_list_updater = hs.get_device_handler().device_list_updater + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) return 200, user_devices diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c62ea22116..fb73886df0 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: """ Register all the admin servlets. """ + # Admin servlets aren't registered on workers. + if hs.config.worker.worker_app is not None: + return + register_servlets_for_client_rest_resource(hs, http_server) BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) @@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DevicesRestServlet(hs).register(http_server) - DeleteDevicesRestServlet(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) @@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserByExternalId(hs).register(http_server) UserByThreePid(hs).register(http_server) - # Some servlets only get registered for the main process. - if hs.config.worker.worker_app is None: - SendServerNoticeServlet(hs).register(http_server) - BackgroundUpdateEnabledRestServlet(hs).register(http_server) - BackgroundUpdateRestServlet(hs).register(http_server) - BackgroundUpdateStartJobRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) + SendServerNoticeServlet(hs).register(http_server) + BackgroundUpdateEnabledRestServlet(hs).register(http_server) + BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( @@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource( """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) - ResetPasswordRestServlet(hs).register(http_server) + # The following resources can only be run on the main process. + if hs.config.worker.worker_app is None: + DeactivateAccountRestServlet(hs).register(http_server) + ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d934880102..3b2f2d9abb 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -16,6 +16,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8f3cbd4ea2..69b803f9f8 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() class PostBody(RequestBodyModel): @@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled @@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler class PostBody(RequestBodyModel): device_id: StrictStr diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 23dfa4518f..6d34625ad5 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) @@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) diff --git a/synapse/server.py b/synapse/server.py index f0a60d0056..5baae2325e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta): ) @cache_in_self - def get_device_handler(self): + def get_device_handler(self) -> DeviceWorkerHandler: if self.config.worker.worker_app: return DeviceWorkerHandler(self) else: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b8b465d35b..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,7 +19,7 @@ from typing import Optional from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError -from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN +from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.server import HomeServer from synapse.util import Clock @@ -32,7 +32,9 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.store = hs.get_datastores().main return hs @@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_is_preserved_if_exists(self) -> None: @@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res2, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_id_is_made_up_if_unspecified(self) -> None: @@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) + assert dev is not None self.assertEqual(dev["display_name"], "display") def test_get_devices_by_user(self) -> None: @@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.registration = hs.get_registration_handler() self.auth = hs.get_auth() self.store = hs.get_datastores().main @@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ) ) - retrieved_device_id, device_data = self.get_success( - self.handler.get_dehydrated_device(user_id=user_id) - ) + result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) + assert result is not None + retrieved_device_id, device_data = result self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index d52aee8f92..03f2112b07 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes +from synapse.handlers.device import DeviceHandler from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") -- cgit 1.5.1