diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index add4e3ea0e..fdf394c612 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with await self._device_list_id_gen.get_next() as stream_id:
+ async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore):
}
async def get_users_whose_devices_changed(
- self, from_key: str, user_ids: Iterable[str]
+ self, from_key: int, user_ids: Iterable[str]
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
@@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
The set of user_ids whose devices have changed since `from_key`
"""
- from_key = int(from_key)
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
@@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
async def get_users_whose_signatures_changed(
- self, user_id: str, from_key: str
+ self, user_id: str, from_key: int
) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
@@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
A set of user IDs with updated signatures.
"""
- from_key = int(from_key)
+
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
sql = """
SELECT DISTINCT user_ids FROM user_signature_stream
@@ -702,7 +701,7 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx",
@@ -827,7 +826,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
@@ -1094,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@@ -1109,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
|