summary refs log tree commit diff
path: root/synapse/storage/data_stores/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/data_stores/main/devices.py')
-rw-r--r--synapse/storage/data_stores/main/devices.py64
1 files changed, 41 insertions, 23 deletions
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py

index fe6d6ecfe0..417ac8dc7c 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py
@@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple +from typing import List, Optional, Set, Tuple from six import iteritems @@ -649,21 +649,31 @@ class DeviceWorkerStore(SQLBaseStore): return results @defer.inlineCallbacks - def get_user_ids_requiring_device_list_resync(self, user_ids: Collection[str]): + def get_user_ids_requiring_device_list_resync( + self, user_ids: Optional[Collection[str]] = None, + ) -> Set[str]: """Given a list of remote users return the list of users that we - should resync the device lists for. + should resync the device lists for. If None is given instead of a list, + return every user that we should resync the device lists for. Returns: - Deferred[Set[str]] + The IDs of users whose device lists need resync. """ - - rows = yield self.db.simple_select_many_batch( - table="device_lists_remote_resync", - column="user_id", - iterable=user_ids, - retcols=("user_id",), - desc="get_user_ids_requiring_device_list_resync", - ) + if user_ids: + rows = yield self.db.simple_select_many_batch( + table="device_lists_remote_resync", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="get_user_ids_requiring_device_list_resync_with_iterable", + ) + else: + rows = yield self.db.simple_select_list( + table="device_lists_remote_resync", + keyvalues=None, + retcols=("user_id",), + desc="get_user_ids_requiring_device_list_resync", + ) return {row["user_id"] for row in rows} @@ -679,6 +689,25 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ + + def _mark_remote_user_device_list_as_unsubscribed_txn(txn): + self.db.simple_delete_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_device_list_last_stream_id_for_remote, (user_id,) + ) + + return self.db.runInteraction( + "mark_remote_user_device_list_as_unsubscribed", + _mark_remote_user_device_list_as_unsubscribed_txn, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): @@ -959,17 +988,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): desc="update_device", ) - @defer.inlineCallbacks - def mark_remote_user_device_list_as_unsubscribed(self, user_id): - """Mark that we no longer track device lists for remote user. - """ - yield self.db.simple_delete( - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - desc="mark_remote_user_device_list_as_unsubscribed", - ) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) - def update_remote_device_list_cache_entry( self, user_id, device_id, content, stream_id ):