diff --git a/changelog.d/8204.misc b/changelog.d/8204.misc
new file mode 100644
index 0000000000..979c8b227b
--- /dev/null
+++ b/changelog.d/8204.misc
@@ -0,0 +1 @@
+Refactor queries for device keys and cross-signatures.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index db417d60de..ee4666337a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler):
return result
async def on_federation_query_user_devices(self, user_id):
- stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
+ stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
+ user_id
+ )
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = await self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 596c72eb92..3b788c9625 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 70cf15dd7f..e6536c8456 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -264,6 +264,9 @@ class DataStore(
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index def96637a2..e8379c73c4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
"""
- now_stream_id = self._device_list_id_gen.get_current_token()
+ now_stream_id = self.get_device_stream_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -412,8 +413,10 @@ class DeviceWorkerStore(SQLBaseStore):
},
)
+ @abc.abstractmethod
def get_device_stream_token(self) -> int:
- return self._device_list_id_gen.get_current_token()
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
@trace
async def get_user_devices_from_cache(
@@ -481,51 +484,6 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
- def get_devices_with_keys_by_user(self, user_id: str):
- """Get all devices (with any device keys) for a user
-
- Returns:
- Deferred which resolves to (stream_id, devices)
- """
- return self.db_pool.runInteraction(
- "get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn,
- user_id,
- )
-
- def _get_devices_with_keys_by_user_txn(
- self, txn: LoggingTransaction, user_id: str
- ) -> Tuple[int, List[JsonDict]]:
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
-
- if devices:
- user_devices = devices[user_id]
- results = []
- for device_id, device in user_devices.items():
- result = {"device_id": device_id}
-
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
-
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
-
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
-
- results.append(result)
-
- return now_stream_id, results
-
- return now_stream_id, []
-
async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
) -> Set[str]:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 50ecddf7fa..fb3b1f94de 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -22,7 +23,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -33,6 +34,51 @@ if TYPE_CHECKING:
class EndToEndKeyWorkerStore(SQLBaseStore):
+ def get_e2e_device_keys_for_federation_query(self, user_id: str):
+ """Get all devices (with any device keys) for a user
+
+ Returns:
+ Deferred which resolves to (stream_id, devices)
+ """
+ return self.db_pool.runInteraction(
+ "get_e2e_device_keys_for_federation_query",
+ self._get_e2e_device_keys_for_federation_query_txn,
+ user_id,
+ )
+
+ def _get_e2e_device_keys_for_federation_query_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
+ now_stream_id = self.get_device_stream_token()
+
+ devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
+
+ if devices:
+ user_devices = devices[user_id]
+ results = []
+ for device_id, device in user_devices.items():
+ result = {"device_id": device_id}
+
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+
+ if "signatures" in device:
+ for sig_user_id, sigs in device["signatures"].items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
@trace
async def get_e2e_device_keys_for_cs_api(
self, query_list: List[Tuple[str, Optional[str]]]
@@ -533,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
_get_all_user_signature_changes_for_remotes_txn,
)
+ @abc.abstractmethod
+ def get_device_stream_token(self) -> int:
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|