summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11215.feature1
-rw-r--r--synapse/appservice/__init__.py149
-rw-r--r--synapse/appservice/api.py47
-rw-r--r--synapse/appservice/scheduler.py151
-rw-r--r--synapse/config/experimental.py12
-rw-r--r--synapse/handlers/appservice.py250
-rw-r--r--synapse/handlers/device.py8
-rw-r--r--synapse/handlers/directory.py6
-rw-r--r--synapse/handlers/sync.py30
-rw-r--r--synapse/handlers/typing.py6
-rw-r--r--synapse/notifier.py4
-rw-r--r--synapse/replication/tcp/client.py10
-rw-r--r--synapse/storage/databases/main/appservice.py30
-rw-r--r--synapse/storage/databases/main/deviceinbox.py76
-rw-r--r--synapse/storage/databases/main/devices.py85
-rw-r--r--synapse/storage/schema/main/delta/65/06_msc2409_add_device_id_appservice_stream_type.sql23
-rw-r--r--synapse/storage/schema/main/delta/65/07_msc3202_add_device_list_appservice_stream_type.sql18
-rw-r--r--synapse/types.py21
-rw-r--r--tests/appservice/test_appservice.py50
-rw-r--r--tests/appservice/test_scheduler.py141
-rw-r--r--tests/handlers/test_appservice.py309
-rw-r--r--tests/storage/test_appservice.py27
22 files changed, 1190 insertions, 264 deletions
diff --git a/changelog.d/11215.feature b/changelog.d/11215.feature
new file mode 100644

