diff options
-rw-r--r-- | changelog.d/11215.feature | 1 | ||||
-rw-r--r-- | synapse/appservice/__init__.py | 3 | ||||
-rw-r--r-- | synapse/appservice/api.py | 30 | ||||
-rw-r--r-- | synapse/appservice/scheduler.py | 93 | ||||
-rw-r--r-- | synapse/config/experimental.py | 8 | ||||
-rw-r--r-- | synapse/handlers/appservice.py | 135 | ||||
-rw-r--r-- | synapse/notifier.py | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/appservice.py | 24 | ||||
-rw-r--r-- | synapse/storage/databases/main/deviceinbox.py | 76 | ||||
-rw-r--r-- | synapse/storage/schema/main/delta/65/06_msc2409_add_device_id_appservice_stream_type.sql | 23 | ||||
-rw-r--r-- | tests/appservice/test_scheduler.py | 106 | ||||
-rw-r--r-- | tests/handlers/test_appservice.py | 280 | ||||
-rw-r--r-- | tests/storage/test_appservice.py | 10 |
13 files changed, 675 insertions, 118 deletions
diff --git a/changelog.d/11215.feature b/changelog.d/11215.feature new file mode 100644 index 0000000000..468020834b --- /dev/null +++ b/changelog.d/11215.feature @@ -0,0 +1 @@ +Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). Disabled by default. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index f9d3bd337d..e33e69eed1 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -330,11 +330,13 @@ class AppServiceTransaction: id: int, events: List[EventBase], ephemeral: List[JsonDict], + to_device_messages: List[JsonDict], ): self.service = service self.id = id self.events = events self.ephemeral = ephemeral + self.to_device_messages = to_device_messages async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -348,6 +350,7 @@ class AppServiceTransaction: service=self.service, events=self.events, ephemeral=self.ephemeral, + to_device_messages=self.to_device_messages, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index f51b636417..ca58f92339 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import urllib -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from prometheus_client import Counter @@ -204,12 +204,25 @@ class ApplicationServiceApi(SimpleHttpClient): service: "ApplicationService", events: List[EventBase], ephemeral: List[JsonDict], + to_device_messages: List[JsonDict], txn_id: Optional[int] = None, - ): + ) -> bool: + """ + Push data to an application service. + Args: + service: The application service to send to. + events: The persistent events to send. + ephemeral: The ephemeral events to send. + to_device_messages: The to-device messages to send. + txn_id: An unique ID to assign to this transaction. Application services should + deduplicate transactions received with identitical IDs. + Returns: + True if the task succeeded, False if it failed. + """ if service.url is None: return True - events = self._serialize(service, events) + serialized_events = self._serialize(service, events) if txn_id is None: logger.warning( @@ -220,10 +233,15 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) # Never send ephemeral events to appservices that do not support it + body: Dict[str, List[JsonDict]] = {"events": serialized_events} if service.supports_ephemeral: - body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral} - else: - body = {"events": events} + body.update( + { + # TODO: Update to stable prefixes once MSC2409 completes FCP merge. + "de.sorunome.msc2409.ephemeral": ephemeral, + "de.sorunome.msc2409.to_device": to_device_messages, + } + ) try: await self.put_json( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 6a2ce99b55..dae952dc13 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,7 +48,7 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging -from typing import List, Optional +from typing import Dict, Iterable, List, Optional from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.events import EventBase @@ -65,6 +65,9 @@ MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100 # Maximum number of ephemeral events to provide in an AS transaction. MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100 +# Maximum number of to-device messages to provide in an AS transaction. +MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION = 100 + class ApplicationServiceScheduler: """Public facing API for this module. Does the required DI to tie the @@ -91,13 +94,39 @@ class ApplicationServiceScheduler: for service in services: self.txn_ctrl.start_recoverer(service) - def submit_event_for_as(self, service: ApplicationService, event: EventBase): - self.queuer.enqueue_event(service, event) + def enqueue_for_appservice( + self, + appservice: ApplicationService, + events: Optional[Iterable[EventBase]] = None, + ephemeral: Optional[Iterable[JsonDict]] = None, + to_device_messages: Optional[Iterable[JsonDict]] = None, + ) -> None: + """ + Enqueue some data to be sent off to an application service. + Args: + appservice: The application service to create and send a transaction to. + events: The persistent room events to send. + ephemeral: The ephemeral events to send. + to_device_messages: The to-device messages to send. These differ from normal + to-device messages sent to clients, as they have 'to_device_id' and + 'to_user_id' fields. + """ + # We purposefully allow this method to run with empty events/ephemeral + # iterables, so that callers do not need to check iterable size themselves. + if not events and not ephemeral and not to_device_messages: + return + + if events: + self.queuer.queued_events.setdefault(appservice.id, []).extend(events) + if ephemeral: + self.queuer.queued_ephemeral.setdefault(appservice.id, []).extend(ephemeral) + if to_device_messages: + self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend( + to_device_messages + ) - def submit_ephemeral_events_for_as( - self, service: ApplicationService, events: List[JsonDict] - ): - self.queuer.enqueue_ephemeral(service, events) + # Kick off a new application service transaction + self.queuer.start_background_request(appservice) class _ServiceQueuer: @@ -109,15 +138,19 @@ class _ServiceQueuer: """ def __init__(self, txn_ctrl, clock): - self.queued_events = {} # dict of {service_id: [events]} - self.queued_ephemeral = {} # dict of {service_id: [events]} + # dict of {service_id: [events]} + self.queued_events: Dict[str, List[EventBase]] = {} + # dict of {service_id: [event_json]} + self.queued_ephemeral: Dict[str, List[JsonDict]] = {} + # dict of {service_id: [to_device_message_json]} + self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} # the appservices which currently have a transaction in flight self.requests_in_flight = set() self.txn_ctrl = txn_ctrl self.clock = clock - def _start_background_request(self, service): + def start_background_request(self, service): # start a sender for this appservice if we don't already have one if service.id in self.requests_in_flight: return @@ -126,14 +159,6 @@ class _ServiceQueuer: "as-sender-%s" % (service.id,), self._send_request, service ) - def enqueue_event(self, service: ApplicationService, event: EventBase): - self.queued_events.setdefault(service.id, []).append(event) - self._start_background_request(service) - - def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]): - self.queued_ephemeral.setdefault(service.id, []).extend(events) - self._start_background_request(service) - async def _send_request(self, service: ApplicationService): # sanity-check: we shouldn't get here if this service already has a sender # running. @@ -150,11 +175,21 @@ class _ServiceQueuer: ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] - if not events and not ephemeral: + all_to_device_messages = self.queued_to_device_messages.get( + service.id, [] + ) + to_device_messages_to_send = all_to_device_messages[ + :MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION + ] + del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION] + + if not events and not ephemeral and not to_device_messages_to_send: return try: - await self.txn_ctrl.send(service, events, ephemeral) + await self.txn_ctrl.send( + service, events, ephemeral, to_device_messages_to_send + ) except Exception: logger.exception("AS request failed") finally: @@ -191,10 +226,24 @@ class _TransactionController: service: ApplicationService, events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, - ): + to_device_messages: Optional[List[JsonDict]] = None, + ) -> None: + """ + Create a transaction with the given data and send to the provided + application service. + + Args: + service: The application service to send the transaction to. + events: The persistent events to include in the transaction. + ephemeral: The ephemeral events to include in the transaction. + to_device_messages: The to-device messages to include in the transaction. + """ try: txn = await self.store.create_appservice_txn( - service=service, events=events, ephemeral=ephemeral or [] + service=service, + events=events, + ephemeral=ephemeral or [], + to_device_messages=to_device_messages or [], ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d78a15097c..e481fc16b6 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -49,3 +49,11 @@ class ExperimentalConfig(Config): # MSC3030 (Jump to date API endpoint) self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) + + # MSC2409 (this setting only relates to optionally sending to-device messages). + # Presence, typing and read receipt EDUs are already sent to application services that + # have opted in to receive them. This setting, if enabled, adds to-device messages + # to that list. + self.msc2409_to_device_messages_enabled: bool = experimental.get( + "msc2409_to_device_messages_enabled", False + ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 9abdad262b..c92668642a 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -55,6 +55,9 @@ class ApplicationServicesHandler: self.clock = hs.get_clock() self.notify_appservices = hs.config.appservice.notify_appservices self.event_sources = hs.get_event_sources() + self._msc2409_to_device_messages_enabled = ( + hs.config.experimental.msc2409_to_device_messages_enabled + ) self.current_max = 0 self.is_processing = False @@ -132,7 +135,9 @@ class ApplicationServicesHandler: # Fork off pushes to these services for service in services: - self.scheduler.submit_event_for_as(service, event) + self.scheduler.enqueue_for_appservice( + service, events=[event] + ) now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) @@ -199,8 +204,9 @@ class ApplicationServicesHandler: Args: stream_key: The stream the event came from. - `stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other - value for `stream_key` will cause this function to return early. + `stream_key` can be "typing_key", "receipt_key", "presence_key" or + "to_device_key". Any other value for `stream_key` will cause this function + to return early. Ephemeral events will only be pushed to appservices that have opted into receiving them by setting `push_ephemeral` to true in their registration @@ -216,8 +222,15 @@ class ApplicationServicesHandler: if not self.notify_appservices: return - # Ignore any unsupported streams - if stream_key not in ("typing_key", "receipt_key", "presence_key"): + # Notify appservices of updates in ephemeral event streams. + # Only the following streams are currently supported. + # FIXME: We should use constants for these values. + if stream_key not in ( + "typing_key", + "receipt_key", + "presence_key", + "to_device_key", + ): return # Assert that new_token is an integer (and not a RoomStreamToken). @@ -233,6 +246,13 @@ class ApplicationServicesHandler: # Additional context: https://github.com/matrix-org/synapse/pull/11137 assert isinstance(new_token, int) + # Ignore to-device messages if the feature flag is not enabled + if ( + stream_key == "to_device_key" + and not self._msc2409_to_device_messages_enabled + ): + return + # Check whether there are any appservices which have registered to receive # ephemeral events. # @@ -266,7 +286,7 @@ class ApplicationServicesHandler: with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: if stream_key == "typing_key": - # Note that we don't persist the token (via set_type_stream_id_for_appservice) + # Note that we don't persist the token (via set_appservice_stream_type_pos) # for typing_key due to performance reasons and due to their highly # ephemeral nature. # @@ -274,7 +294,7 @@ class ApplicationServicesHandler: # and, if they apply to this application service, send it off. events = await self._handle_typing(service, new_token) if events: - self.scheduler.submit_ephemeral_events_for_as(service, events) + self.scheduler.enqueue_for_appservice(service, ephemeral=events) continue # Since we read/update the stream position for this AS/stream @@ -285,28 +305,37 @@ class ApplicationServicesHandler: ): if stream_key == "receipt_key": events = await self._handle_receipts(service, new_token) - if events: - self.scheduler.submit_ephemeral_events_for_as( - service, events - ) + self.scheduler.enqueue_for_appservice(service, ephemeral=events) # Persist the latest handled stream token for this appservice - await self.store.set_type_stream_id_for_appservice( + await self.store.set_appservice_stream_type_pos( service, "read_receipt", new_token ) elif stream_key == "presence_key": events = await self._handle_presence(service, users, new_token) - if events: - self.scheduler.submit_ephemeral_events_for_as( - service, events - ) + self.scheduler.enqueue_for_appservice(service, ephemeral=events) # Persist the latest handled stream token for this appservice - await self.store.set_type_stream_id_for_appservice( + await self.store.set_appservice_stream_type_pos( service, "presence", new_token ) + elif stream_key == "to_device_key": + # Retrieve a list of to-device message events, as well as the + # maximum stream token of the messages we were able to retrieve. + to_device_messages = await self._get_to_device_messages( + service, new_token, users + ) + self.scheduler.enqueue_for_appservice( + service, to_device_messages=to_device_messages + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_appservice_stream_type_pos( + service, "to_device", new_token + ) + async def _handle_typing( self, service: ApplicationService, new_token: int ) -> List[JsonDict]: @@ -440,6 +469,78 @@ class ApplicationServicesHandler: return events + async def _get_to_device_messages( + self, + service: ApplicationService, + new_token: int, + users: Collection[Union[str, UserID]], + ) -> List[JsonDict]: + """ + Given an application service, determine which events it should receive + from those between the last-recorded typing event stream token for this + appservice and the given stream token. + + Args: + service: The application service to check for which events it should receive. + new_token: The latest to-device event stream token. + users: The users that should receive new to-device messages. + + Returns: + A list of JSON dictionaries containing data derived from the typing events + that should be sent to the given application service. + """ + # Get the stream token that this application service has processed up until + from_key = await self.store.get_type_stream_id_for_appservice( + service, "to_device" + ) + + # Filter out users that this appservice is not interested in + users_appservice_is_interested_in: List[str] = [] + for user in users: + # FIXME: We should do this farther up the call stack. We currently repeat + # this operation in _handle_presence. + if isinstance(user, UserID): + user = user.to_string() + + if service.is_interested_in_user(user): + users_appservice_is_interested_in.append(user) + + if not users_appservice_is_interested_in: + # Return early if the AS was not interested in any of these users + return [] + + # Retrieve the to-device messages for each user + recipient_user_id_device_id_to_messages = await self.store.get_new_messages( + users_appservice_is_interested_in, + from_key, + new_token, + ) + + # According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields + # to the event JSON so that the application service will know which user/device + # combination this messages was intended for. + # + # So we mangle this dict into a flat list of to-device messages with the relevant + # user ID and device ID embedded inside each message dict. + message_payload: List[JsonDict] = [] + for ( + user_id, + device_id, + ), messages in recipient_user_id_device_id_to_messages.items(): + for message_json in messages: + # Remove 'message_id' from the to-device message, as it's an internal ID + message_json.pop("message_id", None) + + message_payload.append( + { + "to_user_id": user_id, + "to_device_id": device_id, + **message_json, + } + ) + + return message_payload + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. diff --git a/synapse/notifier.py b/synapse/notifier.py index 60e5409895..3b24a7f4ba 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -452,7 +452,9 @@ class Notifier: users, ) except Exception: - logger.exception("Error notifying application services of event") + logger.exception( + "Error notifying application services of ephemeral event" + ) def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happened diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 4a883dc166..68ba330432 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -194,6 +194,7 @@ class ApplicationServiceTransactionWorkerStore( service: ApplicationService, events: List[EventBase], ephemeral: List[JsonDict], + to_device_messages: List[JsonDict], ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -203,6 +204,7 @@ class ApplicationServiceTransactionWorkerStore( service: The service who the transaction is for. events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. + to_device_messages: A list of to-device messages to put in the transaction. Returns: A new transaction. @@ -233,7 +235,11 @@ class ApplicationServiceTransactionWorkerStore( (service.id, new_txn_id, event_ids), ) return AppServiceTransaction( - service=service, id=new_txn_id, events=events, ephemeral=ephemeral + service=service, + id=new_txn_id, + events=events, + ephemeral=ephemeral, + to_device_messages=to_device_messages, ) return await self.db_pool.runInteraction( @@ -326,7 +332,11 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) return AppServiceTransaction( - service=service, id=entry["txn_id"], events=events, ephemeral=[] + service=service, + id=entry["txn_id"], + events=events, + ephemeral=[], + to_device_messages=[], ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: @@ -387,7 +397,7 @@ class ApplicationServiceTransactionWorkerStore( async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: - if type not in ("read_receipt", "presence"): + if type not in ("read_receipt", "presence", "to_device"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (type,) @@ -411,16 +421,16 @@ class ApplicationServiceTransactionWorkerStore( "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn ) - async def set_type_stream_id_for_appservice( + async def set_appservice_stream_type_pos( self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if stream_type not in ("read_receipt", "presence"): + if stream_type not in ("read_receipt", "presence", "to_device"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (stream_type,) ) - def set_type_stream_id_for_appservice_txn(txn): + def set_appservice_stream_type_pos_txn(txn): stream_id_type = "%s_stream_id" % stream_type txn.execute( "UPDATE application_services_state SET %s = ? WHERE as_id=?" @@ -429,7 +439,7 @@ class ApplicationServiceTransactionWorkerStore( ) await self.db_pool.runInteraction( - "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn + "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index ab8766c75b..d2b285e852 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -24,6 +24,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -136,6 +137,79 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() + async def get_new_messages( + self, + user_ids: Collection[str], + from_stream_id: int, + to_stream_id: int, + ) -> Dict[Tuple[str, str], List[JsonDict]]: + """ + Retrieve to-device messages for a given set of user IDs. + + Only to-device messages with stream ids between the given boundaries + (from < X <= to) are returned. + + Note that a stream ID can be shared by multiple copies of the same message with + different recipient devices. Each (device, message_content) tuple has their own + row in the device_inbox table. + + Args: + user_ids: The users to retrieve to-device messages for. + from_stream_id: The lower boundary of stream id to filter with (exclusive). + to_stream_id: The upper boundary of stream id to filter with (inclusive). + + Returns: + A list of to-device messages. + """ + # Bail out if none of these users have any messages + for user_id in user_ids: + if self._device_inbox_stream_cache.has_entity_changed( + user_id, from_stream_id + ): + break + else: + return {} + + def get_new_messages_txn(txn: LoggingTransaction): + # Build a query to select messages from any of the given users that are between + # the given stream id bounds + + # Scope to only the given users. We need to use this method as doing so is + # different across database engines. + many_clause_sql, many_clause_args = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + + sql = f""" + SELECT user_id, device_id, message_json FROM device_inbox + WHERE {many_clause_sql} + AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + + txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id)) + + # Create a dictionary of (user ID, device ID) -> list of messages that + # that device is meant to receive. + recipient_user_id_device_id_to_messages: Dict[ + Tuple[str, str], List[JsonDict] + ] = {} + + for row in txn: + recipient_user_id = row[0] + recipient_device_id = row[1] + message_dict = db_to_json(row[2]) + + recipient_user_id_device_id_to_messages.setdefault( + (recipient_user_id, recipient_device_id), [] + ).append(message_dict) + + return recipient_user_id_device_id_to_messages + + return await self.db_pool.runInteraction( + "get_new_messages", get_new_messages_txn + ) + async def get_new_messages_for_device( self, user_id: str, diff --git a/synapse/storage/schema/main/delta/65/06_msc2409_add_device_id_appservice_stream_type.sql b/synapse/storage/schema/main/delta/65/06_msc2409_add_device_id_appservice_stream_type.sql new file mode 100644 index 0000000000..7b40241282 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/06_msc2409_add_device_id_appservice_stream_type.sql @@ -0,0 +1,23 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to track what to_device stream id that this application +-- service has been caught up to. + +-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible +-- state is useful for determining if we've ever sent traffic for a stream type +-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for +-- one way this can be used. +ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT; \ No newline at end of file diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 55f0899bae..8f9afa8538 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -14,14 +14,17 @@ from unittest.mock import Mock from twisted.internet import defer +from twisted.internet.testing import MemoryReactor from synapse.appservice import ApplicationServiceState from synapse.appservice.scheduler import ( + ApplicationServiceScheduler, _Recoverer, - _ServiceQueuer, _TransactionController, ) from synapse.logging.context import make_deferred_yieldable +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.test_utils import simple_async_mock @@ -58,7 +61,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[] # txn made and saved + service=service, + events=events, + ephemeral=[], + to_device_messages=[], # txn made and saved ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -79,7 +85,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[] # txn made and saved + service=service, + events=events, + ephemeral=[], + to_device_messages=[], # txn made and saved ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -102,7 +111,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[] + service=service, events=events, ephemeral=[], to_device_messages=[] ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -189,38 +198,41 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.callback.assert_called_once_with(self.recoverer) -class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): - def setUp(self): +class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.scheduler = ApplicationServiceScheduler(hs) self.txn_ctrl = Mock() self.txn_ctrl.send = simple_async_mock() - self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock()) + + # Replace instantiated _TransactionController instances with our Mock + self.scheduler.txn_ctrl = self.txn_ctrl + self.scheduler.queuer.txn_ctrl = self.txn_ctrl def test_send_single_event_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4) event = Mock() - self.queuer.enqueue_event(service, event) - self.txn_ctrl.send.assert_called_once_with(service, [event], []) + self.scheduler.enqueue_for_appservice(service, events=[event]) + self.txn_ctrl.send.assert_called_once_with(service, [event], [], []) def test_send_single_event_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock( - side_effect=lambda x, y, z: make_deferred_yieldable(d) - ) + self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") event3 = Mock(event_id="third") # Send an event and don't resolve it just yet. - self.queuer.enqueue_event(service, event) + self.scheduler.enqueue_for_appservice(service, events=[event]) # Send more events: expect send() to NOT be called multiple times. - self.queuer.enqueue_event(service, event2) - self.queuer.enqueue_event(service, event3) - self.txn_ctrl.send.assert_called_with(service, [event], []) + # (call enqueue_for_appservice multiple times deliberately) + self.scheduler.enqueue_for_appservice(service, events=[event2]) + self.scheduler.enqueue_for_appservice(service, events=[event3]) + self.txn_ctrl.send.assert_called_with(service, [event], [], []) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3], []) + self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], []) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): @@ -238,23 +250,23 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): send_return_list = [srv_1_defer, srv_2_defer] - def do_send(x, y, z): + def do_send(*args, **kwargs): return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) # send events for different ASes and make sure they are sent - self.queuer.enqueue_event(srv1, srv_1_event) - self.queuer.enqueue_event(srv1, srv_1_event2) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], []) - self.queuer.enqueue_event(srv2, srv_2_event) - self.queuer.enqueue_event(srv2, srv_2_event2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], []) + self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) + self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], []) + self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) + self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], []) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], []) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -262,7 +274,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): srv_2_defer = defer.Deferred() send_return_list = [srv_1_defer, srv_2_defer] - def do_send(x, y, z): + def do_send(*args, **kwargs): return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) @@ -270,67 +282,65 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): service = Mock(id=4, name="service") event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)] for event in event_list: - self.queuer.enqueue_event(service, event) + self.scheduler.enqueue_for_appservice(service, [event], []) # Expect the first event to be sent immediately. - self.txn_ctrl.send.assert_called_with(service, [event_list[0]], []) + self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], []) srv_1_defer.callback(service) # Then send the next 100 events - self.txn_ctrl.send.assert_called_with(service, event_list[1:101], []) + self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], []) srv_2_defer.callback(service) # Then the final 99 events - self.txn_ctrl.send.assert_called_with(service, event_list[101:], []) + self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], []) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event")] - self.queuer.enqueue_ephemeral(service, event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) def test_send_multiple_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] - self.queuer.enqueue_ephemeral(service, event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) def test_send_single_ephemeral_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock( - side_effect=lambda x, y, z: make_deferred_yieldable(d) - ) + self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) service = Mock(id=4) event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")] event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")] event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")] # Send an event and don't resolve it just yet. - self.queuer.enqueue_ephemeral(service, event_list_1) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_1) # Send more events: expect send() to NOT be called multiple times. - self.queuer.enqueue_ephemeral(service, event_list_2) - self.queuer.enqueue_ephemeral(service, event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1, []) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent - self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3) + self.txn_ctrl.send.assert_called_with( + service, [], event_list_2 + event_list_3, [] + ) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_send_large_txns_ephemeral(self): d = defer.Deferred() - self.txn_ctrl.send = Mock( - side_effect=lambda x, y, z: make_deferred_yieldable(d) - ) + self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) # Expect the event to be sent immediately. service = Mock(id=4, name="service") first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)] second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] event_list = first_chunk + second_chunk - self.queuer.enqueue_ephemeral(service, event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, []) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk) + self.txn_ctrl.send.assert_called_with(service, [], second_chunk, []) self.assertEquals(2, self.txn_ctrl.send.call_count) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index d6f14e2dba..e4ec149273 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -1,4 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2015-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,18 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable, List, Optional from unittest.mock import Mock from twisted.internet import defer +import synapse.rest.admin +import synapse.storage +from synapse.appservice import ApplicationService from synapse.handlers.appservice import ApplicationServicesHandler +from synapse.rest.client import login, receipts, room, sendtodevice from synapse.types import RoomStreamToken +from synapse.util.stringutils import random_string -from tests.test_utils import make_awaitable +from tests import unittest +from tests.test_utils import make_awaitable, simple_async_mock from tests.utils import MockClock -from .. import unittest - class AppServiceHandlerTestCase(unittest.TestCase): """Tests the ApplicationServicesHandler.""" @@ -36,6 +41,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): hs.get_datastore.return_value = self.mock_store self.mock_store.get_received_ts.return_value = make_awaitable(0) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) + self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( + None + ) hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() @@ -63,8 +71,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): ] self.handler.notify_interested_services(RoomStreamToken(None, 1)) - self.mock_scheduler.submit_event_for_as.assert_called_once_with( - interested_service, event + self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( + interested_service, events=[event] ) def test_query_user_exists_unknown_user(self): @@ -261,7 +269,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): """ interested_service = self._mkservice(is_interested=True) services = [interested_service] - self.mock_store.get_app_services.return_value = services self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( 579 @@ -275,10 +282,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.handler.notify_interested_services_ephemeral( "receipt_key", 580, ["@fakerecipient:example.com"] ) - self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with( - interested_service, [event] + self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( + interested_service, ephemeral=[event] ) - self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with( + self.mock_store.set_appservice_stream_type_pos.assert_called_once_with( interested_service, "read_receipt", 580, @@ -305,7 +312,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.handler.notify_interested_services_ephemeral( "receipt_key", 580, ["@fakerecipient:example.com"] ) - self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called() + # This method will be called, but with an empty list of events + self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( + interested_service, ephemeral=[] + ) def _mkservice(self, is_interested, protocols=None): service = Mock() @@ -321,3 +331,251 @@ class AppServiceHandlerTestCase(unittest.TestCase): service.token = "mock_service_token" service.url = "mock_service_url" return service + + +class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): + """ + Tests that the ApplicationServicesHandler sends events to application + services correctly. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + sendtodevice.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Mock the ApplicationServiceScheduler's _TransactionController's send method so that + # we can track any outgoing ephemeral events + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock + + # Mock out application services, and allow defining our own in tests + self._services: List[ApplicationService] = [] + self.hs.get_datastore().get_app_services = Mock(return_value=self._services) + + # A user on the homeserver. + self.local_user_device_id = "local_device" + self.local_user = self.register_user("local_user", "password") + self.local_user_token = self.login( + "local_user", "password", self.local_user_device_id + ) + + # A user on the homeserver which lies within an appservice's exclusive user namespace. + self.exclusive_as_user_device_id = "exclusive_as_device" + self.exclusive_as_user = self.register_user("exclusive_as_user", "password") + self.exclusive_as_user_token = self.login( + "exclusive_as_user", "password", self.exclusive_as_user_device_id + ) + + @unittest.override_config( + {"experimental_features": {"msc2409_to_device_messages_enabled": True}} + ) + def test_application_services_receive_local_to_device(self): + """ + Test that when a user sends a to-device message to another user + that is an application service's user namespace, the + application service will receive it. + """ + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@exclusive_as_user:.+", + "exclusive": True, + } + ], + }, + ) + + # Have local_user send a to-device message to exclusive_as_user + message_content = {"some_key": "some really interesting value"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/3", + content={ + "messages": { + self.exclusive_as_user: { + self.exclusive_as_user_device_id: message_content + } + } + }, + access_token=self.local_user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + # Have exclusive_as_user send a to-device message to local_user + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/4", + content={ + "messages": { + self.local_user: {self.local_user_device_id: message_content} + } + }, + access_token=self.exclusive_as_user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + # Check if our application service - that is interested in exclusive_as_user - received + # the to-device message as part of an AS transaction. + # Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS. + # + # The uninterested application service should not have been notified at all. + self.send_mock.assert_called_once() + service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0] + + # Assert that this was the same to-device message that local_user sent + self.assertEqual(service, interested_appservice) + self.assertEqual(to_device_messages[0]["type"], "m.room_key_request") + self.assertEqual(to_device_messages[0]["sender"], self.local_user) + + # Additional fields 'to_user_id' and 'to_device_id' specifically for + # to-device messages via the AS API + self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user) + self.assertEqual( + to_device_messages[0]["to_device_id"], self.exclusive_as_user_device_id + ) + self.assertEqual(to_device_messages[0]["content"], message_content) + + @unittest.override_config( + {"experimental_features": {"msc2409_to_device_messages_enabled": True}} + ) + def test_application_services_receive_bursts_of_to_device(self): + """ + Test that when a user sends >100 to-device messages at once, any + interested AS's will receive them in separate transactions. + + Also tests that uninterested application services do not receive messages. + """ + # Register two application services with exclusive interest in a user + interested_appservices = [] + for _ in range(2): + appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@exclusive_as_user:.+", + "exclusive": True, + } + ], + }, + ) + interested_appservices.append(appservice) + + # ...and an application service which does not have any user interest. + self._register_application_service() + + to_device_message_content = { + "some key": "some interesting value", + } + + # We need to send a large burst of to-device messages. We also would like to + # include them all in the same application service transaction so that we can + # test large transactions. + # + # To do this, we can send a single to-device message to many user devices at + # once. + # + # We insert number_of_messages - 1 messages into the database directly. We'll then + # send a final to-device message to the real device, which will also kick off + # an AS transaction (as just inserting messages into the DB won't). + number_of_messages = 150 + fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)] + messages = { + self.exclusive_as_user: { + device_id: to_device_message_content for device_id in fake_device_ids + } + } + + # Create a fake device per message. We can't send to-device messages to + # a device that doesn't exist. + self.get_success( + self.hs.get_datastore().db_pool.simple_insert_many( + desc="test_application_services_receive_burst_of_to_device", + table="devices", + values=[ + { + "user_id": self.exclusive_as_user, + "device_id": device_id, + } + for device_id in fake_device_ids + ], + ) + ) + + # Seed the device_inbox table with our fake messages + self.get_success( + self.hs.get_datastore().add_messages_to_device_inbox(messages, {}) + ) + + # Now have local_user send a final to-device message to exclusive_as_user. All unsent + # to-device messages should be sent to any application services + # interested in exclusive_as_user. + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/4", + content={ + "messages": { + self.exclusive_as_user: { + self.exclusive_as_user_device_id: to_device_message_content + } + } + }, + access_token=self.local_user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + self.send_mock.assert_called() + + # Count the total number of to-device messages that were sent out per-service. + # Ensure that we only sent to-device messages to interested services, and that + # each interested service received the full count of to-device messages. + service_id_to_message_count: Dict[str, int] = {} + + for call in self.send_mock.call_args_list: + service, _events, _ephemeral, to_device_messages = call[0] + + # Check that this was made to an interested service + self.assertIn(service, interested_appservices) + + # Add to the count of messages for this application service + service_id_to_message_count.setdefault(service.id, 0) + service_id_to_message_count[service.id] += len(to_device_messages) + + # Assert that each interested service received the full count of messages + for count in service_id_to_message_count.values(): + self.assertEqual(count, number_of_messages) + + def _register_application_service( + self, + namespaces: Optional[Dict[str, Iterable[Dict]]] = None, + ) -> ApplicationService: + """ + Register a new application service, with the given namespaces of interest. + + Args: + namespaces: A dictionary containing any user, room or alias namespaces that + the application service is interested in. + + Returns: + The registered application service. + """ + # Create an application service + appservice = ApplicationService( + token=None, + hostname="example.com", + id=random_string(10), + sender="@as:example.com", + rate_limited=False, + namespaces=namespaces, + supports_ephemeral=True, + ) + + # Register the application service + self._services.append(appservice) + + return appservice diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 329490caad..db157781fc 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -481,10 +481,10 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ValueError, ) - def test_set_type_stream_id_for_appservice(self) -> None: + def test_set_appservice_stream_type_pos(self) -> None: read_receipt_value = 1024 self.get_success( - self.store.set_type_stream_id_for_appservice( + self.store.set_appservice_stream_type_pos( self.service, "read_receipt", read_receipt_value ) ) @@ -494,7 +494,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): self.assertEqual(result, read_receipt_value) self.get_success( - self.store.set_type_stream_id_for_appservice( + self.store.set_appservice_stream_type_pos( self.service, "presence", read_receipt_value ) ) @@ -503,9 +503,9 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ) self.assertEqual(result, read_receipt_value) - def test_set_type_stream_id_for_appservice_invalid_type(self) -> None: + def test_set_appservice_stream_type_pos_invalid_type(self) -> None: self.get_failure( - self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024), + self.store.set_appservice_stream_type_pos(self.service, "foobar", 1024), ValueError, ) |