diff options
author | Travis Ralston <travisr@matrix.org> | 2021-12-28 20:31:10 -0700 |
---|---|---|
committer | Travis Ralston <travisr@matrix.org> | 2021-12-28 20:31:10 -0700 |
commit | 73d91b5d34f4df5f6a1ff668113d6a844ad65d8d (patch) | |
tree | 2ce890b5045ef17c4c71a219aa76053ad1fa62b0 | |
parent | Merge remote-tracking branch 'origin/rei/as_device_masquerading_msc3202' into... (diff) | |
parent | Newsfile (diff) | |
download | synapse-travis/alt-todev-masq-otk-fbkey.tar.xz |
Merge remote-tracking branch 'origin/rei/msc3202_otks_fbks' into anoa/e2e_as_internal_testing github/travis/alt-todev-masq-otk-fbkey travis/alt-todev-masq-otk-fbkey
-rw-r--r-- | changelog.d/11617.feature | 1 | ||||
-rw-r--r-- | synapse/appservice/__init__.py | 18 | ||||
-rw-r--r-- | synapse/appservice/api.py | 18 | ||||
-rw-r--r-- | synapse/appservice/scheduler.py | 106 | ||||
-rw-r--r-- | synapse/config/appservice.py | 13 | ||||
-rw-r--r-- | synapse/config/experimental.py | 6 | ||||
-rw-r--r-- | synapse/storage/databases/main/appservice.py | 37 | ||||
-rw-r--r-- | synapse/storage/databases/main/end_to_end_keys.py | 110 | ||||
-rw-r--r-- | tests/appservice/test_scheduler.py | 48 | ||||
-rw-r--r-- | tests/handlers/test_appservice.py | 189 | ||||
-rw-r--r-- | tests/handlers/test_user_directory.py | 6 | ||||
-rw-r--r-- | tests/rest/client/utils.py | 22 | ||||
-rw-r--r-- | tests/storage/test_appservice.py | 8 | ||||
-rw-r--r-- | tests/storage/test_user_directory.py | 4 | ||||
-rw-r--r-- | tests/unittest.py | 8 |
15 files changed, 537 insertions, 57 deletions
diff --git a/changelog.d/11617.feature b/changelog.d/11617.feature new file mode 100644 index 0000000000..cf03f00e7c --- /dev/null +++ b/changelog.d/11617.feature @@ -0,0 +1 @@ +Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. \ No newline at end of file diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 01db2b2ae3..632d5d133c 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,7 +14,7 @@ import logging import re from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Match, Optional +from typing import TYPE_CHECKING, Dict, Iterable, List, Match, Optional from synapse.api.constants import EventTypes from synapse.events import EventBase @@ -27,6 +27,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Type for the `device_one_time_key_counts` field in an appservice transaction +# user ID -> {device ID -> {algorithm -> count}} +TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] + +# Type for the `device_unused_fallback_keys` field in an appservice transaction +# user ID -> {device ID -> [algorithm]} +TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]] + class ApplicationServiceState(Enum): DOWN = "down" @@ -61,6 +69,7 @@ class ApplicationService: rate_limited=True, ip_range_whitelist=None, supports_ephemeral=False, + msc3202_transaction_extensions: bool = False, ): self.token = token self.url = ( @@ -73,6 +82,7 @@ class ApplicationService: self.id = id self.ip_range_whitelist = ip_range_whitelist self.supports_ephemeral = supports_ephemeral + self.msc3202_transaction_extensions = msc3202_transaction_extensions if "|" in self.id: raise Exception("application service ID cannot contain '|' character") @@ -371,6 +381,8 @@ class AppServiceTransaction: ephemeral: List[JsonDict], to_device_messages: List[JsonDict], device_list_summary: DeviceLists, + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ): self.service = service self.id = id @@ -378,6 +390,8 @@ class AppServiceTransaction: self.ephemeral = ephemeral self.to_device_messages = to_device_messages self.device_list_summary = device_list_summary + self.one_time_key_counts = one_time_key_counts + self.unused_fallback_keys = unused_fallback_keys async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -393,6 +407,8 @@ class AppServiceTransaction: ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, device_list_summary=self.device_list_summary, + one_time_key_counts=self.one_time_key_counts, + unused_fallback_keys=self.unused_fallback_keys, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 3ae59c7a04..0333832a9c 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -20,6 +20,11 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient @@ -27,7 +32,6 @@ from synapse.types import DeviceLists, JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: - from synapse.appservice import ApplicationService from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -207,6 +211,8 @@ class ApplicationServiceApi(SimpleHttpClient): ephemeral: List[JsonDict], to_device_messages: List[JsonDict], device_list_summary: DeviceLists, + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, txn_id: Optional[int] = None, ) -> bool: """ @@ -258,6 +264,16 @@ class ApplicationServiceApi(SimpleHttpClient): } ) + if service.msc3202_transaction_extensions: + if one_time_key_counts: + body[ + "org.matrix.msc3202.device_one_time_key_counts" + ] = one_time_key_counts + if unused_fallback_keys: + body[ + "org.matrix.msc3202.device_unused_fallback_keys" + ] = unused_fallback_keys + try: await self.put_json( uri=uri, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index d49636d926..d59bf3e7a0 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,14 +48,22 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging -from typing import Dict, Iterable, List, Optional - -from synapse.appservice import ApplicationService, ApplicationServiceState +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple + +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import DeviceLists, JsonDict +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -81,7 +89,7 @@ class ApplicationServiceScheduler: self.as_api = hs.get_application_service_api() self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) - self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) + self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs) async def start(self): logger.info("Starting appservice scheduler") @@ -151,7 +159,7 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__(self, txn_ctrl, clock): + def __init__(self, txn_ctrl, clock, hs: "HomeServer"): # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [event_json]} @@ -162,9 +170,13 @@ class _ServiceQueuer: self.queued_device_list_summaries: Dict[str, List[DeviceLists]] = {} # the appservices which currently have a transaction in flight - self.requests_in_flight = set() + self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock + self._msc3202_transaction_extensions_enabled: bool = ( + hs.config.experimental.msc3202_transaction_extensions + ) + self._store = hs.get_datastore() def start_background_request(self, service): # start a sender for this appservice if we don't already have one @@ -235,6 +247,26 @@ class _ServiceQueuer: ): return + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None + + if ( + self._msc3202_transaction_extensions_enabled + and service.msc3202_transaction_extensions + ): + # Lazily compute the one-time key counts and fallback key + # usage states for the users which are mentioned in this + # transaction, as well as the appservice's sender. + interesting_users = await self._determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys( + service, events, ephemeral, to_device_messages_to_send + ) + ( + one_time_key_counts, + unused_fallback_keys, + ) = await self._compute_msc3202_otk_counts_and_fallback_keys( + interesting_users + ) + try: await self.txn_ctrl.send( service, @@ -242,12 +274,66 @@ class _ServiceQueuer: ephemeral, to_device_messages_to_send, device_list_summary, + one_time_key_counts, + unused_fallback_keys, ) except Exception: logger.exception("AS request failed") finally: self.requests_in_flight.discard(service.id) + async def _determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys( + self, + service: ApplicationService, + events: Iterable[EventBase], + ephemerals: Iterable[JsonDict], + to_device_messages: Iterable[JsonDict], + ) -> Set[str]: + """ + Given a list of the events, ephemeral messages and to-device messaeges, + compute a list of application services users that may have interesting + updates to the one-time key counts or fallback key usage. + """ + interesting_users: Set[str] = set() + + # The sender is always included + interesting_users.add(service.sender) + + # All AS users that would receive the PDUs or EDUs sent to these rooms + # are classed as 'interesting'. + rooms_of_interesting_users: Set[str] = set() + # PDUs + rooms_of_interesting_users.update(event.room_id for event in events) + # EDUs + rooms_of_interesting_users.update( + ephemeral["room_id"] for ephemeral in ephemerals + ) + + # Look up the AS users in those rooms + for room_id in rooms_of_interesting_users: + interesting_users.update( + await self._store.get_app_service_users_in_room(room_id, service) + ) + + # Add recipients of to-device messages. + # device_message["user_id"] is the ID of the recipient. + interesting_users.update( + device_message["user_id"] for device_message in to_device_messages + ) + + return interesting_users + + async def _compute_msc3202_otk_counts_and_fallback_keys( + self, users: Set[str] + ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + """ + Given a list of application service users that are interesting, + compute one-time key counts and fallback key usages for the users. + """ + otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users) + unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users) + return otk_counts, unused_fbks + class _TransactionController: """Transaction manager. @@ -281,6 +367,8 @@ class _TransactionController: ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, device_list_summary: Optional[DeviceLists] = None, + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -292,6 +380,10 @@ class _TransactionController: ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. device_list_summary: The device list summary to include in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -300,6 +392,8 @@ class _TransactionController: ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], device_list_summary=device_list_summary or DeviceLists(), + one_time_key_counts=one_time_key_counts or {}, + unused_fallback_keys=unused_fallback_keys or {}, ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index e4bb7224a4..c77f4058b3 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -167,6 +167,16 @@ def _load_appservice( supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) + # Opt-in flag for the MSC3202-specific transactional behaviour. + # When enabled, appservice transactions contain the following information: + # - device One-Time Key counts + # - device unused fallback key usage states + msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) + if not isinstance(msc3202_transaction_extensions, bool): + raise ValueError( + "The `org.matrix.msc3202` option should be true or false if specified." + ) + return ApplicationService( token=as_info["as_token"], hostname=hostname, @@ -175,8 +185,9 @@ def _load_appservice( hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], - supports_ephemeral=supports_ephemeral, protocols=protocols, rate_limited=rate_limited, ip_range_whitelist=ip_range_whitelist, + supports_ephemeral=supports_ephemeral, + msc3202_transaction_extensions=msc3202_transaction_extensions, ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d19165e5b4..058abeb194 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -66,3 +66,9 @@ class ExperimentalConfig(Config): self.msc3202_device_masquerading_enabled: bool = experimental.get( "msc3202_device_masquerading", False ) + + # Portion of MSC3202 related to transaction extensions: + # sending one-time key counts and fallback key usage to application services. + self.msc3202_transaction_extensions: bool = experimental.get( + "msc3202_transaction_extensions", False + ) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 0ac2005bee..d43153ea23 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,15 +20,19 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, AppServiceTransaction, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices from synapse.events import EventBase -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage._base import db_to_json +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.types import Connection from synapse.types import DeviceLists, JsonDict +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.util import json_encoder +from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -57,8 +61,13 @@ def _make_exclusive_regex( return exclusive_user_pattern -class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): +class ApplicationServiceWorkerStore(RoomMemberWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.services_cache = load_appservices( hs.hostname, hs.config.appservice.app_service_config_files ) @@ -120,6 +129,14 @@ class ApplicationServiceWorkerStore(SQLBaseStore): return service return None + # OSTD cache invalidation + @cached(iterable=True, prune_unread_entries=False) + async def get_app_service_users_in_room( + self, room_id: str, app_service: "ApplicationService" + ) -> List[str]: + users_in_room = await self.get_users_in_room(room_id) + return list(filter(app_service.is_interested_in_user, users_in_room)) + class ApplicationServiceStore(ApplicationServiceWorkerStore): # This is currently empty due to there not being any AS storage functions @@ -196,6 +213,8 @@ class ApplicationServiceTransactionWorkerStore( ephemeral: List[JsonDict], to_device_messages: List[JsonDict], device_list_summary: DeviceLists, + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -207,6 +226,10 @@ class ApplicationServiceTransactionWorkerStore( ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. device_list_summary: The device list summary to include in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. Returns: A new transaction. @@ -243,6 +266,8 @@ class ApplicationServiceTransactionWorkerStore( ephemeral=ephemeral, to_device_messages=to_device_messages, device_list_summary=device_list_summary, + one_time_key_counts=one_time_key_counts, + unused_fallback_keys=unused_fallback_keys, ) return await self.db_pool.runInteraction( @@ -334,6 +359,8 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) + # TODO: should we recalculate one-time key counts and unused fallback + # key counts here? return AppServiceTransaction( service=service, id=entry["txn_id"], @@ -341,6 +368,8 @@ class ApplicationServiceTransactionWorkerStore( ephemeral=[], to_device_messages=[], device_list_summary=DeviceLists(), + one_time_key_counts={}, + unused_fallback_keys={}, ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index b06c1dc45b..d55697c093 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -22,9 +22,17 @@ from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection from synapse.api.constants import DeviceKeyAlgorithms +from synapse.appservice import ( + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.types import JsonDict @@ -397,6 +405,104 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def count_bulk_e2e_one_time_keys_for_as( + self, user_ids: Collection[str] + ) -> TransactionOneTimeKeyCounts: + """ + Counts, in bulk, the one-time keys for all the users specified. + Intended to be used by application services for populating OTK counts in + transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithm -> count + """ + + def _count_bulk_e2e_one_time_keys_txn( + txn: LoggingTransaction, + ) -> TransactionOneTimeKeyCounts: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + sql = f""" + SELECT user_id, device_id, algorithm, COUNT(key_id) + FROM devices + LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id) + WHERE {user_in_where_clause} + GROUP BY user_id, device_id, algorithm + """ + txn.execute(sql, user_parameters) + + result = {} + + for user_id, device_id, algorithm, count in txn: + device_count_by_algo = result.setdefault(user_id, {}).setdefault( + device_id, {} + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_count_by_algo[algorithm] = count + + return result + + return await self.db_pool.runInteraction( + "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn + ) + + async def get_e2e_bulk_unused_fallback_key_types( + self, user_ids: Collection[str] + ) -> TransactionUnusedFallbackKeys: + """ + Finds, in bulk, the types of unused fallback keys for all the users specified. + Intended to be used by application services for populating unused fallback + keys in transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithms + """ + if len(user_ids) == 0: + return {} + + def _get_bulk_e2e_unused_fallback_keys_txn( + txn: LoggingTransaction, + ) -> TransactionUnusedFallbackKeys: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "devices.user_id", user_ids + ) + # We can't use USING here because we require the `.used` condition + # to be part of the JOIN condition so that we generate empty lists + # when all keys are used (as opposed to just when there are no keys at all). + sql = f""" + SELECT devices.user_id, devices.device_id, algorithm + FROM devices + LEFT JOIN e2e_fallback_keys_json AS fallback_keys + /* We can't use USING here because we require the `.used` + condition to be part of the JOIN condition so that we + generate empty lists when all keys are used (as opposed + to just when there are no keys at all). */ + ON devices.user_id = fallback_keys.user_id + AND devices.device_id = fallback_keys.device_id + AND NOT fallback_keys.used + WHERE + {user_in_where_clause} + """ + txn.execute(sql, user_parameters) + + result = {} + + for user_id, device_id, algorithm in txn: + device_unused_keys = result.setdefault(user_id, {}).setdefault( + device_id, [] + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_unused_keys.append(algorithm) + + return result + + return await self.db_pool.runInteraction( + "_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn + ) + async def set_e2e_fallback_keys( self, user_id: str, device_id: str, fallback_keys: JsonDict ) -> None: diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 1c12cc5f49..ba56b12fe4 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -68,6 +68,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): ephemeral=[], to_device_messages=[], device_list_summary=DeviceLists(), + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -94,6 +96,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): ephemeral=[], to_device_messages=[], device_list_summary=DeviceLists(), + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -121,6 +125,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): ephemeral=[], to_device_messages=[], device_list_summary=DeviceLists(), + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -222,9 +228,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with( - service, [event], [], [], DeviceLists() - ) + self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], DeviceLists(), None, None) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -239,12 +243,12 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], [], DeviceLists()) + self.txn_ctrl.send.assert_called_with(service, [event], [], [], DeviceLists(), None, None) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) self.txn_ctrl.send.assert_called_with( - service, [event2, event3], [], [], DeviceLists() + service, [event2, event3], [], [], DeviceLists(), None, None ) self.assertEquals(2, self.txn_ctrl.send.call_count) @@ -271,21 +275,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with( - srv1, [srv_1_event], [], [], DeviceLists() - ) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], DeviceLists(), None, None) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with( - srv2, [srv_2_event], [], [], DeviceLists() - ) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], DeviceLists(), None, None) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with( - srv2, [srv_2_event2], [], [], DeviceLists() - ) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], DeviceLists(), None, None) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -305,17 +303,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Expect the first event to be sent immediately. self.txn_ctrl.send.assert_called_with( - service, [event_list[0]], [], [], DeviceLists() + service, [event_list[0]], [], [], DeviceLists(), None, None ) srv_1_defer.callback(service) # Then send the next 100 events self.txn_ctrl.send.assert_called_with( - service, event_list[1:101], [], [], DeviceLists() + service, event_list[1:101], [], [], DeviceLists(), None, None ) srv_2_defer.callback(service) # Then the final 99 events self.txn_ctrl.send.assert_called_with( - service, event_list[101:], [], [], DeviceLists() + service, event_list[101:], [], [], DeviceLists(), None, None ) self.assertEquals(3, self.txn_ctrl.send.call_count) @@ -325,7 +323,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], event_list, [], DeviceLists() + service, [], event_list, [], DeviceLists(), None, None ) def test_send_multiple_ephemeral_no_queue(self): @@ -334,7 +332,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], event_list, [], DeviceLists() + service, [], event_list, [], DeviceLists(), None, None ) def test_send_single_ephemeral_with_queue(self): @@ -350,15 +348,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with( - service, [], event_list_1, [], DeviceLists() - ) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], DeviceLists(), None, None) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [], DeviceLists() + service, [], event_list_2 + event_list_3, [], DeviceLists(), None, None ) self.assertEquals(2, self.txn_ctrl.send.call_count) @@ -372,10 +368,8 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], first_chunk, [], DeviceLists() + service, [], first_chunk, [], DeviceLists(), None, None ) d.callback(service) - self.txn_ctrl.send.assert_called_with( - service, [], second_chunk, [], DeviceLists() - ) + self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], DeviceLists(), None, None) self.assertEquals(2, self.txn_ctrl.send.call_count) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index def93caf17..ab368b4c18 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -16,17 +16,25 @@ from typing import Dict, Iterable, List, Optional from unittest.mock import Mock from twisted.internet import defer +from twisted.internet.testing import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.appservice import ApplicationService +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.rest.client import login, receipts, room, sendtodevice +from synapse.rest.client import login, receipts, register, room, sendtodevice +from synapse.server import HomeServer from synapse.types import RoomStreamToken +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config from tests.utils import MockClock @@ -437,6 +445,8 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): _ephemeral, to_device_messages, _device_list_summary, + _otks, + _fbks, ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent @@ -554,6 +564,8 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): _ephemeral, to_device_messages, _device_list_summary, + _otks, + _fbks, ) = call[0] # Check that this was made to an interested service @@ -596,3 +608,176 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self._services.append(appservice) return appservice + + +class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): + # Argument indices for pulling out arguments from a `send_mock`. + ARG_OTK_COUNTS = 4 + ARG_FALLBACK_KEYS = 5 + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + room.register_servlets, + sendtodevice.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Mock the ApplicationServiceScheduler's _TransactionController's send method so that + # we can track what's going out + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. + + # Define an application service for the tests + self._service_token = "VERYSECRET" + self._service = ApplicationService( + self._service_token, + "as1.invalid", + "as1", + "@as.sender:test", + namespaces={ + "users": [ + {"regex": "@_as_.*:test", "exclusive": True}, + {"regex": "@as.sender:test", "exclusive": True}, + ] + }, + msc3202_transaction_extensions=True, + ) + self.hs.get_datastore().services_cache = [self._service] + + # Register some appservice users + self._sender_user, self._sender_device = self.register_appservice_user( + "as.sender", self._service_token + ) + self._namespaced_user, self._namespaced_device = self.register_appservice_user( + "_as_user1", self._service_token + ) + + # Register a real user as well. + self._real_user = self.register_user("real.user", "meow") + self._real_user_token = self.login("real.user", "meow") + + async def _add_otks_for_device( + self, user_id: str, device_id: str, otk_count: int + ) -> None: + """ + Add some dummy keys. It doesn't matter if they're not a real algorithm; + that should be opaque to the server anyway. + """ + await self.hs.get_datastore().add_e2e_one_time_keys( + user_id, + device_id, + self.clock.time_msec(), + [("algo", f"k{i}", "{}") for i in range(otk_count)], + ) + + async def _add_fallback_key_for_device( + self, user_id: str, device_id: str, used: bool + ) -> None: + """ + Adds a fake fallback key to a device, optionally marking it as used + right away. + """ + store = self.hs.get_datastore() + await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"}) + if used is True: + # Mark the key as used + await store.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": "algo", + "key_id": "fk", + }, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", + ) + + def _set_up_devices_and_a_room(self) -> str: + """ + Helper to set up devices for all the users + and a room for the users to talk in. + """ + + async def preparation(): + await self._add_otks_for_device(self._sender_user, self._sender_device, 42) + await self._add_fallback_key_for_device( + self._sender_user, self._sender_device, used=True + ) + await self._add_otks_for_device( + self._namespaced_user, self._namespaced_device, 36 + ) + await self._add_fallback_key_for_device( + self._namespaced_user, self._namespaced_device, used=False + ) + + # Register a device for the real user, too, so that we can later ensure + # that we don't leak information to the AS about the non-AS user. + await self.hs.get_datastore().store_device( + self._real_user, "REALDEV", "UltraMatrix 3000" + ) + await self._add_otks_for_device(self._real_user, "REALDEV", 50) + + self.get_success(preparation()) + + room_id = self.helper.create_room_as( + self._real_user, is_public=True, tok=self._real_user_token + ) + self.helper.join( + room_id, + self._namespaced_user, + tok=self._service_token, + appservice_user_id=self._namespaced_user, + ) + + # Check it was called for sanity. (This was to send the join event to the AS.) + self.send_mock.assert_called() + self.send_mock.reset_mock() + + return room_id + + @override_config( + {"experimental_features": {"msc3202_transaction_extensions": True}} + ) + def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus( + self, + ) -> None: + """ + Tests that: + - the AS receives one-time key counts and unused fallback keys for: + - the specified sender; and + - any user who is in receipt of the PDUs + """ + + room_id = self._set_up_devices_and_a_room() + + # Send a message into the AS's room + self.helper.send(room_id, "woof woof", tok=self._real_user_token) + + # Capture what was sent as an AS transaction. + self.send_mock.assert_called() + last_args = self.send_mock.call_args + otks: Optional[TransactionOneTimeKeyCounts] = last_args.args[ + self.ARG_OTK_COUNTS + ] + unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args.args[ + self.ARG_FALLBACK_KEYS + ] + + self.assertEqual( + otks, + { + "@as.sender:test": {self._sender_device: {"algo": 42}}, + "@_as_user1:test": {self._namespaced_device: {"algo": 36}}, + }, + ) + self.assertEqual( + unused_fallbacks, + { + "@as.sender:test": {self._sender_device: []}, + "@_as_user1:test": {self._namespaced_device: ["algo"]}, + }, + ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 70c621b825..482c90ef68 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -169,7 +169,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Register an AS user. user = self.register_user("user", "pass") token = self.login(user, "pass") - as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + as_user, _ = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) # Join the AS user to rooms owned by the normal user. public, private = self._create_rooms_and_inject_memberships( @@ -388,7 +390,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_handle_local_profile_change_with_appservice_user(self) -> None: # create user - as_user_id = self.register_appservice_user( + as_user_id, _ = self.register_appservice_user( "as_user_alice", self.appservice.token ) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 1af5e5cee5..3d382761a8 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -147,12 +147,20 @@ class RestHelper: expect_code=expect_code, ) - def join(self, room=None, user=None, expect_code=200, tok=None): + def join( + self, + room=None, + user=None, + expect_code=200, + tok=None, + appservice_user_id: Optional[str] = None, + ): self.change_membership( room=room, src=user, targ=user, tok=tok, + appservice_user_id=appservice_user_id, membership=Membership.JOIN, expect_code=expect_code, ) @@ -204,6 +212,7 @@ class RestHelper: membership: str, extra_data: Optional[dict] = None, tok: Optional[str] = None, + appservice_user_id: Optional[str] = None, expect_code: int = 200, expect_errcode: Optional[str] = None, ) -> None: @@ -217,6 +226,9 @@ class RestHelper: membership: The type of membership event extra_data: Extra information to include in the content of the event tok: The user access token to use + appservice_user_id: The `user_id` URL parameter to pass. + This allows driving an application service user + using an application service access token in `tok`. expect_code: The expected HTTP response code expect_errcode: The expected Matrix error code """ @@ -224,8 +236,14 @@ class RestHelper: self.auth_user_id = src path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) + next_arg_char = "?" + if tok: - path = path + "?access_token=%s" % tok + path += "?access_token=%s" % tok + next_arg_char = "&" + + if appservice_user_id: + path += f"{next_arg_char}user_id={appservice_user_id}" data = {"membership": membership} data.update(extra_data or {}) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index bb1411232a..84de9c962f 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -268,7 +268,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], [], DeviceLists()) + self.store.create_appservice_txn(service, events, [], [], DeviceLists(), {}, {}) ) ) self.assertEquals(txn.id, 1) @@ -284,7 +284,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], DeviceLists()) + self.store.create_appservice_txn(service, events, [], [], DeviceLists(), {}, {}) ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) @@ -297,7 +297,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], DeviceLists()) + self.store.create_appservice_txn(service, events, [], [], DeviceLists(), {}, {}) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) @@ -321,7 +321,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], DeviceLists()) + self.store.create_appservice_txn(service, events, [], [], DeviceLists(), {}, {}) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7f5b28aed8..48f1e9d841 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -341,7 +341,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): # Register an AS user. user = self.register_user("user", "pass") token = self.login(user, "pass") - as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + as_user, _ = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) # Join the AS user to rooms owned by the normal user. public, private = self._create_rooms_and_inject_memberships( diff --git a/tests/unittest.py b/tests/unittest.py index 1431848367..0698b3d13a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -620,18 +620,18 @@ class HomeserverTestCase(TestCase): self, username: str, appservice_token: str, - ) -> str: + ) -> Tuple[str, str]: """Register an appservice user as an application service. Requires the client-facing registration API be registered. Args: username: the user to be registered by an application service. - Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart" + Should NOT be a full username, i.e. just "localpart" as opposed to "@localpart:hostname" appservice_token: the acccess token for that application service. Raises: if the request to '/register' does not return 200 OK. - Returns: the MXID of the new user. + Returns: the MXID of the new user, the device ID of the new user's first device. """ channel = self.make_request( "POST", @@ -643,7 +643,7 @@ class HomeserverTestCase(TestCase): access_token=appservice_token, ) self.assertEqual(channel.code, 200, channel.json_body) - return channel.json_body["user_id"] + return channel.json_body["user_id"], channel.json_body["device_id"] def login( self, |