diff options
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/appservice.py | 14 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 48 |
2 files changed, 44 insertions, 18 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index abea4383c7..55e1ab099d 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -29,7 +29,7 @@ from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.types import JsonDict +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import _CacheContext, cached @@ -217,6 +217,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -231,6 +232,7 @@ class ApplicationServiceTransactionWorkerStore( appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. + device_list_summary: The device list summary to include in the transaction. Returns: A new transaction. @@ -268,6 +270,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages=to_device_messages, one_time_key_counts=one_time_key_counts, unused_fallback_keys=unused_fallback_keys, + device_list_summary=device_list_summary, ) return await self.db_pool.runInteraction( @@ -359,8 +362,8 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) - # TODO: to-device messages, one-time key counts and unused fallback keys - # are not yet populated for catch-up transactions. + # TODO: to-device messages, one-time key counts, device list summaries and unused + # fallback keys are not yet populated for catch-up transactions. # We likely want to populate those for reliability. return AppServiceTransaction( service=service, @@ -370,6 +373,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages=[], one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: @@ -430,7 +434,7 @@ class ApplicationServiceTransactionWorkerStore( async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: - if type not in ("read_receipt", "presence", "to_device"): + if type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (type,) @@ -458,7 +462,7 @@ class ApplicationServiceTransactionWorkerStore( async def set_appservice_stream_type_pos( self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if stream_type not in ("read_receipt", "presence", "to_device"): + if stream_type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (stream_type,) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3b3a089b76..f08f7834d3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore): return self._device_list_stream_cache.get_all_entities_changed(from_key) async def get_users_whose_devices_changed( - self, from_key: int, user_ids: Iterable[str] + self, + from_key: int, + user_ids: Optional[Iterable[str]] = None, + to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key: The device lists stream token - user_ids: The user IDs to query for devices. + from_key: The minimum device lists stream token to query device list changes for, + exclusive. + user_ids: If provided, only check if these users have changed their device lists. + Otherwise changes from all users are returned. + to_key: The maximum device lists stream token to query device list changes for, + inclusive. Returns: - The set of user_ids whose devices have changed since `from_key` + The set of user_ids whose devices have changed since `from_key` (exclusive) + until `to_key` (inclusive). """ - # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + if user_ids is None: + # Get set of all users that have had device list changes since 'from_key' + user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( + from_key + ) + else: + # The same as above, but filter results to only those users in 'user_ids' + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) - if not to_check: + if not user_ids_to_check: return set() def _get_users_whose_devices_changed_txn(txn): changes = set() - sql = """ + stream_id_where_clause = "stream_id > ?" + sql_args = [from_key] + + if to_key: + stream_id_where_clause += " AND stream_id <= ?" + sql_args.append(to_key) + + sql = f""" SELECT DISTINCT user_id FROM device_lists_stream - WHERE stream_id > ? + WHERE {stream_id_where_clause} AND """ - for chunk in batch_iter(to_check, 100): + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "user_id", chunk ) - txn.execute(sql + clause, (from_key,) + tuple(args)) + txn.execute(sql + clause, sql_args + args) changes.update(user_id for user_id, in txn) return changes |