From aa07c37cf0a3b812e6aa1bb2d97d543e6925c8e2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Sep 2020 12:41:21 +0100 Subject: Move and rename `get_devices_with_keys_by_user` (#8204) * Move `get_devices_with_keys_by_user` to `EndToEndKeyWorkerStore` this seems a better fit for it. This commit simply moves the existing code: no other changes at all. * Rename `get_devices_with_keys_by_user` to better reflect what it does. * get_device_stream_token abstract method To avoid referencing fields which are declared in the derived classes, make `get_device_stream_token` abstract, and define that in the classes which define `_device_list_id_gen`. --- synapse/storage/databases/main/__init__.py | 3 ++ synapse/storage/databases/main/devices.py | 52 +++------------------- synapse/storage/databases/main/end_to_end_keys.py | 53 ++++++++++++++++++++++- 3 files changed, 60 insertions(+), 48 deletions(-) (limited to 'synapse/storage/databases/main') diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 70cf15dd7f..e6536c8456 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -264,6 +264,9 @@ class DataStore( # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index def96637a2..e8379c73c4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging from typing import Any, Dict, Iterable, List, Optional, Set, Tuple @@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore): update included in the response), and the list of updates, where each update is a pair of EDU type and EDU contents. """ - now_stream_id = self._device_list_id_gen.get_current_token() + now_stream_id = self.get_device_stream_token() has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) @@ -412,8 +413,10 @@ class DeviceWorkerStore(SQLBaseStore): }, ) + @abc.abstractmethod def get_device_stream_token(self) -> int: - return self._device_list_id_gen.get_current_token() + """Get the current stream id from the _device_list_id_gen""" + ... @trace async def get_user_devices_from_cache( @@ -481,51 +484,6 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id: str): - """Get all devices (with any device keys) for a user - - Returns: - Deferred which resolves to (stream_id, devices) - """ - return self.db_pool.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, - user_id, - ) - - def _get_devices_with_keys_by_user_txn( - self, txn: LoggingTransaction, user_id: str - ) -> Tuple[int, List[JsonDict]]: - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in user_devices.items(): - result = {"device_id": device_id} - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - async def get_users_whose_devices_changed( self, from_key: str, user_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 50ecddf7fa..fb3b1f94de 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -22,7 +23,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -33,6 +34,51 @@ if TYPE_CHECKING: class EndToEndKeyWorkerStore(SQLBaseStore): + def get_e2e_device_keys_for_federation_query(self, user_id: str): + """Get all devices (with any device keys) for a user + + Returns: + Deferred which resolves to (stream_id, devices) + """ + return self.db_pool.runInteraction( + "get_e2e_device_keys_for_federation_query", + self._get_e2e_device_keys_for_federation_query_txn, + user_id, + ) + + def _get_e2e_device_keys_for_federation_query_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: + now_stream_id = self.get_device_stream_token() + + devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.items(): + result = {"device_id": device_id} + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + @trace async def get_e2e_device_keys_for_cs_api( self, query_list: List[Tuple[str, Optional[str]]] @@ -533,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): _get_all_user_signature_changes_for_remotes_txn, ) + @abc.abstractmethod + def get_device_stream_token(self) -> int: + """Get the current stream id from the _device_list_id_gen""" + ... + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): -- cgit 1.4.1