From 72d37f2d8f7172aea02445a568d8b3bbab2df693 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Dec 2021 14:49:05 +0000 Subject: Support opting-in to MSC3202 transactional behaviour using the registration file --- synapse/appservice/__init__.py | 2 ++ synapse/config/appservice.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index e33e69eed1..9672e34f92 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -61,6 +61,7 @@ class ApplicationService: rate_limited=True, ip_range_whitelist=None, supports_ephemeral=False, + msc3202_transaction_extensions: bool = False, ): self.token = token self.url = ( @@ -73,6 +74,7 @@ class ApplicationService: self.id = id self.ip_range_whitelist = ip_range_whitelist self.supports_ephemeral = supports_ephemeral + self.msc3202_transaction_extensions = msc3202_transaction_extensions if "|" in self.id: raise Exception("application service ID cannot contain '|' character") diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index e4bb7224a4..c77f4058b3 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -167,6 +167,16 @@ def _load_appservice( supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) + # Opt-in flag for the MSC3202-specific transactional behaviour. + # When enabled, appservice transactions contain the following information: + # - device One-Time Key counts + # - device unused fallback key usage states + msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) + if not isinstance(msc3202_transaction_extensions, bool): + raise ValueError( + "The `org.matrix.msc3202` option should be true or false if specified." + ) + return ApplicationService( token=as_info["as_token"], hostname=hostname, @@ -175,8 +185,9 @@ def _load_appservice( hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], - supports_ephemeral=supports_ephemeral, protocols=protocols, rate_limited=rate_limited, ip_range_whitelist=ip_range_whitelist, + supports_ephemeral=supports_ephemeral, + msc3202_transaction_extensions=msc3202_transaction_extensions, ) -- cgit 1.4.1 From dcdfa6f46d4c2ebc2075547899bf22b91b841e82 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Dec 2021 15:00:13 +0000 Subject: Add type aliases for one-time key counts and unused fallback keys that will be sent --- synapse/appservice/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 9672e34f92..811f21d634 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -27,6 +27,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Type for the `device_one_time_key_counts` field in an appservice transaction +# user ID -> {device ID -> {algorithm -> count}} +TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] + +# Type for the `device_unused_fallback_keys` field in an appservice transaction +# user ID -> {device ID -> [algorithm]} +TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]] + class ApplicationServiceState(Enum): DOWN = "down" -- cgit 1.4.1 From 9e171fd15e330f50d4987c8b880a962e85eff556 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Dec 2021 15:00:53 +0000 Subject: Feed one-time key counts and unused fallback keys through the transaction --- synapse/appservice/__init__.py | 8 +++++++- synapse/appservice/api.py | 8 +++++++- synapse/storage/databases/main/appservice.py | 10 ++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 811f21d634..268a3ae3ba 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,7 +14,7 @@ import logging import re from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Match, Optional +from typing import TYPE_CHECKING, Dict, Iterable, List, Match, Optional from synapse.api.constants import EventTypes from synapse.events import EventBase @@ -341,12 +341,16 @@ class AppServiceTransaction: events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ): self.service = service self.id = id self.events = events self.ephemeral = ephemeral self.to_device_messages = to_device_messages + self.one_time_key_counts = one_time_key_counts + self.unused_fallback_keys = unused_fallback_keys async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -361,6 +365,8 @@ class AppServiceTransaction: events=self.events, ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, + one_time_key_counts=self.one_time_key_counts, + unused_fallback_keys=self.unused_fallback_keys, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index ca58f92339..d77effebfa 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -19,6 +19,11 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient @@ -26,7 +31,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: - from synapse.appservice import ApplicationService from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -205,6 +209,8 @@ class ApplicationServiceApi(SimpleHttpClient): events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, txn_id: Optional[int] = None, ) -> bool: """ diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 68ba330432..46d8448ce4 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,6 +20,8 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, AppServiceTransaction, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices from synapse.events import EventBase @@ -195,6 +197,8 @@ class ApplicationServiceTransactionWorkerStore( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -205,6 +209,10 @@ class ApplicationServiceTransactionWorkerStore( events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. Returns: A new transaction. @@ -240,6 +248,8 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=ephemeral, to_device_messages=to_device_messages, + one_time_key_counts=one_time_key_counts, + unused_fallback_keys=unused_fallback_keys, ) return await self.db_pool.runInteraction( -- cgit 1.4.1 From 473d1a399e6af1d39219ead6b9d81ca51b981185 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Dec 2021 15:04:00 +0000 Subject: Emit the one-time key counts and fallback keys over federation --- synapse/appservice/api.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d77effebfa..d62871bb07 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -249,6 +249,16 @@ class ApplicationServiceApi(SimpleHttpClient): } ) + if service.msc3202_transaction_extensions: + if one_time_key_counts: + body[ + "org.matrix.msc3202.device_one_time_key_counts" + ] = one_time_key_counts + if unused_fallback_keys: + body[ + "org.matrix.msc3202.device_unused_fallback_keys" + ] = unused_fallback_keys + try: await self.put_json( uri=uri, -- cgit 1.4.1 From 24340342503ba25278cb1d5bbba26facce103498 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Dec 2021 15:08:02 +0000 Subject: During AS catch-up, send empty OTK counts and fallback keys for now --- synapse/storage/databases/main/appservice.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 46d8448ce4..a0404e2391 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -341,12 +341,16 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) + # TODO: should we recalculate one-time key counts and unused fallback + # key counts here? return AppServiceTransaction( service=service, id=entry["txn_id"], events=events, ephemeral=[], to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: -- cgit 1.4.1 From 37215edd817a4f4267302ef57df0f1ebb3dbf330 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 10 Dec 2021 14:46:27 +0000 Subject: Fix up type after rebase onto anoa's branch --- synapse/appservice/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d62871bb07..2ec1b8f48a 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import urllib -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from prometheus_client import Counter @@ -239,7 +239,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) # Never send ephemeral events to appservices that do not support it - body: Dict[str, List[JsonDict]] = {"events": serialized_events} + body: Dict[str, Union[JsonDict, List[JsonDict]]] = {"events": serialized_events} if service.supports_ephemeral: body.update( { -- cgit 1.4.1 From 06455cf91adae637e66910d3f6887d77b173f002 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 10 Dec 2021 14:47:25 +0000 Subject: Boring piping --- synapse/appservice/scheduler.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index dae952dc13..74e6c5855d 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -50,7 +50,12 @@ components. import logging from typing import Dict, Iterable, List, Optional -from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -186,9 +191,17 @@ class _ServiceQueuer: if not events and not ephemeral and not to_device_messages_to_send: return + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None + try: await self.txn_ctrl.send( - service, events, ephemeral, to_device_messages_to_send + service, + events, + ephemeral, + to_device_messages_to_send, + one_time_key_counts, + unused_fallback_keys, ) except Exception: logger.exception("AS request failed") @@ -227,6 +240,8 @@ class _TransactionController: events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -237,6 +252,10 @@ class _TransactionController: events: The persistent events to include in the transaction. ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -244,6 +263,8 @@ class _TransactionController: events=events, ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], + one_time_key_counts=one_time_key_counts or {}, + unused_fallback_keys=unused_fallback_keys or {}, ) service_is_up = await self._is_service_up(service) if service_is_up: -- cgit 1.4.1 From cab682fe895b69760e15eaa987e910547374eddf Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 10 Dec 2021 14:47:35 +0000 Subject: Fix up some tests --- tests/storage/test_appservice.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ddcb7f5549..a27cd05e26 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -267,7 +267,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) ) self.assertEquals(txn.id, 1) @@ -283,7 +283,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) @@ -296,7 +296,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) @@ -320,7 +320,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) -- cgit 1.4.1 From c5e072fad5e6eed570f03370aee06041abf4f274 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 10 Dec 2021 14:47:54 +0000 Subject: Add feature flag for experimental MSC3202 transaction extensions --- synapse/config/experimental.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index e481fc16b6..9bc205338d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -57,3 +57,9 @@ class ExperimentalConfig(Config): self.msc2409_to_device_messages_enabled: bool = experimental.get( "msc2409_to_device_messages_enabled", False ) + + # Portion of MSC3202 related to transaction extensions: + # sending one-time key counts and fallback key usage to application services. + self.msc3202_transaction_extensions: bool = experimental.get( + "msc3202_transaction_extensions", False + ) -- cgit 1.4.1 From 36595a7cdc6e1d0d53c1e01de952f1438821a9db Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 10 Dec 2021 14:48:14 +0000 Subject: Pipe through the feature flag --- synapse/appservice/scheduler.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 74e6c5855d..77194f41fd 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,7 +48,7 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging -from typing import Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set from synapse.appservice import ( ApplicationService, @@ -61,6 +61,9 @@ from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -86,7 +89,7 @@ class ApplicationServiceScheduler: self.as_api = hs.get_application_service_api() self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) - self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) + self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs) async def start(self): logger.info("Starting appservice scheduler") @@ -142,7 +145,7 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__(self, txn_ctrl, clock): + def __init__(self, txn_ctrl, clock, hs: "HomeServer"): # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [event_json]} @@ -151,9 +154,12 @@ class _ServiceQueuer: self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} # the appservices which currently have a transaction in flight - self.requests_in_flight = set() + self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock + self._msc3202_transaction_extensions_enabled: bool = ( + hs.config.experimental.msc3202_transaction_extensions + ) def start_background_request(self, service): # start a sender for this appservice if we don't already have one -- cgit 1.4.1 From b99b3117952dabe933cd01ac1aa6e7440961aa0d Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 13 Dec 2021 16:09:09 +0000 Subject: Add some method stubs and add the OTKs and FBKs to the response --- synapse/appservice/scheduler.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 77194f41fd..8657aae7d0 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,7 +48,7 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.appservice import ( ApplicationService, @@ -200,6 +200,20 @@ class _ServiceQueuer: one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None + if service.msc3202_transaction_extensions: + # Lazily compute the one-time key counts and fallback key + # usage states for the users which are mentioned in this + # transaction, as well as the appservice's sender. + interesting_users = self._determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys( + service, events, ephemeral, to_device_messages_to_send + ) + ( + one_time_key_counts, + unused_fallback_keys, + ) = await self._compute_msc3202_otk_counts_and_fallback_keys( + interesting_users + ) + try: await self.txn_ctrl.send( service, @@ -214,6 +228,30 @@ class _ServiceQueuer: finally: self.requests_in_flight.discard(service.id) + def _determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys( + self, + service: ApplicationService, + events: Iterable[EventBase], + ephemeral: Iterable[JsonDict], + to_device_messages: Iterable[JsonDict], + ) -> Set[str]: + """ + Given a list of the events, ephemeral messages and to-device messaeges, + compute a list of application services users that may have interesting + updates to the one-time key counts or fallback key usage. + """ + # OSTD implement me! + return set() + async def _compute_msc3202_otk_counts_and_fallback_keys( + self, users: Set[str] + ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + """ + Given a list of application service users that are interesting, + compute one-time key counts and fallback key usages for the users. + """ + # OSTD implement me! + return {}, {} + class _TransactionController: """Transaction manager. -- cgit 1.4.1 From ec81f8d38f4fb2de664e41b2779d494ad0b682c4 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 13 Dec 2021 16:39:40 +0000 Subject: Find interesting users for the AS when sending OTKs and FBKs --- synapse/appservice/scheduler.py | 38 ++++++++++++++++++++++++---- synapse/storage/databases/main/appservice.py | 27 ++++++++++++++++---- 2 files changed, 55 insertions(+), 10 deletions(-) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 8657aae7d0..9c8bec5ed7 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -160,6 +160,7 @@ class _ServiceQueuer: self._msc3202_transaction_extensions_enabled: bool = ( hs.config.experimental.msc3202_transaction_extensions ) + self._store = hs.get_datastore() def start_background_request(self, service): # start a sender for this appservice if we don't already have one @@ -204,7 +205,7 @@ class _ServiceQueuer: # Lazily compute the one-time key counts and fallback key # usage states for the users which are mentioned in this # transaction, as well as the appservice's sender. - interesting_users = self._determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys( + interesting_users = await self._determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys( service, events, ephemeral, to_device_messages_to_send ) ( @@ -228,11 +229,11 @@ class _ServiceQueuer: finally: self.requests_in_flight.discard(service.id) - def _determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys( + async def _determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys( self, service: ApplicationService, events: Iterable[EventBase], - ephemeral: Iterable[JsonDict], + ephemerals: Iterable[JsonDict], to_device_messages: Iterable[JsonDict], ) -> Set[str]: """ @@ -240,8 +241,35 @@ class _ServiceQueuer: compute a list of application services users that may have interesting updates to the one-time key counts or fallback key usage. """ - # OSTD implement me! - return set() + interesting_users: Set[str] = set() + + # The sender is always included + interesting_users.add(service.sender) + + # All AS users that would receive the PDUs or EDUs sent to these rooms + # are classed as 'interesting'. + rooms_of_interesting_users: Set[str] = set() + # PDUs + rooms_of_interesting_users.update(event.room_id for event in events) + # EDUs + rooms_of_interesting_users.update( + ephemeral["room_id"] for ephemeral in ephemerals + ) + + # Look up the AS users in those rooms + for room_id in rooms_of_interesting_users: + interesting_users.update( + await self._store.get_app_service_users_in_room(room_id, service) + ) + + # Add recipients of to-device messages. + # device_message["user_id"] is the ID of the recipient. + interesting_users.update( + device_message["user_id"] for device_message in to_device_messages + ) + + return interesting_users + async def _compute_msc3202_otk_counts_and_fallback_keys( self, users: Set[str] ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index a0404e2391..d9a79d6f4f 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -25,12 +25,13 @@ from synapse.appservice import ( ) from synapse.config.appservice import load_appservices from synapse.events import EventBase -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage._base import db_to_json +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.types import Connection +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.types import JsonDict from synapse.util import json_encoder +from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -59,8 +60,13 @@ def _make_exclusive_regex( return exclusive_user_pattern -class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): +class ApplicationServiceWorkerStore(RoomMemberWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.services_cache = load_appservices( hs.hostname, hs.config.appservice.app_service_config_files ) @@ -122,6 +128,17 @@ class ApplicationServiceWorkerStore(SQLBaseStore): return service return None + # OSTD cache invalidation + @cached(iterable=True, prune_unread_entries=False) + async def get_app_service_users_in_room( + self, room_id: str, app_service: "ApplicationService" + ) -> List[str]: + return list( + filter( + app_service.is_interested_in_user, await self.get_users_in_room(room_id) + ) + ) + class ApplicationServiceStore(ApplicationServiceWorkerStore): # This is currently empty due to there not being any AS storage functions -- cgit 1.4.1 From 5d8949419850d13c9161d3bde209374803e9399b Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 13 Dec 2021 16:54:54 +0000 Subject: Fix two tests --- tests/handlers/test_appservice.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index e4ec149273..a59f37da69 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -426,7 +426,14 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # # The uninterested application service should not have been notified at all. self.send_mock.assert_called_once() - service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0] + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent self.assertEqual(service, interested_appservice) @@ -537,7 +544,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): service_id_to_message_count: Dict[str, int] = {} for call in self.send_mock.call_args_list: - service, _events, _ephemeral, to_device_messages = call[0] + service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] # Check that this was made to an interested service self.assertIn(service, interested_appservices) -- cgit 1.4.1 From 67e438af779f40b1001eeffb9de623a47f31b390 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 13 Dec 2021 17:06:17 +0000 Subject: Fix up tests that weren't expecting extra call arguments --- synapse/appservice/scheduler.py | 5 +++- tests/appservice/test_scheduler.py | 55 +++++++++++++++++++++++++++----------- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 9c8bec5ed7..686d98e791 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -201,7 +201,10 @@ class _ServiceQueuer: one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None - if service.msc3202_transaction_extensions: + if ( + self._msc3202_transaction_extensions_enabled + and service.msc3202_transaction_extensions + ): # Lazily compute the one-time key counts and fallback key # usage states for the users which are mentioned in this # transaction, as well as the appservice's sender. diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 8f9afa8538..bac6a66605 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -65,6 +65,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -89,6 +91,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -111,7 +115,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[], to_device_messages=[] + service=service, + events=events, + ephemeral=[], + to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -213,7 +222,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with(service, [event], [], []) + self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -228,11 +237,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], []) + self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], []) + self.txn_ctrl.send.assert_called_with( + service, [event2, event3], [], [], None, None + ) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): @@ -258,15 +269,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], []) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -285,13 +296,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.scheduler.enqueue_for_appservice(service, [event], []) # Expect the first event to be sent immediately. - self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], []) + self.txn_ctrl.send.assert_called_with( + service, [event_list[0]], [], [], None, None + ) srv_1_defer.callback(service) # Then send the next 100 events - self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[1:101], [], [], None, None + ) srv_2_defer.callback(service) # Then the final 99 events - self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[101:], [], [], None, None + ) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): @@ -299,14 +316,18 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4, name="service") event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_multiple_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_single_ephemeral_with_queue(self): d = defer.Deferred() @@ -321,13 +342,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1, []) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [] + service, [], event_list_2 + event_list_3, [], None, None ) self.assertEquals(2, self.txn_ctrl.send.call_count) @@ -340,7 +361,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], first_chunk, [], None, None + ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk, []) + self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) self.assertEquals(2, self.txn_ctrl.send.call_count) -- cgit 1.4.1 From b17f575d4212f8c7aaeb75195e67f4fa63d68e71 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 13 Dec 2021 17:34:26 +0000 Subject: Count the OTKs in bulk --- synapse/appservice/scheduler.py | 3 +- synapse/storage/databases/main/end_to_end_keys.py | 52 ++++++++++++++++++++++- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 686d98e791..0d11297f2b 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -280,8 +280,9 @@ class _ServiceQueuer: Given a list of application service users that are interesting, compute one-time key counts and fallback key usages for the users. """ + otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users) # OSTD implement me! - return {}, {} + return otk_counts, {} class _TransactionController: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index b06c1dc45b..9b1c0f12d4 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -22,9 +22,14 @@ from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection from synapse.api.constants import DeviceKeyAlgorithms +from synapse.appservice import TransactionOneTimeKeyCounts from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.types import JsonDict @@ -397,6 +402,49 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def count_bulk_e2e_one_time_keys_for_as( + self, user_ids: Collection[str] + ) -> TransactionOneTimeKeyCounts: + """ + Counts, in bulk, the one-time keys for all the users specified. + Intended to be used by application services for populating OTK counts in + transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithm -> count + """ + + def _count_bulk_e2e_one_time_keys_txn( + txn: LoggingTransaction, + ) -> TransactionOneTimeKeyCounts: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + sql = f""" + SELECT user_id, device_id, algorithm, COUNT(key_id) + FROM devices + LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id) + WHERE {user_in_where_clause} + GROUP BY user_id, device_id, algorithm + """ + txn.execute(sql, user_parameters) + + result = {} + + for user_id, device_id, algorithm, count in txn: + device_count_by_algo = result.setdefault(user_id, {}).setdefault( + device_id, {} + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_count_by_algo[algorithm] = count + + return result + + return await self.db_pool.runInteraction( + "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn + ) + async def set_e2e_fallback_keys( self, user_id: str, device_id: str, fallback_keys: JsonDict ) -> None: -- cgit 1.4.1 From d091b4de115526f7a3ccd8555b7899273fb2ab30 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 13 Dec 2021 17:47:57 +0000 Subject: Get unused fallback key types in bulk and send them out --- synapse/appservice/scheduler.py | 4 +- synapse/storage/databases/main/end_to_end_keys.py | 49 ++++++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 0d11297f2b..e37f07d9e0 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -281,8 +281,8 @@ class _ServiceQueuer: compute one-time key counts and fallback key usages for the users. """ otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users) - # OSTD implement me! - return otk_counts, {} + unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users) + return otk_counts, unused_fbks class _TransactionController: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 9b1c0f12d4..0646a6756d 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -22,7 +22,10 @@ from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection from synapse.api.constants import DeviceKeyAlgorithms -from synapse.appservice import TransactionOneTimeKeyCounts +from synapse.appservice import ( + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -445,6 +448,50 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn ) + async def get_e2e_bulk_unused_fallback_key_types( + self, user_ids: Collection[str] + ) -> TransactionUnusedFallbackKeys: + """ + Finds, in bulk, the types of unused fallback keys for all the users specified. + Intended to be used by application services for populating unused fallback + keys in transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithms + """ + + def _get_bulk_e2e_unused_fallback_keys_txn( + txn: LoggingTransaction, + ) -> TransactionUnusedFallbackKeys: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + sql = f""" + SELECT user_id, device_id, algorithm + FROM devices + LEFT JOIN e2e_fallback_keys_json USING (user_id, device_id) + WHERE + {user_in_where_clause} + AND NOT used + """ + txn.execute(sql, user_parameters) + + result = {} + + for user_id, device_id, algorithm in txn: + device_unused_keys = result.setdefault(user_id, {}).setdefault( + device_id, [] + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_unused_keys.append(algorithm) + + return result + + return await self.db_pool.runInteraction( + "_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn + ) + async def set_e2e_fallback_keys( self, user_id: str, device_id: str, fallback_keys: JsonDict ) -> None: -- cgit 1.4.1 From 7486fdb463b048b5a6c1baac26cf1dd8088179c0 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 17 Dec 2021 18:08:16 +0000 Subject: Break up `get_app_service_users_in_room` to make it easier to debug --- synapse/storage/databases/main/appservice.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index d9a79d6f4f..0a73a75aed 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -133,11 +133,8 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): async def get_app_service_users_in_room( self, room_id: str, app_service: "ApplicationService" ) -> List[str]: - return list( - filter( - app_service.is_interested_in_user, await self.get_users_in_room(room_id) - ) - ) + users_in_room = await self.get_users_in_room(room_id) + return list(filter(app_service.is_interested_in_user, users_in_room)) class ApplicationServiceStore(ApplicationServiceWorkerStore): -- cgit 1.4.1 From b50b46034d71334f0514fe579d8765017ad5b673 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 17 Dec 2021 18:08:49 +0000 Subject: Fix the get_bulk_e2e_unused_fallback_keys query to return devices with only used keys --- synapse/storage/databases/main/end_to_end_keys.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 0646a6756d..d55697c093 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -459,20 +459,31 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): Return structure is of the shape: user_id -> device_id -> algorithms """ + if len(user_ids) == 0: + return {} def _get_bulk_e2e_unused_fallback_keys_txn( txn: LoggingTransaction, ) -> TransactionUnusedFallbackKeys: user_in_where_clause, user_parameters = make_in_list_sql_clause( - self.database_engine, "user_id", user_ids + self.database_engine, "devices.user_id", user_ids ) + # We can't use USING here because we require the `.used` condition + # to be part of the JOIN condition so that we generate empty lists + # when all keys are used (as opposed to just when there are no keys at all). sql = f""" - SELECT user_id, device_id, algorithm + SELECT devices.user_id, devices.device_id, algorithm FROM devices - LEFT JOIN e2e_fallback_keys_json USING (user_id, device_id) + LEFT JOIN e2e_fallback_keys_json AS fallback_keys + /* We can't use USING here because we require the `.used` + condition to be part of the JOIN condition so that we + generate empty lists when all keys are used (as opposed + to just when there are no keys at all). */ + ON devices.user_id = fallback_keys.user_id + AND devices.device_id = fallback_keys.device_id + AND NOT fallback_keys.used WHERE {user_in_where_clause} - AND NOT used """ txn.execute(sql, user_parameters) -- cgit 1.4.1 From 7392d2f02febe2a1492036026acddf75de1f00aa Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 17 Dec 2021 18:18:48 +0000 Subject: Return the device ID when registering an appservice user in test helpers --- tests/handlers/test_user_directory.py | 6 ++++-- tests/storage/test_user_directory.py | 4 +++- tests/unittest.py | 8 ++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 70c621b825..482c90ef68 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -169,7 +169,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Register an AS user. user = self.register_user("user", "pass") token = self.login(user, "pass") - as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + as_user, _ = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) # Join the AS user to rooms owned by the normal user. public, private = self._create_rooms_and_inject_memberships( @@ -388,7 +390,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_handle_local_profile_change_with_appservice_user(self) -> None: # create user - as_user_id = self.register_appservice_user( + as_user_id, _ = self.register_appservice_user( "as_user_alice", self.appservice.token ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7f5b28aed8..48f1e9d841 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -341,7 +341,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): # Register an AS user. user = self.register_user("user", "pass") token = self.login(user, "pass") - as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + as_user, _ = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) # Join the AS user to rooms owned by the normal user. public, private = self._create_rooms_and_inject_memberships( diff --git a/tests/unittest.py b/tests/unittest.py index eea0903f05..780a8ba434 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -623,18 +623,18 @@ class HomeserverTestCase(TestCase): self, username: str, appservice_token: str, - ) -> str: + ) -> Tuple[str, str]: """Register an appservice user as an application service. Requires the client-facing registration API be registered. Args: username: the user to be registered by an application service. - Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart" + Should NOT be a full username, i.e. just "localpart" as opposed to "@localpart:hostname" appservice_token: the acccess token for that application service. Raises: if the request to '/register' does not return 200 OK. - Returns: the MXID of the new user. + Returns: the MXID of the new user, the device ID of the new user's first device. """ channel = self.make_request( "POST", @@ -646,7 +646,7 @@ class HomeserverTestCase(TestCase): access_token=appservice_token, ) self.assertEqual(channel.code, 200, channel.json_body) - return channel.json_body["user_id"] + return channel.json_body["user_id"], channel.json_body["device_id"] def login( self, -- cgit 1.4.1 From ff8555efc29e733f1f89b9ef71fe41c2a52617ec Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 17 Dec 2021 18:19:27 +0000 Subject: Allow passing an appservice_user_id to the join helper --- tests/rest/client/utils.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 1af5e5cee5..3d382761a8 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -147,12 +147,20 @@ class RestHelper: expect_code=expect_code, ) - def join(self, room=None, user=None, expect_code=200, tok=None): + def join( + self, + room=None, + user=None, + expect_code=200, + tok=None, + appservice_user_id: Optional[str] = None, + ): self.change_membership( room=room, src=user, targ=user, tok=tok, + appservice_user_id=appservice_user_id, membership=Membership.JOIN, expect_code=expect_code, ) @@ -204,6 +212,7 @@ class RestHelper: membership: str, extra_data: Optional[dict] = None, tok: Optional[str] = None, + appservice_user_id: Optional[str] = None, expect_code: int = 200, expect_errcode: Optional[str] = None, ) -> None: @@ -217,6 +226,9 @@ class RestHelper: membership: The type of membership event extra_data: Extra information to include in the content of the event tok: The user access token to use + appservice_user_id: The `user_id` URL parameter to pass. + This allows driving an application service user + using an application service access token in `tok`. expect_code: The expected HTTP response code expect_errcode: The expected Matrix error code """ @@ -224,8 +236,14 @@ class RestHelper: self.auth_user_id = src path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) + next_arg_char = "?" + if tok: - path = path + "?access_token=%s" % tok + path += "?access_token=%s" % tok + next_arg_char = "&" + + if appservice_user_id: + path += f"{next_arg_char}user_id={appservice_user_id}" data = {"membership": membership} data.update(extra_data or {}) -- cgit 1.4.1 From db7bd678f05e6c786ce6c0c1e25b51a1d0ef004c Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 17 Dec 2021 18:19:55 +0000 Subject: Add a test for sending OTKs and UFBKs to ASes upon receiving PDUs --- tests/handlers/test_appservice.py | 185 +++++++++++++++++++++++++++++++++++++- 1 file changed, 183 insertions(+), 2 deletions(-) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index a59f37da69..6314179aef 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -16,17 +16,25 @@ from typing import Dict, Iterable, List, Optional from unittest.mock import Mock from twisted.internet import defer +from twisted.internet.testing import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.appservice import ApplicationService +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.rest.client import login, receipts, room, sendtodevice +from synapse.rest.client import login, receipts, register, room, sendtodevice +from synapse.server import HomeServer from synapse.types import RoomStreamToken +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config from tests.utils import MockClock @@ -586,3 +594,176 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self._services.append(appservice) return appservice + + +class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): + # Argument indices for pulling out arguments from a `send_mock`. + ARG_OTK_COUNTS = 4 + ARG_FALLBACK_KEYS = 5 + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + room.register_servlets, + sendtodevice.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Mock the ApplicationServiceScheduler's _TransactionController's send method so that + # we can track what's going out + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. + + # Define an application service for the tests + self._service_token = "VERYSECRET" + self._service = ApplicationService( + self._service_token, + "as1.invalid", + "as1", + "@as.sender:test", + namespaces={ + "users": [ + {"regex": "@_as_.*:test", "exclusive": True}, + {"regex": "@as.sender:test", "exclusive": True}, + ] + }, + msc3202_transaction_extensions=True, + ) + self.hs.get_datastore().services_cache = [self._service] + + # Register some appservice users + self._sender_user, self._sender_device = self.register_appservice_user( + "as.sender", self._service_token + ) + self._namespaced_user, self._namespaced_device = self.register_appservice_user( + "_as_user1", self._service_token + ) + + # Register a real user as well. + self._real_user = self.register_user("real.user", "meow") + self._real_user_token = self.login("real.user", "meow") + + async def _add_otks_for_device( + self, user_id: str, device_id: str, otk_count: int + ) -> None: + """ + Add some dummy keys. It doesn't matter if they're not a real algorithm; + that should be opaque to the server anyway. + """ + await self.hs.get_datastore().add_e2e_one_time_keys( + user_id, + device_id, + self.clock.time_msec(), + [("algo", f"k{i}", "{}") for i in range(otk_count)], + ) + + async def _add_fallback_key_for_device( + self, user_id: str, device_id: str, used: bool + ) -> None: + """ + Adds a fake fallback key to a device, optionally marking it as used + right away. + """ + store = self.hs.get_datastore() + await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"}) + if used is True: + # Mark the key as used + await store.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": "algo", + "key_id": "fk", + }, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", + ) + + def _set_up_devices_and_a_room(self) -> str: + """ + Helper to set up devices for all the users + and a room for the users to talk in. + """ + + async def preparation(): + await self._add_otks_for_device(self._sender_user, self._sender_device, 42) + await self._add_fallback_key_for_device( + self._sender_user, self._sender_device, used=True + ) + await self._add_otks_for_device( + self._namespaced_user, self._namespaced_device, 36 + ) + await self._add_fallback_key_for_device( + self._namespaced_user, self._namespaced_device, used=False + ) + + # Register a device for the real user, too, so that we can later ensure + # that we don't leak information to the AS about the non-AS user. + await self.hs.get_datastore().store_device( + self._real_user, "REALDEV", "UltraMatrix 3000" + ) + await self._add_otks_for_device(self._real_user, "REALDEV", 50) + + self.get_success(preparation()) + + room_id = self.helper.create_room_as( + self._real_user, is_public=True, tok=self._real_user_token + ) + self.helper.join( + room_id, + self._namespaced_user, + tok=self._service_token, + appservice_user_id=self._namespaced_user, + ) + + # Check it was called for sanity. (This was to send the join event to the AS.) + self.send_mock.assert_called() + self.send_mock.reset_mock() + + return room_id + + @override_config( + {"experimental_features": {"msc3202_transaction_extensions": True}} + ) + def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus( + self, + ) -> None: + """ + Tests that: + - the AS receives one-time key counts and unused fallback keys for: + - the specified sender; and + - any user who is in receipt of the PDUs + """ + + room_id = self._set_up_devices_and_a_room() + + # Send a message into the AS's room + self.helper.send(room_id, "woof woof", tok=self._real_user_token) + + # Capture what was sent as an AS transaction. + self.send_mock.assert_called() + last_args = self.send_mock.call_args + otks: Optional[TransactionOneTimeKeyCounts] = last_args.args[ + self.ARG_OTK_COUNTS + ] + unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args.args[ + self.ARG_FALLBACK_KEYS + ] + + self.assertEqual( + otks, + { + "@as.sender:test": {self._sender_device: {"algo": 42}}, + "@_as_user1:test": {self._namespaced_device: {"algo": 36}}, + }, + ) + self.assertEqual( + unused_fallbacks, + { + "@as.sender:test": {self._sender_device: []}, + "@_as_user1:test": {self._namespaced_device: ["algo"]}, + }, + ) -- cgit 1.4.1 From 537adac389ff682725ae5444b8bf942ecf398494 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 20 Dec 2021 16:14:25 +0000 Subject: Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/11617.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/11617.feature diff --git a/changelog.d/11617.feature b/changelog.d/11617.feature new file mode 100644 index 0000000000..cf03f00e7c --- /dev/null +++ b/changelog.d/11617.feature @@ -0,0 +1 @@ +Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. \ No newline at end of file -- cgit 1.4.1