summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/appservice.py14
-rw-r--r--synapse/storage/databases/main/devices.py48
2 files changed, 44 insertions, 18 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index abea4383c7..55e1ab099d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -29,7 +29,7 @@ from synapse.storage._base import db_to_json
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.types import JsonDict
+from synapse.types import DeviceListUpdates, JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import _CacheContext, cached
 
@@ -217,6 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
         to_device_messages: List[JsonDict],
         one_time_key_counts: TransactionOneTimeKeyCounts,
         unused_fallback_keys: TransactionUnusedFallbackKeys,
+        device_list_summary: DeviceListUpdates,
     ) -> AppServiceTransaction:
         """Atomically creates a new transaction for this application service
         with the given list of events. Ephemeral events are NOT persisted to the
@@ -231,6 +232,7 @@ class ApplicationServiceTransactionWorkerStore(
                 appservice devices in the transaction.
             unused_fallback_keys: Lists of unused fallback keys for relevant
                 appservice devices in the transaction.
+            device_list_summary: The device list summary to include in the transaction.
 
         Returns:
             A new transaction.
@@ -268,6 +270,7 @@ class ApplicationServiceTransactionWorkerStore(
                 to_device_messages=to_device_messages,
                 one_time_key_counts=one_time_key_counts,
                 unused_fallback_keys=unused_fallback_keys,
+                device_list_summary=device_list_summary,
             )
 
         return await self.db_pool.runInteraction(
@@ -359,8 +362,8 @@ class ApplicationServiceTransactionWorkerStore(
 
         events = await self.get_events_as_list(event_ids)
 
-        # TODO: to-device messages, one-time key counts and unused fallback keys
-        #       are not yet populated for catch-up transactions.
+        # TODO: to-device messages, one-time key counts, device list summaries and unused
+        #       fallback keys are not yet populated for catch-up transactions.
         #       We likely want to populate those for reliability.
         return AppServiceTransaction(
             service=service,
@@ -370,6 +373,7 @@ class ApplicationServiceTransactionWorkerStore(
             to_device_messages=[],
             one_time_key_counts={},
             unused_fallback_keys={},
+            device_list_summary=DeviceListUpdates(),
         )
 
     def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@@ -430,7 +434,7 @@ class ApplicationServiceTransactionWorkerStore(
     async def get_type_stream_id_for_appservice(
         self, service: ApplicationService, type: str
     ) -> int:
-        if type not in ("read_receipt", "presence", "to_device"):
+        if type not in ("read_receipt", "presence", "to_device", "device_list"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (type,)
@@ -458,7 +462,7 @@ class ApplicationServiceTransactionWorkerStore(
     async def set_appservice_stream_type_pos(
         self, service: ApplicationService, stream_type: str, pos: Optional[int]
     ) -> None:
-        if stream_type not in ("read_receipt", "presence", "to_device"):
+        if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (stream_type,)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 3b3a089b76..f08f7834d3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore):
         return self._device_list_stream_cache.get_all_entities_changed(from_key)
 
     async def get_users_whose_devices_changed(
-        self, from_key: int, user_ids: Iterable[str]
+        self,
+        from_key: int,
+        user_ids: Optional[Iterable[str]] = None,
+        to_key: Optional[int] = None,
     ) -> Set[str]:
         """Get set of users whose devices have changed since `from_key` that
         are in the given list of user_ids.
 
         Args:
-            from_key: The device lists stream token
-            user_ids: The user IDs to query for devices.
+            from_key: The minimum device lists stream token to query device list changes for,
+                exclusive.
+            user_ids: If provided, only check if these users have changed their device lists.
+                Otherwise changes from all users are returned.
+            to_key: The maximum device lists stream token to query device list changes for,
+                inclusive.
 
         Returns:
-            The set of user_ids whose devices have changed since `from_key`
+            The set of user_ids whose devices have changed since `from_key` (exclusive)
+                until `to_key` (inclusive).
         """
-
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
-        to_check = self._device_list_stream_cache.get_entities_changed(
-            user_ids, from_key
-        )
+        if user_ids is None:
+            # Get set of all users that have had device list changes since 'from_key'
+            user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
+                from_key
+            )
+        else:
+            # The same as above, but filter results to only those users in 'user_ids'
+            user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+                user_ids, from_key
+            )
 
-        if not to_check:
+        if not user_ids_to_check:
             return set()
 
         def _get_users_whose_devices_changed_txn(txn):
             changes = set()
 
-            sql = """
+            stream_id_where_clause = "stream_id > ?"
+            sql_args = [from_key]
+
+            if to_key:
+                stream_id_where_clause += " AND stream_id <= ?"
+                sql_args.append(to_key)
+
+            sql = f"""
                 SELECT DISTINCT user_id FROM device_lists_stream
-                WHERE stream_id > ?
+                WHERE {stream_id_where_clause}
                 AND
             """
 
-            for chunk in batch_iter(to_check, 100):
+            # Query device changes with a batch of users at a time
+            for chunk in batch_iter(user_ids_to_check, 100):
                 clause, args = make_in_list_sql_clause(
                     txn.database_engine, "user_id", chunk
                 )
-                txn.execute(sql + clause, (from_key,) + tuple(args))
+                txn.execute(sql + clause, sql_args + args)
                 changes.update(user_id for user_id, in txn)
 
             return changes