diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 7a5f0bab05..2b33060480 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, Optional, Set, Tuple
-
-from twisted.internet import defer
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -33,14 +31,9 @@ from synapse.storage.database import (
LoggingTransaction,
make_tuple_comparison_clause,
)
-from synapse.types import Collection, get_verify_key_from_cross_signing_key
+from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import (
- Cache,
- cached,
- cachedInlineCallbacks,
- cachedList,
-)
+from synapse.util.caches.descriptors import Cache, cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -54,13 +47,13 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def get_device(self, user_id, device_id):
+ def get_device(self, user_id: str, device_id: str):
"""Retrieve a device. Only returns devices that are not marked as
hidden.
Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to retrieve
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
Raises:
@@ -73,19 +66,17 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_device",
)
- @defer.inlineCallbacks
- def get_devices_by_user(self, user_id):
+ async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
Args:
- user_id (str):
+ user_id:
Returns:
- defer.Deferred: resolves to a dict from device_id to a dict
- containing "device_id", "user_id" and "display_name" for each
- device.
+ A mapping from device_id to a dict containing "device_id", "user_id"
+ and "display_name" for each device.
"""
- devices = yield self.db_pool.simple_select_list(
+ devices = await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -95,19 +86,20 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
@trace
- @defer.inlineCallbacks
- def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ async def get_device_updates_by_remote(
+ self, destination: str, from_stream_id: int, limit: int
+ ) -> Tuple[int, List[Tuple[str, dict]]]:
"""Get a stream of device updates to send to the given remote server.
Args:
- destination (str): The host the device updates are intended for
- from_stream_id (int): The minimum stream_id to filter updates by, exclusive
- limit (int): Maximum number of device updates to return
+ destination: The host the device updates are intended for
+ from_stream_id: The minimum stream_id to filter updates by, exclusive
+ limit: Maximum number of device updates to return
+
Returns:
- Deferred[tuple[int, list[tuple[string,dict]]]]:
- current stream id (ie, the stream id of the last update included in the
- response), and the list of updates, where each update is a pair of EDU
- type and EDU contents
+ A mapping from the current stream id (ie, the stream id of the last
+ update included in the response), and the list of updates, where
+ each update is a pair of EDU type and EDU contents.
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -117,7 +109,7 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed:
return now_stream_id, []
- updates = yield self.db_pool.runInteraction(
+ updates = await self.db_pool.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
@@ -136,9 +128,7 @@ class DeviceWorkerStore(SQLBaseStore):
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
- cross_signing_key = yield defer.ensureDeferred(
- self.get_e2e_cross_signing_key(user, "master")
- )
+ cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
@@ -151,8 +141,8 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- cross_signing_key = yield defer.ensureDeferred(
- self.get_e2e_cross_signing_key(user, "self_signing")
+ cross_signing_key = await self.get_e2e_cross_signing_key(
+ user, "self_signing"
)
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
@@ -202,7 +192,7 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
- results = yield self._get_device_update_edus_by_remote(
+ results = await self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
@@ -215,16 +205,21 @@ class DeviceWorkerStore(SQLBaseStore):
return now_stream_id, results
def _get_device_updates_by_remote_txn(
- self, txn, destination, from_stream_id, now_stream_id, limit
+ self,
+ txn: LoggingTransaction,
+ destination: str,
+ from_stream_id: int,
+ now_stream_id: int,
+ limit: int,
):
"""Return device update information for a given remote destination
Args:
- txn (LoggingTransaction): The transaction to execute
- destination (str): The host the device updates are intended for
- from_stream_id (int): The minimum stream_id to filter updates by, exclusive
- now_stream_id (int): The maximum stream_id to filter updates by, inclusive
- limit (int): Maximum number of device updates to return
+ txn: The transaction to execute
+ destination: The host the device updates are intended for
+ from_stream_id: The minimum stream_id to filter updates by, exclusive
+ now_stream_id: The maximum stream_id to filter updates by, inclusive
+ limit: Maximum number of device updates to return
Returns:
List: List of device updates
@@ -240,23 +235,26 @@ class DeviceWorkerStore(SQLBaseStore):
return list(txn)
- @defer.inlineCallbacks
- def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
+ async def _get_device_update_edus_by_remote(
+ self,
+ destination: str,
+ from_stream_id: int,
+ query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
+ ) -> List[Tuple[str, dict]]:
"""Returns a list of device update EDUs as well as E2EE keys
Args:
- destination (str): The host the device updates are intended for
- from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ destination: The host the device updates are intended for
+ from_stream_id: The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
user_id/device_id to update stream_id and the relevant json-encoded
opentracing context
Returns:
- List[Dict]: List of objects representing an device update EDU
-
+ List of objects representing an device update EDU
"""
devices = (
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
@@ -271,7 +269,7 @@ class DeviceWorkerStore(SQLBaseStore):
for user_id, user_devices in devices.items():
# The prev_id for the first row is always the last row before
# `from_stream_id`
- prev_id = yield self._get_last_device_update_for_remote_user(
+ prev_id = await self._get_last_device_update_for_remote_user(
destination, user_id, from_stream_id
)
@@ -315,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore):
return results
def _get_last_device_update_for_remote_user(
- self, destination, user_id, from_stream_id
+ self, destination: str, user_id: str, from_stream_id: int
):
def f(txn):
prev_sent_id_sql = """
@@ -329,7 +327,7 @@ class DeviceWorkerStore(SQLBaseStore):
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
- def mark_as_sent_devices_by_remote(self, destination, stream_id):
+ def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
"""Mark that updates have successfully been sent to the destination.
"""
return self.db_pool.runInteraction(
@@ -339,7 +337,9 @@ class DeviceWorkerStore(SQLBaseStore):
stream_id,
)
- def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+ def _mark_as_sent_devices_by_remote_txn(
+ self, txn: LoggingTransaction, destination: str, stream_id: int
+ ) -> None:
# We update the device_lists_outbound_last_success with the successfully
# poked users.
sql = """
@@ -367,17 +367,21 @@ class DeviceWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (destination, stream_id))
- @defer.inlineCallbacks
- def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+ async def add_user_signature_change_to_streams(
+ self, from_user_id: str, user_ids: List[str]
+ ) -> int:
"""Persist that a user has made new signatures
Args:
- from_user_id (str): the user who made the signatures
- user_ids (list[str]): the users who were signed
+ from_user_id: the user who made the signatures
+ user_ids: the users who were signed
+
+ Returns:
+ THe new stream ID.
"""
with self._device_list_id_gen.get_next() as stream_id:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
from_user_id,
@@ -386,7 +390,13 @@ class DeviceWorkerStore(SQLBaseStore):
)
return stream_id
- def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+ def _add_user_signature_change_txn(
+ self,
+ txn: LoggingTransaction,
+ from_user_id: str,
+ user_ids: List[str],
+ stream_id: int,
+ ) -> None:
txn.call_after(
self._user_signature_stream_cache.entity_has_changed,
from_user_id,
@@ -402,29 +412,30 @@ class DeviceWorkerStore(SQLBaseStore):
},
)
- def get_device_stream_token(self):
+ def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
@trace
- @defer.inlineCallbacks
- def get_user_devices_from_cache(self, query_list):
+ async def get_user_devices_from_cache(
+ self, query_list: List[Tuple[str, str]]
+ ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
- query_list(list): List of (user_id, device_ids), if device_ids is
+ query_list: List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
- (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
- a set of user_ids and results_map is a mapping of
- user_id -> device_id -> device_info
+ A tuple of (user_ids_not_in_cache, results_map), where
+ user_ids_not_in_cache is a set of user_ids and results_map is a
+ mapping of user_id -> device_id -> device_info.
"""
user_ids = {user_id for user_id, _ in query_list}
- user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
- users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
+ users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
user_ids
)
user_ids_in_cache = {
@@ -438,19 +449,19 @@ class DeviceWorkerStore(SQLBaseStore):
continue
if device_id:
- device = yield self._get_cached_user_device(user_id, device_id)
+ device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
- results[user_id] = yield self.get_cached_devices_for_user(user_id)
+ results[user_id] = await self.get_cached_devices_for_user(user_id)
set_tag("in_cache", results)
set_tag("not_in_cache", user_ids_not_in_cache)
return user_ids_not_in_cache, results
- @cachedInlineCallbacks(num_args=2, tree=True)
- def _get_cached_user_device(self, user_id, device_id):
- content = yield self.db_pool.simple_select_one_onecol(
+ @cached(num_args=2, tree=True)
+ async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
+ content = await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
@@ -458,9 +469,9 @@ class DeviceWorkerStore(SQLBaseStore):
)
return db_to_json(content)
- @cachedInlineCallbacks()
- def get_cached_devices_for_user(self, user_id):
- devices = yield self.db_pool.simple_select_list(
+ @cached()
+ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+ devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
@@ -470,11 +481,11 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
- def get_devices_with_keys_by_user(self, user_id):
+ def get_devices_with_keys_by_user(self, user_id: str):
"""Get all devices (with any device keys) for a user
Returns:
- (stream_id, devices)
+ Deferred which resolves to (stream_id, devices)
"""
return self.db_pool.runInteraction(
"get_devices_with_keys_by_user",
@@ -482,7 +493,9 @@ class DeviceWorkerStore(SQLBaseStore):
user_id,
)
- def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+ def _get_devices_with_keys_by_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
@@ -515,17 +528,18 @@ class DeviceWorkerStore(SQLBaseStore):
return now_stream_id, []
- def get_users_whose_devices_changed(self, from_key, user_ids):
+ async def get_users_whose_devices_changed(
+ self, from_key: str, user_ids: Iterable[str]
+ ) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Args:
- from_key (str): The device lists stream token
- user_ids (Iterable[str])
+ from_key: The device lists stream token
+ user_ids: The user IDs to query for devices.
Returns:
- Deferred[set[str]]: The set of user_ids whose devices have changed
- since `from_key`
+ The set of user_ids whose devices have changed since `from_key`
"""
from_key = int(from_key)
@@ -536,7 +550,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
if not to_check:
- return defer.succeed(set())
+ return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
@@ -556,18 +570,22 @@ class DeviceWorkerStore(SQLBaseStore):
return changes
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
)
- @defer.inlineCallbacks
- def get_users_whose_signatures_changed(self, user_id, from_key):
+ async def get_users_whose_signatures_changed(
+ self, user_id: str, from_key: str
+ ) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
Args:
- user_id (str): the user who made the signatures
- from_key (str): The device lists stream token
+ user_id: the user who made the signatures
+ from_key: The device lists stream token
+
+ Returns:
+ A set of user IDs with updated signatures.
"""
from_key = int(from_key)
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
@@ -575,7 +593,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT DISTINCT user_ids FROM user_signature_stream
WHERE from_user_id = ? AND stream_id > ?
"""
- rows = yield self.db_pool.execute(
+ rows = await self.db_pool.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
@@ -638,7 +656,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
- def get_device_list_last_stream_id_for_remote(self, user_id):
+ def get_device_list_last_stream_id_for_remote(self, user_id: str):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
@@ -655,7 +673,7 @@ class DeviceWorkerStore(SQLBaseStore):
list_name="user_ids",
inlineCallbacks=True,
)
- def get_device_list_last_stream_id_for_remotes(self, user_ids):
+ def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = yield self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
@@ -669,8 +687,7 @@ class DeviceWorkerStore(SQLBaseStore):
return results
- @defer.inlineCallbacks
- def get_user_ids_requiring_device_list_resync(
+ async def get_user_ids_requiring_device_list_resync(
self, user_ids: Optional[Collection[str]] = None,
) -> Set[str]:
"""Given a list of remote users return the list of users that we
@@ -681,7 +698,7 @@ class DeviceWorkerStore(SQLBaseStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
column="user_id",
iterable=user_ids,
@@ -689,7 +706,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable",
)
else:
- rows = yield self.db_pool.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
retcols=("user_id",),
@@ -710,7 +727,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
- def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+ def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
"""Mark that we no longer track device lists for remote user.
"""
@@ -779,16 +796,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
"drop_device_lists_outbound_last_success_non_unique_idx",
)
- @defer.inlineCallbacks
- def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
+ async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
- yield self.db_pool.runWithConnection(f)
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.runWithConnection(f)
+ await self.db_pool.updates._end_background_update(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
)
return 1
@@ -868,18 +884,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
- @defer.inlineCallbacks
- def store_device(self, user_id, device_id, initial_device_display_name):
+ async def store_device(
+ self, user_id: str, device_id: str, initial_device_display_name: str
+ ) -> bool:
"""Ensure the given device is known; add it to the store if not
Args:
- user_id (str): id of user associated with the device
- device_id (str): id of device
- initial_device_display_name (str): initial displayname of the
- device. Ignored if device exists.
+ user_id: id of user associated with the device
+ device_id: id of device
+ initial_device_display_name: initial displayname of the device.
+ Ignored if device exists.
+
Returns:
- defer.Deferred: boolean whether the device was inserted or an
- existing device existed with that ID.
+ Whether the device was inserted or an existing device existed with that ID.
+
Raises:
StoreError: if the device is already in use
"""
@@ -888,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return False
try:
- inserted = yield self.db_pool.simple_insert(
+ inserted = await self.db_pool.simple_insert(
"devices",
values={
"user_id": user_id,
@@ -902,7 +920,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not inserted:
# if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else
- hidden = yield self.db_pool.simple_select_one_onecol(
+ hidden = await self.db_pool.simple_select_one_onecol(
"devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="hidden",
@@ -927,17 +945,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
raise StoreError(500, "Problem storing device.")
- @defer.inlineCallbacks
- def delete_device(self, user_id, device_id):
+ async def delete_device(self, user_id: str, device_id: str) -> None:
"""Delete a device.
Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to delete
- Returns:
- defer.Deferred
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to delete
"""
- yield self.db_pool.simple_delete_one(
+ await self.db_pool.simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device",
@@ -945,17 +960,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.device_id_exists_cache.invalidate((user_id, device_id))
- @defer.inlineCallbacks
- def delete_devices(self, user_id, device_ids):
+ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
"""Deletes several devices.
Args:
- user_id (str): The ID of the user which owns the devices
- device_ids (list): The IDs of the devices to delete
- Returns:
- defer.Deferred
+ user_id: The ID of the user which owns the devices
+ device_ids: The IDs of the devices to delete
"""
- yield self.db_pool.simple_delete_many(
+ await self.db_pool.simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
@@ -965,26 +977,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
- def update_device(self, user_id, device_id, new_display_name=None):
+ async def update_device(
+ self, user_id: str, device_id: str, new_display_name: Optional[str] = None
+ ) -> None:
"""Update a device. Only updates the device if it is not marked as
hidden.
Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to update
- new_display_name (str|None): new displayname for device; None
- to leave unchanged
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to update
+ new_display_name: new displayname for device; None to leave unchanged
Raises:
StoreError: if the device is not found
- Returns:
- defer.Deferred
"""
updates = {}
if new_display_name is not None:
updates["display_name"] = new_display_name
if not updates:
- return defer.succeed(None)
- return self.db_pool.simple_update_one(
+ return None
+ await self.db_pool.simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates,
@@ -992,7 +1003,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def update_remote_device_list_cache_entry(
- self, user_id, device_id, content, stream_id
+ self, user_id: str, device_id: str, content: JsonDict, stream_id: int
):
"""Updates a single device in the cache of a remote user's devicelist.
@@ -1000,10 +1011,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device list.
Args:
- user_id (str): User to update device list for
- device_id (str): ID of decivice being updated
- content (dict): new data on this device
- stream_id (int): the version of the device list
+ user_id: User to update device list for
+ device_id: ID of decivice being updated
+ content: new data on this device
+ stream_id: the version of the device list
Returns:
Deferred[None]
@@ -1018,8 +1029,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def _update_remote_device_list_cache_entry_txn(
- self, txn, user_id, device_id, content, stream_id
- ):
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ content: JsonDict,
+ stream_id: int,
+ ) -> None:
if content.get("deleted"):
self.db_pool.simple_delete_txn(
txn,
@@ -1055,16 +1071,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
lock=False,
)
- def update_remote_device_list_cache(self, user_id, devices, stream_id):
+ def update_remote_device_list_cache(
+ self, user_id: str, devices: List[dict], stream_id: int
+ ):
"""Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's
device list.
Args:
- user_id (str): User to update device list for
- devices (list[dict]): list of device objects supplied over federation
- stream_id (int): the version of the device list
+ user_id: User to update device list for
+ devices: list of device objects supplied over federation
+ stream_id: the version of the device list
Returns:
Deferred[None]
@@ -1077,7 +1095,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_id,
)
- def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
+ def _update_remote_device_list_cache_txn(
+ self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
+ ):
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
@@ -1118,8 +1138,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
)
- @defer.inlineCallbacks
- def add_device_change_to_streams(self, user_id, device_ids, hosts):
+ async def add_device_change_to_streams(
+ self, user_id: str, device_ids: Collection[str], hosts: List[str]
+ ):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
@@ -1127,7 +1148,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
user_id,
@@ -1142,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn,
user_id,
@@ -1187,7 +1208,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def _add_device_outbound_poke_to_stream_txn(
- self, txn, user_id, device_ids, hosts, stream_ids, context,
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ hosts: List[str],
+ stream_ids: List[str],
+ context: Dict[str, str],
):
for host in hosts:
txn.call_after(
@@ -1219,7 +1246,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
- def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
+ def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers.
|