diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/appservice/__init__.py | 18 | ||||
-rw-r--r-- | synapse/appservice/api.py | 18 | ||||
-rw-r--r-- | synapse/appservice/scheduler.py | 106 | ||||
-rw-r--r-- | synapse/config/appservice.py | 13 | ||||
-rw-r--r-- | synapse/config/experimental.py | 6 | ||||
-rw-r--r-- | synapse/storage/databases/main/appservice.py | 37 | ||||
-rw-r--r-- | synapse/storage/databases/main/end_to_end_keys.py | 110 |
7 files changed, 293 insertions, 15 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 01db2b2ae3..632d5d133c 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,7 +14,7 @@ import logging import re from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Match, Optional +from typing import TYPE_CHECKING, Dict, Iterable, List, Match, Optional from synapse.api.constants import EventTypes from synapse.events import EventBase @@ -27,6 +27,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Type for the `device_one_time_key_counts` field in an appservice transaction +# user ID -> {device ID -> {algorithm -> count}} +TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] + +# Type for the `device_unused_fallback_keys` field in an appservice transaction +# user ID -> {device ID -> [algorithm]} +TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]] + class ApplicationServiceState(Enum): DOWN = "down" @@ -61,6 +69,7 @@ class ApplicationService: rate_limited=True, ip_range_whitelist=None, supports_ephemeral=False, + msc3202_transaction_extensions: bool = False, ): self.token = token self.url = ( @@ -73,6 +82,7 @@ class ApplicationService: self.id = id self.ip_range_whitelist = ip_range_whitelist self.supports_ephemeral = supports_ephemeral + self.msc3202_transaction_extensions = msc3202_transaction_extensions if "|" in self.id: raise Exception("application service ID cannot contain '|' character") @@ -371,6 +381,8 @@ class AppServiceTransaction: ephemeral: List[JsonDict], to_device_messages: List[JsonDict], device_list_summary: DeviceLists, + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ): self.service = service self.id = id @@ -378,6 +390,8 @@ class AppServiceTransaction: self.ephemeral = ephemeral self.to_device_messages = to_device_messages self.device_list_summary = device_list_summary + self.one_time_key_counts = one_time_key_counts + self.unused_fallback_keys = unused_fallback_keys async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -393,6 +407,8 @@ class AppServiceTransaction: ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, device_list_summary=self.device_list_summary, + one_time_key_counts=self.one_time_key_counts, + unused_fallback_keys=self.unused_fallback_keys, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 3ae59c7a04..0333832a9c 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -20,6 +20,11 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient @@ -27,7 +32,6 @@ from synapse.types import DeviceLists, JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: - from synapse.appservice import ApplicationService from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -207,6 +211,8 @@ class ApplicationServiceApi(SimpleHttpClient): ephemeral: List[JsonDict], to_device_messages: List[JsonDict], device_list_summary: DeviceLists, + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, txn_id: Optional[int] = None, ) -> bool: """ @@ -258,6 +264,16 @@ class ApplicationServiceApi(SimpleHttpClient): } ) + if service.msc3202_transaction_extensions: + if one_time_key_counts: + body[ + "org.matrix.msc3202.device_one_time_key_counts" + ] = one_time_key_counts + if unused_fallback_keys: + body[ + "org.matrix.msc3202.device_unused_fallback_keys" + ] = unused_fallback_keys + try: await self.put_json( uri=uri, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index d49636d926..d59bf3e7a0 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,14 +48,22 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging -from typing import Dict, Iterable, List, Optional - -from synapse.appservice import ApplicationService, ApplicationServiceState +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple + +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import DeviceLists, JsonDict +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -81,7 +89,7 @@ class ApplicationServiceScheduler: self.as_api = hs.get_application_service_api() self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) - self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) + self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs) async def start(self): logger.info("Starting appservice scheduler") @@ -151,7 +159,7 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__(self, txn_ctrl, clock): + def __init__(self, txn_ctrl, clock, hs: "HomeServer"): # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [event_json]} @@ -162,9 +170,13 @@ class _ServiceQueuer: self.queued_device_list_summaries: Dict[str, List[DeviceLists]] = {} # the appservices which currently have a transaction in flight - self.requests_in_flight = set() + self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock + self._msc3202_transaction_extensions_enabled: bool = ( + hs.config.experimental.msc3202_transaction_extensions + ) + self._store = hs.get_datastore() def start_background_request(self, service): # start a sender for this appservice if we don't already have one @@ -235,6 +247,26 @@ class _ServiceQueuer: ): return + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None + + if ( + self._msc3202_transaction_extensions_enabled + and service.msc3202_transaction_extensions + ): + # Lazily compute the one-time key counts and fallback key + # usage states for the users which are mentioned in this + # transaction, as well as the appservice's sender. + interesting_users = await self._determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys( + service, events, ephemeral, to_device_messages_to_send + ) + ( + one_time_key_counts, + unused_fallback_keys, + ) = await self._compute_msc3202_otk_counts_and_fallback_keys( + interesting_users + ) + try: await self.txn_ctrl.send( service, @@ -242,12 +274,66 @@ class _ServiceQueuer: ephemeral, to_device_messages_to_send, device_list_summary, + one_time_key_counts, + unused_fallback_keys, ) except Exception: logger.exception("AS request failed") finally: self.requests_in_flight.discard(service.id) + async def _determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys( + self, + service: ApplicationService, + events: Iterable[EventBase], + ephemerals: Iterable[JsonDict], + to_device_messages: Iterable[JsonDict], + ) -> Set[str]: + """ + Given a list of the events, ephemeral messages and to-device messaeges, + compute a list of application services users that may have interesting + updates to the one-time key counts or fallback key usage. + """ + interesting_users: Set[str] = set() + + # The sender is always included + interesting_users.add(service.sender) + + # All AS users that would receive the PDUs or EDUs sent to these rooms + # are classed as 'interesting'. + rooms_of_interesting_users: Set[str] = set() + # PDUs + rooms_of_interesting_users.update(event.room_id for event in events) + # EDUs + rooms_of_interesting_users.update( + ephemeral["room_id"] for ephemeral in ephemerals + ) + + # Look up the AS users in those rooms + for room_id in rooms_of_interesting_users: + interesting_users.update( + await self._store.get_app_service_users_in_room(room_id, service) + ) + + # Add recipients of to-device messages. + # device_message["user_id"] is the ID of the recipient. + interesting_users.update( + device_message["user_id"] for device_message in to_device_messages + ) + + return interesting_users + + async def _compute_msc3202_otk_counts_and_fallback_keys( + self, users: Set[str] + ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + """ + Given a list of application service users that are interesting, + compute one-time key counts and fallback key usages for the users. + """ + otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users) + unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users) + return otk_counts, unused_fbks + class _TransactionController: """Transaction manager. @@ -281,6 +367,8 @@ class _TransactionController: ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, device_list_summary: Optional[DeviceLists] = None, + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -292,6 +380,10 @@ class _TransactionController: ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. device_list_summary: The device list summary to include in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -300,6 +392,8 @@ class _TransactionController: ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], device_list_summary=device_list_summary or DeviceLists(), + one_time_key_counts=one_time_key_counts or {}, + unused_fallback_keys=unused_fallback_keys or {}, ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index e4bb7224a4..c77f4058b3 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -167,6 +167,16 @@ def _load_appservice( supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) + # Opt-in flag for the MSC3202-specific transactional behaviour. + # When enabled, appservice transactions contain the following information: + # - device One-Time Key counts + # - device unused fallback key usage states + msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) + if not isinstance(msc3202_transaction_extensions, bool): + raise ValueError( + "The `org.matrix.msc3202` option should be true or false if specified." + ) + return ApplicationService( token=as_info["as_token"], hostname=hostname, @@ -175,8 +185,9 @@ def _load_appservice( hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], - supports_ephemeral=supports_ephemeral, protocols=protocols, rate_limited=rate_limited, ip_range_whitelist=ip_range_whitelist, + supports_ephemeral=supports_ephemeral, + msc3202_transaction_extensions=msc3202_transaction_extensions, ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d19165e5b4..058abeb194 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -66,3 +66,9 @@ class ExperimentalConfig(Config): self.msc3202_device_masquerading_enabled: bool = experimental.get( "msc3202_device_masquerading", False ) + + # Portion of MSC3202 related to transaction extensions: + # sending one-time key counts and fallback key usage to application services. + self.msc3202_transaction_extensions: bool = experimental.get( + "msc3202_transaction_extensions", False + ) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 0ac2005bee..d43153ea23 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,15 +20,19 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, AppServiceTransaction, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices from synapse.events import EventBase -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage._base import db_to_json +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.types import Connection from synapse.types import DeviceLists, JsonDict +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.util import json_encoder +from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -57,8 +61,13 @@ def _make_exclusive_regex( return exclusive_user_pattern -class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): +class ApplicationServiceWorkerStore(RoomMemberWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.services_cache = load_appservices( hs.hostname, hs.config.appservice.app_service_config_files ) @@ -120,6 +129,14 @@ class ApplicationServiceWorkerStore(SQLBaseStore): return service return None + # OSTD cache invalidation + @cached(iterable=True, prune_unread_entries=False) + async def get_app_service_users_in_room( + self, room_id: str, app_service: "ApplicationService" + ) -> List[str]: + users_in_room = await self.get_users_in_room(room_id) + return list(filter(app_service.is_interested_in_user, users_in_room)) + class ApplicationServiceStore(ApplicationServiceWorkerStore): # This is currently empty due to there not being any AS storage functions @@ -196,6 +213,8 @@ class ApplicationServiceTransactionWorkerStore( ephemeral: List[JsonDict], to_device_messages: List[JsonDict], device_list_summary: DeviceLists, + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -207,6 +226,10 @@ class ApplicationServiceTransactionWorkerStore( ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. device_list_summary: The device list summary to include in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. Returns: A new transaction. @@ -243,6 +266,8 @@ class ApplicationServiceTransactionWorkerStore( ephemeral=ephemeral, to_device_messages=to_device_messages, device_list_summary=device_list_summary, + one_time_key_counts=one_time_key_counts, + unused_fallback_keys=unused_fallback_keys, ) return await self.db_pool.runInteraction( @@ -334,6 +359,8 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) + # TODO: should we recalculate one-time key counts and unused fallback + # key counts here? return AppServiceTransaction( service=service, id=entry["txn_id"], @@ -341,6 +368,8 @@ class ApplicationServiceTransactionWorkerStore( ephemeral=[], to_device_messages=[], device_list_summary=DeviceLists(), + one_time_key_counts={}, + unused_fallback_keys={}, ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index b06c1dc45b..d55697c093 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -22,9 +22,17 @@ from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection from synapse.api.constants import DeviceKeyAlgorithms +from synapse.appservice import ( + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.types import JsonDict @@ -397,6 +405,104 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def count_bulk_e2e_one_time_keys_for_as( + self, user_ids: Collection[str] + ) -> TransactionOneTimeKeyCounts: + """ + Counts, in bulk, the one-time keys for all the users specified. + Intended to be used by application services for populating OTK counts in + transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithm -> count + """ + + def _count_bulk_e2e_one_time_keys_txn( + txn: LoggingTransaction, + ) -> TransactionOneTimeKeyCounts: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + sql = f""" + SELECT user_id, device_id, algorithm, COUNT(key_id) + FROM devices + LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id) + WHERE {user_in_where_clause} + GROUP BY user_id, device_id, algorithm + """ + txn.execute(sql, user_parameters) + + result = {} + + for user_id, device_id, algorithm, count in txn: + device_count_by_algo = result.setdefault(user_id, {}).setdefault( + device_id, {} + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_count_by_algo[algorithm] = count + + return result + + return await self.db_pool.runInteraction( + "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn + ) + + async def get_e2e_bulk_unused_fallback_key_types( + self, user_ids: Collection[str] + ) -> TransactionUnusedFallbackKeys: + """ + Finds, in bulk, the types of unused fallback keys for all the users specified. + Intended to be used by application services for populating unused fallback + keys in transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithms + """ + if len(user_ids) == 0: + return {} + + def _get_bulk_e2e_unused_fallback_keys_txn( + txn: LoggingTransaction, + ) -> TransactionUnusedFallbackKeys: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "devices.user_id", user_ids + ) + # We can't use USING here because we require the `.used` condition + # to be part of the JOIN condition so that we generate empty lists + # when all keys are used (as opposed to just when there are no keys at all). + sql = f""" + SELECT devices.user_id, devices.device_id, algorithm + FROM devices + LEFT JOIN e2e_fallback_keys_json AS fallback_keys + /* We can't use USING here because we require the `.used` + condition to be part of the JOIN condition so that we + generate empty lists when all keys are used (as opposed + to just when there are no keys at all). */ + ON devices.user_id = fallback_keys.user_id + AND devices.device_id = fallback_keys.device_id + AND NOT fallback_keys.used + WHERE + {user_in_where_clause} + """ + txn.execute(sql, user_parameters) + + result = {} + + for user_id, device_id, algorithm in txn: + device_unused_keys = result.setdefault(user_id, {}).setdefault( + device_id, [] + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_unused_keys.append(algorithm) + + return result + + return await self.db_pool.runInteraction( + "_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn + ) + async def set_e2e_fallback_keys( self, user_id: str, device_id: str, fallback_keys: JsonDict ) -> None: |