diff options
Diffstat (limited to 'synapse/storage/devices.py')
-rw-r--r-- | synapse/storage/devices.py | 62 |
1 files changed, 42 insertions, 20 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d102e07372..d2b113a4e7 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -24,6 +24,7 @@ from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import Cache, SQLBaseStore, db_to_json from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList logger = logging.getLogger(__name__) @@ -149,9 +150,7 @@ class DeviceWorkerStore(SQLBaseStore): defer.returnValue((stream_id_cutoff, [])) results = yield self._get_device_update_edus_by_remote( - destination, - from_stream_id, - query_map, + destination, from_stream_id, query_map ) defer.returnValue((now_stream_id, results)) @@ -182,9 +181,7 @@ class DeviceWorkerStore(SQLBaseStore): return list(txn) @defer.inlineCallbacks - def _get_device_update_edus_by_remote( - self, destination, from_stream_id, query_map, - ): + def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): """Returns a list of device update EDUs as well as E2EE keys Args: @@ -210,7 +207,7 @@ class DeviceWorkerStore(SQLBaseStore): # The prev_id for the first row is always the last row before # `from_stream_id` prev_id = yield self._get_last_device_update_for_remote_user( - destination, user_id, from_stream_id, + destination, user_id, from_stream_id ) for device_id, device in iteritems(user_devices): stream_id = query_map[(user_id, device_id)] @@ -238,7 +235,7 @@ class DeviceWorkerStore(SQLBaseStore): defer.returnValue(results) def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id, + self, destination, user_id, from_stream_id ): def f(txn): prev_sent_id_sql = """ @@ -395,22 +392,47 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, [] - @defer.inlineCallbacks - def get_user_whose_devices_changed(self, from_key): - """Get set of users whose devices have changed since `from_key`. + def get_users_whose_devices_changed(self, from_key, user_ids): + """Get set of users whose devices have changed since `from_key` that + are in the given list of user_ids. + + Args: + from_key (str): The device lists stream token + user_ids (Iterable[str]) + + Returns: + Deferred[set[str]]: The set of user_ids whose devices have changed + since `from_key` """ from_key = int(from_key) - changed = self._device_list_stream_cache.get_all_entities_changed(from_key) - if changed is not None: - defer.returnValue(set(changed)) - sql = """ - SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? - """ - rows = yield self._execute( - "get_user_whose_devices_changed", None, sql, from_key + # Get set of users who *may* have changed. Users not in the returned + # list have definitely not changed. + to_check = list( + self._device_list_stream_cache.get_entities_changed(user_ids, from_key) + ) + + if not to_check: + return defer.succeed(set()) + + def _get_users_whose_devices_changed_txn(txn): + changes = set() + + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream + WHERE stream_id > ? + AND user_id IN (%s) + """ + + for chunk in batch_iter(to_check, 100): + txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk) + changes.update(user_id for user_id, in txn) + + return changes + + return self.runInteraction( + "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) - defer.returnValue(set(row[0] for row in rows)) def get_all_device_list_changes_for_remotes(self, from_key, to_key): """Return a list of `(stream_id, user_id, destination)` which is the |