diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/devices.py | 50 |
1 files changed, 38 insertions, 12 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 3413a46675..d2b113a4e7 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -24,6 +24,7 @@ from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import Cache, SQLBaseStore, db_to_json from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList logger = logging.getLogger(__name__) @@ -391,22 +392,47 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, [] - @defer.inlineCallbacks - def get_user_whose_devices_changed(self, from_key): - """Get set of users whose devices have changed since `from_key`. + def get_users_whose_devices_changed(self, from_key, user_ids): + """Get set of users whose devices have changed since `from_key` that + are in the given list of user_ids. + + Args: + from_key (str): The device lists stream token + user_ids (Iterable[str]) + + Returns: + Deferred[set[str]]: The set of user_ids whose devices have changed + since `from_key` """ from_key = int(from_key) - changed = self._device_list_stream_cache.get_all_entities_changed(from_key) - if changed is not None: - defer.returnValue(set(changed)) - sql = """ - SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? - """ - rows = yield self._execute( - "get_user_whose_devices_changed", None, sql, from_key + # Get set of users who *may* have changed. Users not in the returned + # list have definitely not changed. + to_check = list( + self._device_list_stream_cache.get_entities_changed(user_ids, from_key) + ) + + if not to_check: + return defer.succeed(set()) + + def _get_users_whose_devices_changed_txn(txn): + changes = set() + + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream + WHERE stream_id > ? + AND user_id IN (%s) + """ + + for chunk in batch_iter(to_check, 100): + txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk) + changes.update(user_id for user_id, in txn) + + return changes + + return self.runInteraction( + "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) - defer.returnValue(set(row[0] for row in rows)) def get_all_device_list_changes_for_remotes(self, from_key, to_key): """Return a list of `(stream_id, user_id, destination)` which is the |