diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index def96637a2..f8fe948122 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
"""
- now_stream_id = self._device_list_id_gen.get_current_token()
+ now_stream_id = self.get_device_stream_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -254,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
List of objects representing an device update EDU
"""
devices = (
- await self.db_pool.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
+ await self.get_e2e_device_keys_and_signatures(
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
@@ -292,17 +291,17 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = stream_id
if device is not None:
- key_json = device.get("key_json", None)
+ key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
+ if device.signatures:
+ for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
- device_display_name = device.get("device_display_name", None)
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
@@ -312,9 +311,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results
- def _get_last_device_update_for_remote_user(
+ async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int
- ):
+ ) -> int:
def f(txn):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
@@ -325,12 +324,16 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
- return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
+ return await self.db_pool.runInteraction(
+ "get_last_device_update_for_remote_user", f
+ )
- def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
+ async def mark_as_sent_devices_by_remote(
+ self, destination: str, stream_id: int
+ ) -> None:
"""Mark that updates have successfully been sent to the destination.
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
@@ -412,8 +415,10 @@ class DeviceWorkerStore(SQLBaseStore):
},
)
+ @abc.abstractmethod
def get_device_stream_token(self) -> int:
- return self._device_list_id_gen.get_current_token()
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
@trace
async def get_user_devices_from_cache(
@@ -481,51 +486,6 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
- def get_devices_with_keys_by_user(self, user_id: str):
- """Get all devices (with any device keys) for a user
-
- Returns:
- Deferred which resolves to (stream_id, devices)
- """
- return self.db_pool.runInteraction(
- "get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn,
- user_id,
- )
-
- def _get_devices_with_keys_by_user_txn(
- self, txn: LoggingTransaction, user_id: str
- ) -> Tuple[int, List[JsonDict]]:
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
-
- if devices:
- user_devices = devices[user_id]
- results = []
- for device_id, device in user_devices.items():
- result = {"device_id": device_id}
-
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
-
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
-
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
-
- results.append(result)
-
- return now_stream_id, results
-
- return now_stream_id, []
-
async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
) -> Set[str]:
@@ -726,7 +686,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
- def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
+ async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user.
"""
@@ -740,7 +700,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn,
)
@@ -1001,9 +961,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
desc="update_device",
)
- def update_remote_device_list_cache_entry(
+ async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
- ):
+ ) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's
@@ -1014,11 +974,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: ID of decivice being updated
content: new data on this device
stream_id: the version of the device list
-
- Returns:
- Deferred[None]
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@@ -1070,9 +1027,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
lock=False,
)
- def update_remote_device_list_cache(
+ async def update_remote_device_list_cache(
self, user_id: str, devices: List[dict], stream_id: int
- ):
+ ) -> None:
"""Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's
@@ -1082,11 +1039,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: User to update device list for
devices: list of device objects supplied over federation
stream_id: the version of the device list
-
- Returns:
- Deferred[None]
"""
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@@ -1096,7 +1050,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
- ):
+ ) -> None:
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
|