index 0000000000..468020834b --- /dev/null +++ b/changelog.d/11215.feature
@@ -0,0 +1 @@ +Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). Disabled by default. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index f9d3bd337d..01db2b2ae3 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Iterable, List, Match, Optional from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id +from synapse.types import DeviceLists, GroupID, JsonDict, UserID, get_domain_from_id from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: @@ -144,26 +144,6 @@ class ApplicationService: return regex_obj["exclusive"] return False - async def _matches_user( - self, event: Optional[EventBase], store: Optional["DataStore"] = None - ) -> bool: - if not event: - return False - - if self.is_interested_in_user(event.sender): - return True - # also check m.room.member state key - if event.type == EventTypes.Member and self.is_interested_in_user( - event.state_key - ): - return True - - if not store: - return False - - does_match = await self.matches_user_in_member_list(event.room_id, store) - return does_match - @cached(num_args=1, cache_context=True) async def matches_user_in_member_list( self, @@ -171,14 +151,15 @@ class ApplicationService: store: "DataStore", cache_context: _CacheContext, ) -> bool: - """Check if this service is interested a room based upon it's membership + """Check if this appservice is interested a room based upon whether any members + fall into the appservice's user namespace. Args: room_id: The room to check. store: The datastore to query. Returns: - True if this service would like to know about this room. + True if this appservice would like to know about this room. """ member_list = await store.get_users_in_room( room_id, on_invalidate=cache_context.invalidate @@ -190,28 +171,82 @@ class ApplicationService: return True return False - def _matches_room_id(self, event: EventBase) -> bool: - if hasattr(event, "room_id"): - return self.is_interested_in_room(event.room_id) - return False + def is_interested_in_user( + self, + user_id: str, + ) -> bool: + """ + Returns whether the application is interested in a given user ID. + + The appservice is considered to be interested in a user if either: the + user ID is in the appservice's user namespace, or if the user is the + appservice's configured sender_localpart. + + Args: + user_id: The ID of the user to check. - async def _matches_aliases( - self, event: EventBase, store: Optional["DataStore"] = None + Returns: + True if the application service is interested in the user, False if not. + """ + return ( + # User is the appservice's sender_localpart user + user_id == self.sender + # User is in a defined namespace + or self.is_user_in_namespace(user_id) + ) + + @cached(num_args=1, cache_context=True) + async def is_interested_in_room( + self, + room_id: str, + store: "DataStore", + cache_context: _CacheContext, ) -> bool: - if not store or not event: - return False + """ + Returns whether the application service is interested in a given room ID. + + The appservice is considered to be interested in the room if either: the ID or one + of the aliases of the room is in the appservice's room ID or alias namespace + respectively, or if one of the members of the room fall into the appservice's user + namespace. + + Args: + room_id: The ID of the room to check. + store: The homeserver's datastore class. + + Returns: + True if the application service is interested in the room, False if not. + """ + # Check if we have interest in this room ID + if self.is_room_id_in_namespace(room_id): + return True - alias_list = await store.get_aliases_for_room(event.room_id) + # or any of the aliases this room has + alias_list = await store.get_aliases_for_room(room_id) for alias in alias_list: - if self.is_interested_in_alias(alias): + if self.is_room_alias_in_namespace(alias): return True - return False - async def is_interested( - self, event: EventBase, store: Optional["DataStore"] = None + # And finally, perform an expensive check on whether the appservice + # is interested in any users in the room based on their user ID + # and the appservice's user namespace. + return await self.matches_user_in_member_list( + room_id, store, on_invalidate=cache_context.invalidate + ) + + @cached(num_args=1, cache_context=True) + async def is_interested_in_event( + self, + event: EventBase, + store: "DataStore", + cache_context: _CacheContext, ) -> bool: """Check if this service is interested in this event. + Interest in an event is determined by whether this appservice is interested in + either the room the event was sent in, the sender of the event or - if the + event is of type "m.room.member", the user referenced by the event's state key. + Args: event: The event to check. store: The datastore to query. @@ -220,23 +255,28 @@ class ApplicationService: True if this service would like to know about this event. """ # Do cheap checks first - if self._matches_room_id(event): + + # Check if we're interested in this user by namespace (or if they're the + # sender_localpart user) + if self.is_interested_in_user(event.sender): return True - # This will check the namespaces first before - # checking the store, so should be run before _matches_aliases - if await self._matches_user(event, store): + # or, if this is a membership event, the user it references by namespace + if event.type == EventTypes.Member and self.is_interested_in_user( + event.state_key + ): return True - # This will check the store, so should be run last - if await self._matches_aliases(event, store): + if await self.is_interested_in_room( + event.room_id, store, on_invalidate=cache_context.invalidate + ): return True return False - @cached(num_args=1) + @cached(num_args=1, cache_context=True) async def is_interested_in_presence( - self, user_id: UserID, store: "DataStore" + self, user_id: UserID, store: "DataStore", cache_context: _CacheContext ) -> bool: """Check if this service is interested a user's presence @@ -254,20 +294,19 @@ class ApplicationService: # Then find out if the appservice is interested in any of those rooms for room_id in room_ids: - if await self.matches_user_in_member_list(room_id, store): + if await self.matches_user_in_member_list( + room_id, store, on_invalidate=cache_context.invalidate + ): return True return False - def is_interested_in_user(self, user_id: str) -> bool: - return ( - bool(self._matches_regex(user_id, ApplicationService.NS_USERS)) - or user_id == self.sender - ) + def is_user_in_namespace(self, user_id: str) -> bool: + return bool(self._matches_regex(user_id, ApplicationService.NS_USERS)) - def is_interested_in_alias(self, alias: str) -> bool: + def is_room_alias_in_namespace(self, alias: str) -> bool: return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) - def is_interested_in_room(self, room_id: str) -> bool: + def is_room_id_in_namespace(self, room_id: str) -> bool: return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) def is_exclusive_user(self, user_id: str) -> bool: @@ -330,11 +369,15 @@ class AppServiceTransaction: id: int, events: List[EventBase], ephemeral: List[JsonDict], + to_device_messages: List[JsonDict], + device_list_summary: DeviceLists, ): self.service = service self.id = id self.events = events self.ephemeral = ephemeral + self.to_device_messages = to_device_messages + self.device_list_summary = device_list_summary async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -348,6 +391,8 @@ class AppServiceTransaction: service=self.service, events=self.events, ephemeral=self.ephemeral, + to_device_messages=self.to_device_messages, + device_list_summary=self.device_list_summary, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index f51b636417..3ae59c7a04 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py
@@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,7 @@ # limitations under the License. import logging import urllib -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from prometheus_client import Counter @@ -22,7 +23,7 @@ from synapse.api.errors import CodeMessageException from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient -from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.types import DeviceLists, JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -204,12 +205,26 @@ class ApplicationServiceApi(SimpleHttpClient): service: "ApplicationService", events: List[EventBase], ephemeral: List[JsonDict], + to_device_messages: List[JsonDict], + device_list_summary: DeviceLists, txn_id: Optional[int] = None, - ): + ) -> bool: + """ + Push data to an application service. + Args: + service: The application service to send to. + events: The persistent events to send. + ephemeral: The ephemeral events to send. + to_device_messages: The to-device messages to send. + txn_id: An unique ID to assign to this transaction. Application services should + deduplicate transactions received with identitical IDs. + Returns: + True if the task succeeded, False if it failed. + """ if service.url is None: return True - events = self._serialize(service, events) + serialized_events = self._serialize(service, events) if txn_id is None: logger.warning( @@ -220,10 +235,28 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) # Never send ephemeral events to appservices that do not support it + body: Dict[str, Union[JsonDict, List[JsonDict]]] = {"events": serialized_events} if service.supports_ephemeral: - body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral} - else: - body = {"events": events} + body.update( + { + # TODO: Update to stable prefixes once MSC2409 completes FCP merge. + "de.sorunome.msc2409.ephemeral": ephemeral, + "de.sorunome.msc2409.to_device": to_device_messages, + } + ) + + # Send device list summaries if needed + if device_list_summary: + logger.info("Sending device list summary: %s", device_list_summary) + body.update( + { + # TODO: Update to stable prefix once MSC3202 completes FCP merge + "org.matrix.msc3202.device_lists": { + "changed": list(device_list_summary.changed), + "left": list(device_list_summary.left), + } + } + ) try: await self.put_json( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 6a2ce99b55..d49636d926 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py
@@ -48,13 +48,13 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging -from typing import List, Optional +from typing import Dict, Iterable, List, Optional from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import JsonDict +from synapse.types import DeviceLists, JsonDict logger = logging.getLogger(__name__) @@ -65,6 +65,9 @@ MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100 # Maximum number of ephemeral events to provide in an AS transaction. MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100 +# Maximum number of to-device messages to provide in an AS transaction. +MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION = 100 + class ApplicationServiceScheduler: """Public facing API for this module. Does the required DI to tie the @@ -91,13 +94,53 @@ class ApplicationServiceScheduler: for service in services: self.txn_ctrl.start_recoverer(service) - def submit_event_for_as(self, service: ApplicationService, event: EventBase): - self.queuer.enqueue_event(service, event) + def enqueue_for_appservice( + self, + appservice: ApplicationService, + events: Optional[Iterable[EventBase]] = None, + ephemeral: Optional[Iterable[JsonDict]] = None, + to_device_messages: Optional[Iterable[JsonDict]] = None, + device_list_summary: Optional[DeviceLists] = None, + ) -> None: + """ + Enqueue some data to be sent off to an application service. + + Args: + appservice: The application service to create and send a transaction to. + events: The persistent room events to send. + ephemeral: The ephemeral events to send. + to_device_messages: The to-device messages to send. These differ from normal + to-device messages sent to clients, as they have 'to_device_id' and + 'to_user_id' fields. + device_list_summary: A summary of users that the application service either needs + to refresh the device lists of, or those that the application service need no + longer track the device lists of. + """ + # We purposefully allow this method to run with empty events/ephemeral + # iterables, so that callers do not need to check iterable size themselves. + if ( + not events + and not ephemeral + and not to_device_messages + and not device_list_summary + ): + return + + if events: + self.queuer.queued_events.setdefault(appservice.id, []).extend(events) + if ephemeral: + self.queuer.queued_ephemeral.setdefault(appservice.id, []).extend(ephemeral) + if to_device_messages: + self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend( + to_device_messages + ) + if device_list_summary: + self.queuer.queued_device_list_summaries.setdefault( + appservice.id, [] + ).append(device_list_summary) - def submit_ephemeral_events_for_as( - self, service: ApplicationService, events: List[JsonDict] - ): - self.queuer.enqueue_ephemeral(service, events) + # Kick off a new application service transaction + self.queuer.start_background_request(appservice) class _ServiceQueuer: @@ -109,15 +152,21 @@ class _ServiceQueuer: """ def __init__(self, txn_ctrl, clock): - self.queued_events = {} # dict of {service_id: [events]} - self.queued_ephemeral = {} # dict of {service_id: [events]} + # dict of {service_id: [events]} + self.queued_events: Dict[str, List[EventBase]] = {} + # dict of {service_id: [event_json]} + self.queued_ephemeral: Dict[str, List[JsonDict]] = {} + # dict of {service_id: [to_device_message_json]} + self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} + # dict of {service_id: [device_list_summary]} + self.queued_device_list_summaries: Dict[str, List[DeviceLists]] = {} # the appservices which currently have a transaction in flight self.requests_in_flight = set() self.txn_ctrl = txn_ctrl self.clock = clock - def _start_background_request(self, service): + def start_background_request(self, service): # start a sender for this appservice if we don't already have one if service.id in self.requests_in_flight: return @@ -126,14 +175,6 @@ class _ServiceQueuer: "as-sender-%s" % (service.id,), self._send_request, service ) - def enqueue_event(self, service: ApplicationService, event: EventBase): - self.queued_events.setdefault(service.id, []).append(event) - self._start_background_request(service) - - def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]): - self.queued_ephemeral.setdefault(service.id, []).extend(events) - self._start_background_request(service) - async def _send_request(self, service: ApplicationService): # sanity-check: we shouldn't get here if this service already has a sender # running. @@ -150,11 +191,58 @@ class _ServiceQueuer: ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] - if not events and not ephemeral: + all_to_device_messages = self.queued_to_device_messages.get( + service.id, [] + ) + to_device_messages_to_send = all_to_device_messages[ + :MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION + ] + del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION] + + # Consolidate any pending device list summaries into a single, up-to-date + # summary. + # Note: this code assumes that in a single DeviceLists, a user will + # never be in both "changed" and "left" sets. + device_list_summary = DeviceLists() + while self.queued_device_list_summaries.get(service.id, []): + # Pop a summary off the front of the queue + summary = self.queued_device_list_summaries[service.id].pop(0) + + # For every user in the incoming "changed" set: + # * Remove them from the existing "left" set if necessary + # (as we need to start tracking them again) + # * Add them to the existing "changed" set if necessary. + for user_id in summary.changed: + if user_id in device_list_summary.left: + device_list_summary.left.remove(user_id) + device_list_summary.changed.add(user_id) + + # For every user in the incoming "left" set: + # * Remove them from the existing "changed" set if necessary + # (we no longer need to track them) + # * Add them to the existing "left" set if necessary. + for user_id in summary.left: + if user_id in device_list_summary.changed: + device_list_summary.changed.remove(user_id) + device_list_summary.left.add(user_id) + + if ( + not events + and not ephemeral + and not to_device_messages_to_send + # Note that DeviceLists implements __bool__ + and not device_list_summary + ): return try: - await self.txn_ctrl.send(service, events, ephemeral) + await self.txn_ctrl.send( + service, + events, + ephemeral, + to_device_messages_to_send, + device_list_summary, + ) except Exception: logger.exception("AS request failed") finally: @@ -191,10 +279,27 @@ class _TransactionController: service: ApplicationService, events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, - ): + to_device_messages: Optional[List[JsonDict]] = None, + device_list_summary: Optional[DeviceLists] = None, + ) -> None: + """ + Create a transaction with the given data and send to the provided + application service. + + Args: + service: The application service to send the transaction to. + events: The persistent events to include in the transaction. + ephemeral: The ephemeral events to include in the transaction. + to_device_messages: The to-device messages to include in the transaction. + device_list_summary: The device list summary to include in the transaction. + """ try: txn = await self.store.create_appservice_txn( - service=service, events=events, ephemeral=ephemeral or [] + service=service, + events=events, + ephemeral=ephemeral or [], + to_device_messages=to_device_messages or [], + device_list_summary=device_list_summary or DeviceLists(), ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 678c78d565..d19165e5b4 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -50,6 +50,18 @@ class ExperimentalConfig(Config): # MSC3030 (Jump to date API endpoint) self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) + # MSC2409 (this setting only relates to optionally sending to-device messages). + # Presence, typing and read receipt EDUs are already sent to application services that + # have opted in to receive them. This setting, if enabled, adds to-device messages + # to that list. + self.msc2409_to_device_messages_enabled: bool = experimental.get( + "msc2409_to_device_messages_enabled", False + ) + + # MSC3202 (device list updates and OTK counts / fallback keys to appservices). + # Only device lists are supported currently. + self.msc3202_enabled: bool = experimental.get("msc3202_enabled", False) + # The portion of MSC3202 which is related to device masquerading. self.msc3202_device_masquerading_enabled: bool = experimental.get( "msc3202_device_masquerading", False diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9abdad262b..fb533188a2 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -33,7 +33,7 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID +from synapse.types import DeviceLists, JsonDict, RoomAlias, RoomStreamToken, UserID from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure @@ -55,6 +55,10 @@ class ApplicationServicesHandler: self.clock = hs.get_clock() self.notify_appservices = hs.config.appservice.notify_appservices self.event_sources = hs.get_event_sources() + self._msc2409_to_device_messages_enabled = ( + hs.config.experimental.msc2409_to_device_messages_enabled + ) + self._msc3202_enabled = hs.config.experimental.msc3202_enabled self.current_max = 0 self.is_processing = False @@ -132,7 +136,9 @@ class ApplicationServicesHandler: # Fork off pushes to these services for service in services: - self.scheduler.submit_event_for_as(service, event) + self.scheduler.enqueue_for_appservice( + service, events=[event] + ) now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) @@ -199,8 +205,9 @@ class ApplicationServicesHandler: Args: stream_key: The stream the event came from. - `stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other - value for `stream_key` will cause this function to return early. + `stream_key` can be "typing_key", "receipt_key", "presence_key", + "to_device_key" or "device_list_key". Any other value fo + `stream_key` will cause this function to return early. Ephemeral events will only be pushed to appservices that have opted into receiving them by setting `push_ephemeral` to true in their registration @@ -216,8 +223,16 @@ class ApplicationServicesHandler: if not self.notify_appservices: return - # Ignore any unsupported streams - if stream_key not in ("typing_key", "receipt_key", "presence_key"): + # Notify appservices of updates in ephemeral event streams. + # Only the following streams are currently supported. + # FIXME: We should use constants for these values. + if stream_key not in ( + "typing_key", + "receipt_key", + "presence_key", + "to_device_key", + "device_list_key", + ): return # Assert that new_token is an integer (and not a RoomStreamToken). @@ -233,6 +248,17 @@ class ApplicationServicesHandler: # Additional context: https://github.com/matrix-org/synapse/pull/11137 assert isinstance(new_token, int) + # Ignore to-device messages if the feature flag is not enabled + if ( + stream_key == "to_device_key" + and not self._msc2409_to_device_messages_enabled + ): + return + + # Ignore device lists if the feature flag is not enabled + if stream_key == "device_list_key" and not self._msc3202_enabled: + return + # Check whether there are any appservices which have registered to receive # ephemeral events. # @@ -266,7 +292,7 @@ class ApplicationServicesHandler: with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: if stream_key == "typing_key": - # Note that we don't persist the token (via set_type_stream_id_for_appservice) + # Note that we don't persist the token (via set_appservice_stream_type_pos) # for typing_key due to performance reasons and due to their highly # ephemeral nature. # @@ -274,7 +300,7 @@ class ApplicationServicesHandler: # and, if they apply to this application service, send it off. events = await self._handle_typing(service, new_token) if events: - self.scheduler.submit_ephemeral_events_for_as(service, events) + self.scheduler.enqueue_for_appservice(service, ephemeral=events) continue # Since we read/update the stream position for this AS/stream @@ -285,28 +311,51 @@ class ApplicationServicesHandler: ): if stream_key == "receipt_key": events = await self._handle_receipts(service, new_token) - if events: - self.scheduler.submit_ephemeral_events_for_as( - service, events - ) + self.scheduler.enqueue_for_appservice(service, ephemeral=events) # Persist the latest handled stream token for this appservice - await self.store.set_type_stream_id_for_appservice( + await self.store.set_appservice_stream_type_pos( service, "read_receipt", new_token ) elif stream_key == "presence_key": events = await self._handle_presence(service, users, new_token) - if events: - self.scheduler.submit_ephemeral_events_for_as( - service, events - ) + self.scheduler.enqueue_for_appservice(service, ephemeral=events) # Persist the latest handled stream token for this appservice - await self.store.set_type_stream_id_for_appservice( + await self.store.set_appservice_stream_type_pos( service, "presence", new_token ) + elif stream_key == "to_device_key": + # Retrieve a list of to-device message events, as well as the + # maximum stream token of the messages we were able to retrieve. + to_device_messages = await self._get_to_device_messages( + service, new_token, users + ) + self.scheduler.enqueue_for_appservice( + service, to_device_messages=to_device_messages + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_appservice_stream_type_pos( + service, "to_device", new_token + ) + + elif stream_key == "device_list_key": + device_list_summary = await self._get_device_list_summary( + service, new_token + ) + if device_list_summary: + self.scheduler.enqueue_for_appservice( + service, device_list_summary=device_list_summary + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_appservice_stream_type_pos( + service, "device_list", new_token + ) + async def _handle_typing( self, service: ApplicationService, new_token: int ) -> List[JsonDict]: @@ -440,6 +489,167 @@ class ApplicationServicesHandler: return events + async def _get_to_device_messages( + self, + service: ApplicationService, + new_token: int, + users: Collection[Union[str, UserID]], + ) -> List[JsonDict]: + """ + Given an application service, determine which events it should receive + from those between the last-recorded typing event stream token for this + appservice and the given stream token. + + Args: + service: The application service to check for which events it should receive. + new_token: The latest to-device event stream token. + users: The users that should receive new to-device messages. + + Returns: + A list of JSON dictionaries containing data derived from the typing events + that should be sent to the given application service. + """ + # Get the stream token that this application service has processed up until + from_key = await self.store.get_type_stream_id_for_appservice( + service, "to_device" + ) + + # Filter out users that this appservice is not interested in + users_appservice_is_interested_in: List[str] = [] + for user in users: + # FIXME: We should do this farther up the call stack. We currently repeat + # this operation in _handle_presence. + if isinstance(user, UserID): + user = user.to_string() + + if service.is_interested_in_user(user): + users_appservice_is_interested_in.append(user) + + if not users_appservice_is_interested_in: + # Return early if the AS was not interested in any of these users + return [] + + # Retrieve the to-device messages for each user + recipient_user_id_device_id_to_messages = await self.store.get_new_messages( + users_appservice_is_interested_in, + from_key, + new_token, + ) + + # According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields + # to the event JSON so that the application service will know which user/device + # combination this messages was intended for. + # + # So we mangle this dict into a flat list of to-device messages with the relevant + # user ID and device ID embedded inside each message dict. + message_payload: List[JsonDict] = [] + for ( + user_id, + device_id, + ), messages in recipient_user_id_device_id_to_messages.items(): + for message_json in messages: + # Remove 'message_id' from the to-device message, as it's an internal ID + message_json.pop("message_id", None) + + message_payload.append( + { + "to_user_id": user_id, + "to_device_id": device_id, + **message_json, + } + ) + + return message_payload + + async def _get_device_list_summary( + self, + appservice: ApplicationService, + new_key: int, + ) -> DeviceLists: + """ + Retrieve a list of users who have changed their device lists. + + Args: + appservice: The application service to retrieve device list changes for. + new_key: The stream key of the device list change that triggered this method call. + + Returns: + A set of device list updates, comprised of users that the appservices needs to: + * resync the device list of, and + * stop tracking the device list of. + """ + # Fetch the last successfully processed device list update stream ID + # for this appservice. + from_key = await self.store.get_type_stream_id_for_appservice( + appservice, "device_list" + ) + + # Fetch the users who have modified their device list since then. + users_with_changed_device_lists = ( + await self.store.get_users_whose_devices_changed( + from_key, filter_user_ids=None, to_key=new_key + ) + ) + + # Filter out any users the application service is not interested in + # + # For each user who changed their device list, we want to check whether this + # appservice would be interested in the change. + filtered_users_with_changed_device_lists = { + user_id + for user_id in users_with_changed_device_lists + if self._is_appservice_interested_in_device_lists_of_user( + appservice, user_id + ) + } + + # Create a summary of "changed" and "left" users. + # TODO: Calculate "left" users. + device_list_summary = DeviceLists( + changed=filtered_users_with_changed_device_lists + ) + + return device_list_summary + + async def _is_appservice_interested_in_device_lists_of_user( + self, + appservice: ApplicationService, + user_id: str, + ) -> bool: + """ + Returns whether a given application service is interested in the device lists of a + given user. + + The application service is interested in the user's device lists if any of the + following are true: + * The user is in the appservice's user namespace. + * At least one member of one room that the user is a part of is in the + appservice's user namespace. + * The appservice is explicitly (via room ID or alias) interested in at + least one room that the user is in. + + Args: + appservice: The application service to gauge interest of. + user_id: The ID of the user whose device list interest is in question. + + Returns: + True if the application service is interested in the user's device lists, False + otherwise. + """ + if appservice.is_interested_in_user(user_id): + return True + + # FIXME: This is quite an expensive check. This method is called per device + # list change. + room_ids = await self.store.get_rooms_for_user(user_id) + for room_id in room_ids: + # This method covers checking room members for appservice interest as well as + # room ID and alias checks. + if await appservice.is_interested_in_room(room_id, self.store): + return True + + return False + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. @@ -469,7 +679,7 @@ class ApplicationServicesHandler: room_alias_str = room_alias.to_string() services = self.store.get_app_services() alias_query_services = [ - s for s in services if (s.is_interested_in_alias(room_alias_str)) + s for s in services if (s.is_room_alias_in_namespace(room_alias_str)) ] for alias_service in alias_query_services: is_known_alias = await self.appservice_api.query_alias( @@ -558,7 +768,7 @@ class ApplicationServicesHandler: # inside of a list comprehension anymore. interested_list = [] for s in services: - if await s.is_interested(event, self.store): + if await s.is_interested_in_event(event, self.store): interested_list.append(s) return interested_list diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 82ee11e921..2c07d31dfd 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -495,13 +495,11 @@ class DeviceHandler(DeviceWorkerHandler): "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) - room_ids = await self.store.get_rooms_for_user(user_id) - # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - self.notifier.on_new_event( - "device_list_key", position, users=[user_id], rooms=room_ids - ) + users_to_notify = users_who_share_room.union(user_id) + + self.notifier.on_new_event("device_list_key", position, users=users_to_notify) if hosts: logger.info( diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 7ee5c47fd9..f49bb806a8 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py
@@ -119,7 +119,7 @@ class DirectoryHandler: service = requester.app_service if service: - if not service.is_interested_in_alias(room_alias_str): + if not service.is_room_alias_in_namespace(room_alias_str): raise SynapseError( 400, "This application service has not reserved this kind of alias.", @@ -221,7 +221,7 @@ class DirectoryHandler: async def delete_appservice_association( self, service: ApplicationService, room_alias: RoomAlias ) -> None: - if not service.is_interested_in_alias(room_alias.to_string()): + if not service.is_room_alias_in_namespace(room_alias.to_string()): raise SynapseError( 400, "This application service has not reserved this kind of alias", @@ -374,7 +374,7 @@ class DirectoryHandler: # non-exclusive locks on the alias (or there are no interested services) services = self.store.get_app_services() interested_services = [ - s for s in services if s.is_interested_in_alias(alias.to_string()) + s for s in services if s.is_room_alias_in_namespace(alias.to_string()) ] for service in interested_services: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f3039c3c3f..d004c42885 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -13,17 +13,7 @@ # limitations under the License. import itertools import logging -from typing import ( - TYPE_CHECKING, - Any, - Collection, - Dict, - FrozenSet, - List, - Optional, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -39,6 +29,7 @@ from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( + DeviceLists, JsonDict, MutableStateMap, Requester, @@ -183,21 +174,6 @@ class GroupsSyncResult: return bool(self.join or self.invite or self.leave) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DeviceLists: - """ - Attributes: - changed: List of user_ids whose devices may have changed - left: List of user_ids whose devices we no longer track - """ - - changed: Collection[str] - left: Collection[str] - - def __bool__(self) -> bool: - return bool(self.changed or self.left) - - @attr.s(slots=True, auto_attribs=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined @@ -1329,7 +1305,7 @@ class SyncHandler: return DeviceLists(changed=users_that_have_changed, left=newly_left_users) else: - return DeviceLists(changed=[], left=[]) + return DeviceLists() async def _generate_sync_entry_for_to_device( self, sync_result_builder: "SyncResultBuilder" diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 1676ebd057..985b8ff3be 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -442,7 +442,7 @@ class TypingWriterHandler(FollowerTypingHandler): class TypingNotificationEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.hs = hs + self.store = hs.get_datastore() self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: # @@ -482,9 +482,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): if handler._room_serials[room_id] <= from_key: continue - if not await service.matches_user_in_member_list( - room_id, handler.store - ): + if not await service.matches_user_in_member_list(room_id, self.store): continue events.append(self._make_event_for(room_id)) diff --git a/synapse/notifier.py b/synapse/notifier.py
index 60e5409895..3b24a7f4ba 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py
@@ -452,7 +452,9 @@ class Notifier: users, ) except Exception: - logger.exception("Error notifying application services of event") + logger.exception( + "Error notifying application services of ephemeral event" + ) def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happened diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e29ae1e375..679df5602f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -173,12 +173,14 @@ class ReplicationDataHandler: if entities: self.notifier.on_new_event("to_device_key", token, users=entities) elif stream_name == DeviceListsStream.NAME: - all_room_ids: Set[str] = set() + users_to_notify: Set[str] = set() for row in rows: if row.entity.startswith("@"): - room_ids = await self.store.get_rooms_for_user(row.entity) - all_room_ids.update(room_ids) - self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) + user_ids = await self.store.get_users_who_share_room_with_user( + row.entity + ) + users_to_notify.update(user_ids) + self.notifier.on_new_event("device_list_key", token, users=users_to_notify) elif stream_name == GroupServerStream.NAME: self.notifier.on_new_event( "groups_key", token, users=[row.user_id for row in rows] diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 4a883dc166..0ac2005bee 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -27,7 +27,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.types import Connection -from synapse.types import JsonDict +from synapse.types import DeviceLists, JsonDict from synapse.util import json_encoder if TYPE_CHECKING: @@ -194,6 +194,8 @@ class ApplicationServiceTransactionWorkerStore( service: ApplicationService, events: List[EventBase], ephemeral: List[JsonDict], + to_device_messages: List[JsonDict], + device_list_summary: DeviceLists, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -203,6 +205,8 @@ class ApplicationServiceTransactionWorkerStore( service: The service who the transaction is for. events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. + to_device_messages: A list of to-device messages to put in the transaction. + device_list_summary: The device list summary to include in the transaction. Returns: A new transaction. @@ -233,7 +237,12 @@ class ApplicationServiceTransactionWorkerStore( (service.id, new_txn_id, event_ids), ) return AppServiceTransaction( - service=service, id=new_txn_id, events=events, ephemeral=ephemeral + service=service, + id=new_txn_id, + events=events, + ephemeral=ephemeral, + to_device_messages=to_device_messages, + device_list_summary=device_list_summary, ) return await self.db_pool.runInteraction( @@ -326,7 +335,12 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) return AppServiceTransaction( - service=service, id=entry["txn_id"], events=events, ephemeral=[] + service=service, + id=entry["txn_id"], + events=events, + ephemeral=[], + to_device_messages=[], + device_list_summary=DeviceLists(), ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: @@ -387,7 +401,7 @@ class ApplicationServiceTransactionWorkerStore( async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: - if type not in ("read_receipt", "presence"): + if type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (type,) @@ -411,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore( "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn ) - async def set_type_stream_id_for_appservice( + async def set_appservice_stream_type_pos( self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if stream_type not in ("read_receipt", "presence"): + if stream_type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (stream_type,) ) - def set_type_stream_id_for_appservice_txn(txn): + def set_appservice_stream_type_pos_txn(txn): stream_id_type = "%s_stream_id" % stream_type txn.execute( "UPDATE application_services_state SET %s = ? WHERE as_id=?" @@ -429,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore( ) await self.db_pool.runInteraction( - "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn + "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ab8766c75b..d2b285e852 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -24,6 +24,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -136,6 +137,79 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() + async def get_new_messages( + self, + user_ids: Collection[str], + from_stream_id: int, + to_stream_id: int, + ) -> Dict[Tuple[str, str], List[JsonDict]]: + """ + Retrieve to-device messages for a given set of user IDs. + + Only to-device messages with stream ids between the given boundaries + (from < X <= to) are returned. + + Note that a stream ID can be shared by multiple copies of the same message with + different recipient devices. Each (device, message_content) tuple has their own + row in the device_inbox table. + + Args: + user_ids: The users to retrieve to-device messages for. + from_stream_id: The lower boundary of stream id to filter with (exclusive). + to_stream_id: The upper boundary of stream id to filter with (inclusive). + + Returns: + A list of to-device messages. + """ + # Bail out if none of these users have any messages + for user_id in user_ids: + if self._device_inbox_stream_cache.has_entity_changed( + user_id, from_stream_id + ): + break + else: + return {} + + def get_new_messages_txn(txn: LoggingTransaction): + # Build a query to select messages from any of the given users that are between + # the given stream id bounds + + # Scope to only the given users. We need to use this method as doing so is + # different across database engines. + many_clause_sql, many_clause_args = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + + sql = f""" + SELECT user_id, device_id, message_json FROM device_inbox + WHERE {many_clause_sql} + AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + + txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id)) + + # Create a dictionary of (user ID, device ID) -> list of messages that + # that device is meant to receive. + recipient_user_id_device_id_to_messages: Dict[ + Tuple[str, str], List[JsonDict] + ] = {} + + for row in txn: + recipient_user_id = row[0] + recipient_device_id = row[1] + message_dict = db_to_json(row[2]) + + recipient_user_id_device_id_to_messages.setdefault( + (recipient_user_id, recipient_device_id), [] + ).append(message_dict) + + return recipient_user_id_device_id_to_messages + + return await self.db_pool.runInteraction( + "get_new_messages", get_new_messages_txn + ) + async def get_new_messages_for_device( self, user_id: str, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index afc516a978..db22501c09 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -579,42 +579,67 @@ class DeviceWorkerStore(SQLBaseStore): } async def get_users_whose_devices_changed( - self, from_key: int, user_ids: Iterable[str] + self, + from_key: int, + filter_user_ids: Optional[Iterable[str]] = None, + to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key: The device lists stream token - user_ids: The user IDs to query for devices. + from_key: The minimum device lists stream token to query device list changes for, + exclusive. + filter_user_ids: If provided, only check if these users have changed their + device lists. + to_key: The maximum device lists stream token to query device list changes for, + inclusive. Returns: - The set of user_ids whose devices have changed since `from_key` + The set of user_ids whose devices have changed since `from_key` (exclusive) + until `to_key` (inclusive). """ - # Get set of users who *may* have changed. Users not in the returned - # list have definitely not changed. - to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + user_ids_to_check = [] + if filter_user_ids is not None: + # Get set of users who *may* have changed. Users not in the returned + # list have definitely not changed. + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + filter_user_ids, from_key + ) - if not to_check: - return set() + if not user_ids_to_check: + return set() def _get_users_whose_devices_changed_txn(txn): changes = set() - sql = """ + sql_args = [from_key] + if to_key: + stream_id_where_clause = "stream_id > ? AND stream_id <= ?" + sql_args += [to_key] + else: + stream_id_where_clause = "stream_id > ?" + + sql = f""" SELECT DISTINCT user_id FROM device_lists_stream - WHERE stream_id > ? - AND + WHERE {stream_id_where_clause} """ - for chunk in batch_iter(to_check, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "user_id", chunk - ) - txn.execute(sql + clause, (from_key,) + tuple(args)) + # TODO: This is starting to get a bit messy + if filter_user_ids: + sql += " AND " + + for chunk in batch_iter(user_ids_to_check, 100): + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + sql_args += args + + txn.execute(sql + clause, sql_args) + changes.update(user_id for user_id, in txn) + else: + txn.execute(sql, sql_args) changes.update(user_id for user_id, in txn) return changes @@ -1393,13 +1418,23 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def add_device_change_to_streams( - self, user_id: str, device_ids: Collection[str], hosts: List[str] - ) -> int: + self, user_id: str, device_ids: Collection[str], hosts: Collection[str] + ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. + + Args: + user_id: The ID of the user whose device changed. + device_ids: The IDs of any changed devices. If empty, this function will return + None. + hosts: The remote destinations that should be notified of the change. + + Returns: + The maximum device list update stream ID which was added to the database, or + None if no updates were added. """ if not device_ids: - return + return None async with self._device_list_id_gen.get_next_mult( len(device_ids) @@ -1469,11 +1504,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn: LoggingTransaction, user_id: str, - device_ids: Collection[str], - hosts: List[str], + device_ids: Iterable[str], + hosts: Iterable[str], stream_ids: List[str], context: Dict[str, str], - ): + ) -> None: for host in hosts: txn.call_after( self._device_list_federation_stream_cache.entity_has_changed, diff --git a/synapse/storage/schema/main/delta/65/06_msc2409_add_device_id_appservice_stream_type.sql b/synapse/storage/schema/main/delta/65/06_msc2409_add_device_id_appservice_stream_type.sql new file mode 100644
index 0000000000..7b40241282 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/06_msc2409_add_device_id_appservice_stream_type.sql
@@ -0,0 +1,23 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to track what to_device stream id that this application +-- service has been caught up to. + +-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible +-- state is useful for determining if we've ever sent traffic for a stream type +-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for +-- one way this can be used. +ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT; \ No newline at end of file diff --git a/synapse/storage/schema/main/delta/65/07_msc3202_add_device_list_appservice_stream_type.sql b/synapse/storage/schema/main/delta/65/07_msc3202_add_device_list_appservice_stream_type.sql new file mode 100644
index 0000000000..a8f518c08b --- /dev/null +++ b/synapse/storage/schema/main/delta/65/07_msc3202_add_device_list_appservice_stream_type.sql
@@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to track what device list changes stream id that this application +-- service has been caught up to. +ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT; \ No newline at end of file diff --git a/synapse/types.py b/synapse/types.py
index fb72f19343..c06ae59c91 100644 --- a/synapse/types.py +++ b/synapse/types.py
@@ -24,6 +24,7 @@ from typing import ( Mapping, MutableMapping, Optional, + Set, Tuple, Type, TypeVar, @@ -751,6 +752,26 @@ class ReadReceipt: data = attr.ib() +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DeviceLists: + """ + Attributes: + changed: List of user_ids whose devices may have changed + left: List of user_ids whose devices we no longer track + """ + + # We need to use a factory here, otherwise `set` is not evaluated at + # object instantiation, but instead at class definition instantiation. + # The latter happening only once, thus always giving you the same sets + # across multiple DeviceLists instances. + # Also see: don't define mutable default arguments. + changed: Set[str] = attr.ib(factory=set) + left: Set[str] = attr.ib(factory=set) + + def __bool__(self) -> bool: + return bool(self.changed or self.left) + + def get_verify_key_from_cross_signing_key(key_info): """Get the key ID and signedjson verify key from a cross-signing key dict diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index f386b5e128..9dd4f26b35 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py
@@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.appservice import ApplicationService from tests import unittest +from tests.test_utils import simple_async_mock def _regex(regex, exclusive=True): @@ -44,13 +45,19 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.store = Mock() + self.store.get_aliases_for_room = simple_async_mock(return_value=[]) + self.store.get_users_in_room = simple_async_mock(return_value=[]) @defer.inlineCallbacks def test_regex_user_id_prefix_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.assertTrue( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested_in_event(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -58,7 +65,11 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.assertFalse( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested_in_event(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -68,7 +79,11 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" self.assertTrue( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested_in_event(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -78,7 +93,12 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.assertTrue( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + # We need to provide the store here in order to carry out room checks + self.service.is_interested_in_event(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -88,7 +108,11 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.assertFalse( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested_in_event(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -103,7 +127,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event(self.event, self.store) ) ) ) @@ -156,7 +180,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertFalse( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event(self.event, self.store) ) ) ) @@ -175,7 +199,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(self.event, self.store) + self.service.is_interested_in_event(self.event, self.store) ) ) ) @@ -189,7 +213,11 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.content = {"membership": "invite"} self.event.state_key = self.service.sender self.assertTrue( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested_in_event(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -205,7 +233,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.assertTrue( ( yield defer.ensureDeferred( - self.service.is_interested(event=self.event, store=self.store) + self.service.is_interested_in_event( + event=self.event, store=self.store + ) ) ) ) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 55f0899bae..1c12cc5f49 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py
@@ -14,14 +14,18 @@ from unittest.mock import Mock from twisted.internet import defer +from twisted.internet.testing import MemoryReactor from synapse.appservice import ApplicationServiceState from synapse.appservice.scheduler import ( + ApplicationServiceScheduler, _Recoverer, - _ServiceQueuer, _TransactionController, ) from synapse.logging.context import make_deferred_yieldable +from synapse.server import HomeServer +from synapse.types import DeviceLists +from synapse.util import Clock from tests import unittest from tests.test_utils import simple_async_mock @@ -57,8 +61,13 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) + # txn made and saved self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[] # txn made and saved + service=service, + events=events, + ephemeral=[], + to_device_messages=[], + device_list_summary=DeviceLists(), ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -78,8 +87,13 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) + # txn made and saved self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[] # txn made and saved + service=service, + events=events, + ephemeral=[], + to_device_messages=[], + device_list_summary=DeviceLists(), ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -102,7 +116,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[] + service=service, + events=events, + ephemeral=[], + to_device_messages=[], + device_list_summary=DeviceLists(), ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -189,38 +207,45 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.callback.assert_called_once_with(self.recoverer) -class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): - def setUp(self): +class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.scheduler = ApplicationServiceScheduler(hs) self.txn_ctrl = Mock() self.txn_ctrl.send = simple_async_mock() - self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock()) + + # Replace instantiated _TransactionController instances with our Mock + self.scheduler.txn_ctrl = self.txn_ctrl + self.scheduler.queuer.txn_ctrl = self.txn_ctrl def test_send_single_event_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4) event = Mock() - self.queuer.enqueue_event(service, event) - self.txn_ctrl.send.assert_called_once_with(service, [event], []) + self.scheduler.enqueue_for_appservice(service, events=[event]) + self.txn_ctrl.send.assert_called_once_with( + service, [event], [], [], DeviceLists() + ) def test_send_single_event_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock( - side_effect=lambda x, y, z: make_deferred_yieldable(d) - ) + self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") event3 = Mock(event_id="third") # Send an event and don't resolve it just yet. - self.queuer.enqueue_event(service, event) + self.scheduler.enqueue_for_appservice(service, events=[event]) # Send more events: expect send() to NOT be called multiple times. - self.queuer.enqueue_event(service, event2) - self.queuer.enqueue_event(service, event3) - self.txn_ctrl.send.assert_called_with(service, [event], []) + # (call enqueue_for_appservice multiple times deliberately) + self.scheduler.enqueue_for_appservice(service, events=[event2]) + self.scheduler.enqueue_for_appservice(service, events=[event3]) + self.txn_ctrl.send.assert_called_with(service, [event], [], [], DeviceLists()) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3], []) + self.txn_ctrl.send.assert_called_with( + service, [event2, event3], [], [], DeviceLists() + ) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): @@ -238,23 +263,29 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): send_return_list = [srv_1_defer, srv_2_defer] - def do_send(x, y, z): + def do_send(*args, **kwargs): return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) # send events for different ASes and make sure they are sent - self.queuer.enqueue_event(srv1, srv_1_event) - self.queuer.enqueue_event(srv1, srv_1_event2) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], []) - self.queuer.enqueue_event(srv2, srv_2_event) - self.queuer.enqueue_event(srv2, srv_2_event2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], []) + self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) + self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) + self.txn_ctrl.send.assert_called_with( + srv1, [srv_1_event], [], [], DeviceLists() + ) + self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) + self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) + self.txn_ctrl.send.assert_called_with( + srv2, [srv_2_event], [], [], DeviceLists() + ) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], []) + self.txn_ctrl.send.assert_called_with( + srv2, [srv_2_event2], [], [], DeviceLists() + ) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -262,7 +293,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): srv_2_defer = defer.Deferred() send_return_list = [srv_1_defer, srv_2_defer] - def do_send(x, y, z): + def do_send(*args, **kwargs): return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) @@ -270,67 +301,81 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): service = Mock(id=4, name="service") event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)] for event in event_list: - self.queuer.enqueue_event(service, event) + self.scheduler.enqueue_for_appservice(service, [event], []) # Expect the first event to be sent immediately. - self.txn_ctrl.send.assert_called_with(service, [event_list[0]], []) + self.txn_ctrl.send.assert_called_with( + service, [event_list[0]], [], [], DeviceLists() + ) srv_1_defer.callback(service) # Then send the next 100 events - self.txn_ctrl.send.assert_called_with(service, event_list[1:101], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[1:101], [], [], DeviceLists() + ) srv_2_defer.callback(service) # Then the final 99 events - self.txn_ctrl.send.assert_called_with(service, event_list[101:], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[101:], [], [], DeviceLists() + ) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event")] - self.queuer.enqueue_ephemeral(service, event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], DeviceLists() + ) def test_send_multiple_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] - self.queuer.enqueue_ephemeral(service, event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], DeviceLists() + ) def test_send_single_ephemeral_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock( - side_effect=lambda x, y, z: make_deferred_yieldable(d) - ) + self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) service = Mock(id=4) event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")] event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")] event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")] # Send an event and don't resolve it just yet. - self.queuer.enqueue_ephemeral(service, event_list_1) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_1) # Send more events: expect send() to NOT be called multiple times. - self.queuer.enqueue_ephemeral(service, event_list_2) - self.queuer.enqueue_ephemeral(service, event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) + self.txn_ctrl.send.assert_called_with( + service, [], event_list_1, [], DeviceLists() + ) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent - self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3) + self.txn_ctrl.send.assert_called_with( + service, [], event_list_2 + event_list_3, [], DeviceLists() + ) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_send_large_txns_ephemeral(self): d = defer.Deferred() - self.txn_ctrl.send = Mock( - side_effect=lambda x, y, z: make_deferred_yieldable(d) - ) + self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) # Expect the event to be sent immediately. service = Mock(id=4, name="service") first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)] second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] event_list = first_chunk + second_chunk - self.queuer.enqueue_ephemeral(service, event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) + self.txn_ctrl.send.assert_called_once_with( + service, [], first_chunk, [], DeviceLists() + ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk) + self.txn_ctrl.send.assert_called_with( + service, [], second_chunk, [], DeviceLists() + ) self.assertEquals(2, self.txn_ctrl.send.call_count) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index d6f14e2dba..def93caf17 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -1,4 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2015-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,18 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable, List, Optional from unittest.mock import Mock from twisted.internet import defer +import synapse.rest.admin +import synapse.storage +from synapse.appservice import ApplicationService from synapse.handlers.appservice import ApplicationServicesHandler +from synapse.rest.client import login, receipts, room, sendtodevice from synapse.types import RoomStreamToken +from synapse.util.stringutils import random_string -from tests.test_utils import make_awaitable +from tests import unittest +from tests.test_utils import make_awaitable, simple_async_mock from tests.utils import MockClock -from .. import unittest - class AppServiceHandlerTestCase(unittest.TestCase): """Tests the ApplicationServicesHandler.""" @@ -36,6 +41,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): hs.get_datastore.return_value = self.mock_store self.mock_store.get_received_ts.return_value = make_awaitable(0) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) + self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( + None + ) hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() @@ -63,8 +71,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): ] self.handler.notify_interested_services(RoomStreamToken(None, 1)) - self.mock_scheduler.submit_event_for_as.assert_called_once_with( - interested_service, event + self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( + interested_service, events=[event] ) def test_query_user_exists_unknown_user(self): @@ -111,11 +119,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): room_id = "!alpha:bet" servers = ["aperture"] - interested_service = self._mkservice_alias(is_interested_in_alias=True) + interested_service = self._mkservice_alias(is_room_alias_in_namespace=True) services = [ - self._mkservice_alias(is_interested_in_alias=False), + self._mkservice_alias(is_room_alias_in_namespace=False), interested_service, - self._mkservice_alias(is_interested_in_alias=False), + self._mkservice_alias(is_room_alias_in_namespace=False), ] self.mock_as_api.query_alias.return_value = make_awaitable(True) @@ -261,7 +269,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): """ interested_service = self._mkservice(is_interested=True) services = [interested_service] - self.mock_store.get_app_services.return_value = services self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( 579 @@ -275,10 +282,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.handler.notify_interested_services_ephemeral( "receipt_key", 580, ["@fakerecipient:example.com"] ) - self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with( - interested_service, [event] + self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( + interested_service, ephemeral=[event] ) - self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with( + self.mock_store.set_appservice_stream_type_pos.assert_called_once_with( interested_service, "read_receipt", 580, @@ -305,19 +312,287 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.handler.notify_interested_services_ephemeral( "receipt_key", 580, ["@fakerecipient:example.com"] ) - self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called() + # This method will be called, but with an empty list of events + self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( + interested_service, ephemeral=[] + ) def _mkservice(self, is_interested, protocols=None): service = Mock() - service.is_interested.return_value = make_awaitable(is_interested) + service.is_interested_in_event.return_value = make_awaitable(is_interested) + service.is_interested_in_room.return_value = make_awaitable(is_interested) + service.is_interested_in_presence.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" service.protocols = protocols return service - def _mkservice_alias(self, is_interested_in_alias): + def _mkservice_alias(self, is_room_alias_in_namespace): service = Mock() - service.is_interested_in_alias.return_value = is_interested_in_alias + service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace service.token = "mock_service_token" service.url = "mock_service_url" return service + + +class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): + """ + Tests that the ApplicationServicesHandler sends events to application + services correctly. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + sendtodevice.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Mock the ApplicationServiceScheduler's _TransactionController's send method so that + # we can track any outgoing ephemeral events + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock + + # Mock out application services, and allow defining our own in tests + self._services: List[ApplicationService] = [] + self.hs.get_datastore().get_app_services = Mock(return_value=self._services) + + # A user on the homeserver. + self.local_user_device_id = "local_device" + self.local_user = self.register_user("local_user", "password") + self.local_user_token = self.login( + "local_user", "password", self.local_user_device_id + ) + + # A user on the homeserver which lies within an appservice's exclusive user namespace. + self.exclusive_as_user_device_id = "exclusive_as_device" + self.exclusive_as_user = self.register_user("exclusive_as_user", "password") + self.exclusive_as_user_token = self.login( + "exclusive_as_user", "password", self.exclusive_as_user_device_id + ) + + # Ensure that the mock is reset after creating devices (and thus updating device lists) + self.send_mock.reset_mock() + + @unittest.override_config( + {"experimental_features": {"msc2409_to_device_messages_enabled": True}} + ) + def test_application_services_receive_local_to_device(self): + """ + Test that when a user sends a to-device message to another user + that is an application service's user namespace, the + application service will receive it. + """ + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@exclusive_as_user:.+", + "exclusive": True, + } + ], + }, + ) + + # Have local_user send a to-device message to exclusive_as_user + message_content = {"some_key": "some really interesting value"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/3", + content={ + "messages": { + self.exclusive_as_user: { + self.exclusive_as_user_device_id: message_content + } + } + }, + access_token=self.local_user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + # Have exclusive_as_user send a to-device message to local_user + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/4", + content={ + "messages": { + self.local_user: {self.local_user_device_id: message_content} + } + }, + access_token=self.exclusive_as_user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + # Check if our application service - that is interested in exclusive_as_user - received + # the to-device message as part of an AS transaction. + # Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS. + # + # The uninterested application service should not have been notified at all. + self.send_mock.assert_called_once() + ( + service, + _events, + _ephemeral, + to_device_messages, + _device_list_summary, + ) = self.send_mock.call_args[0] + + # Assert that this was the same to-device message that local_user sent + self.assertEqual(service, interested_appservice) + self.assertEqual(to_device_messages[0]["type"], "m.room_key_request") + self.assertEqual(to_device_messages[0]["sender"], self.local_user) + + # Additional fields 'to_user_id' and 'to_device_id' specifically for + # to-device messages via the AS API + self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user) + self.assertEqual( + to_device_messages[0]["to_device_id"], self.exclusive_as_user_device_id + ) + self.assertEqual(to_device_messages[0]["content"], message_content) + + @unittest.override_config( + {"experimental_features": {"msc2409_to_device_messages_enabled": True}} + ) + def test_application_services_receive_bursts_of_to_device(self): + """ + Test that when a user sends >100 to-device messages at once, any + interested AS's will receive them in separate transactions. + + Also tests that uninterested application services do not receive messages. + """ + # Register two application services with exclusive interest in a user + interested_appservices = [] + for _ in range(2): + appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@exclusive_as_user:.+", + "exclusive": True, + } + ], + }, + ) + interested_appservices.append(appservice) + + # ...and an application service which does not have any user interest. + self._register_application_service() + + to_device_message_content = { + "some key": "some interesting value", + } + + # We need to send a large burst of to-device messages. We also would like to + # include them all in the same application service transaction so that we can + # test large transactions. + # + # To do this, we can send a single to-device message to many user devices at + # once. + # + # We insert number_of_messages - 1 messages into the database directly. We'll then + # send a final to-device message to the real device, which will also kick off + # an AS transaction (as just inserting messages into the DB won't). + number_of_messages = 150 + fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)] + messages = { + self.exclusive_as_user: { + device_id: to_device_message_content for device_id in fake_device_ids + } + } + + # Create a fake device per message. We can't send to-device messages to + # a device that doesn't exist. + self.get_success( + self.hs.get_datastore().db_pool.simple_insert_many( + desc="test_application_services_receive_burst_of_to_device", + table="devices", + values=[ + { + "user_id": self.exclusive_as_user, + "device_id": device_id, + } + for device_id in fake_device_ids + ], + ) + ) + + # Seed the device_inbox table with our fake messages + self.get_success( + self.hs.get_datastore().add_messages_to_device_inbox(messages, {}) + ) + + # Now have local_user send a final to-device message to exclusive_as_user. All unsent + # to-device messages should be sent to any application services + # interested in exclusive_as_user. + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/4", + content={ + "messages": { + self.exclusive_as_user: { + self.exclusive_as_user_device_id: to_device_message_content + } + } + }, + access_token=self.local_user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + self.send_mock.assert_called() + + # Count the total number of to-device messages that were sent out per-service. + # Ensure that we only sent to-device messages to interested services, and that + # each interested service received the full count of to-device messages. + service_id_to_message_count: Dict[str, int] = {} + + for call in self.send_mock.call_args_list: + ( + service, + _events, + _ephemeral, + to_device_messages, + _device_list_summary, + ) = call[0] + + # Check that this was made to an interested service + self.assertIn(service, interested_appservices) + + # Add to the count of messages for this application service + service_id_to_message_count.setdefault(service.id, 0) + service_id_to_message_count[service.id] += len(to_device_messages) + + # Assert that each interested service received the full count of messages + for count in service_id_to_message_count.values(): + self.assertEqual(count, number_of_messages) + + def _register_application_service( + self, + namespaces: Optional[Dict[str, Iterable[Dict]]] = None, + ) -> ApplicationService: + """ + Register a new application service, with the given namespaces of interest. + + Args: + namespaces: A dictionary containing any user, room or alias namespaces that + the application service is interested in. + + Returns: + The registered application service. + """ + # Create an application service + appservice = ApplicationService( + token=None, + hostname="example.com", + id=random_string(10), + sender="@as:example.com", + rate_limited=False, + namespaces=namespaces, + supports_ephemeral=True, + ) + + # Register the application service + self._services.append(appservice) + + return appservice diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 329490caad..bb1411232a 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.types import DeviceLists from synapse.util import Clock from tests import unittest @@ -266,7 +267,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): service = Mock(id=self.as_list[0]["id"]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( - defer.ensureDeferred(self.store.create_appservice_txn(service, events, [])) + defer.ensureDeferred( + self.store.create_appservice_txn(service, events, [], [], DeviceLists()) + ) ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) @@ -280,7 +283,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + txn = self.get_success( + self.store.create_appservice_txn(service, events, [], [], DeviceLists()) + ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -291,7 +296,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): service = Mock(id=self.as_list[0]["id"]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + txn = self.get_success( + self.store.create_appservice_txn(service, events, [], [], DeviceLists()) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -313,7 +320,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events)) self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + txn = self.get_success( + self.store.create_appservice_txn(service, events, [], [], DeviceLists()) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -481,10 +490,10 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ValueError, ) - def test_set_type_stream_id_for_appservice(self) -> None: + def test_set_appservice_stream_type_pos(self) -> None: read_receipt_value = 1024 self.get_success( - self.store.set_type_stream_id_for_appservice( + self.store.set_appservice_stream_type_pos( self.service, "read_receipt", read_receipt_value ) ) @@ -494,7 +503,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): self.assertEqual(result, read_receipt_value) self.get_success( - self.store.set_type_stream_id_for_appservice( + self.store.set_appservice_stream_type_pos( self.service, "presence", read_receipt_value ) ) @@ -503,9 +512,9 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ) self.assertEqual(result, read_receipt_value) - def test_set_type_stream_id_for_appservice_invalid_type(self) -> None: + def test_set_appservice_stream_type_pos_invalid_type(self) -> None: self.get_failure( - self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024), + self.store.set_appservice_stream_type_pos(self.service, "foobar", 1024), ValueError, )