diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 09a2492afc..5fe102e2f2 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -27,9 +27,9 @@ from twisted.internet import defer
from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
UserID,
@@ -49,33 +49,30 @@ logger = logging.getLogger(__name__)
class E2eKeysHandler:
def __init__(self, hs: "HomeServer"):
+ self.config = hs.config
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
- self._edu_updater = SigningKeyEduUpdater(hs, self)
-
federation_registry = hs.get_federation_registry()
- self._is_master = hs.config.worker.worker_app is None
- if not self._is_master:
- self._user_device_resync_client = (
- ReplicationUserDevicesResyncRestServlet.make_client(hs)
- )
- else:
+ is_master = hs.config.worker.worker_app is None
+ if is_master:
+ edu_updater = SigningKeyEduUpdater(hs)
+
# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
EduTypes.SIGNING_KEY_UPDATE,
- self._edu_updater.incoming_signing_key_update,
+ edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
- self._edu_updater.incoming_signing_key_update,
+ edu_updater.incoming_signing_key_update,
)
# doesn't really work as part of the generic query API, because the
@@ -318,14 +315,13 @@ class E2eKeysHandler:
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
- if self._is_master:
- resync_results = await self.device_handler.device_list_updater.user_device_resync(
+ resync_results = (
+ await self.device_handler.device_list_updater.user_device_resync(
user_id
)
- else:
- resync_results = await self._user_device_resync_client(
- user_id=user_id
- )
+ )
+ if resync_results is None:
+ raise ValueError("Device resync failed")
# Add the device keys to the results.
user_devices = resync_results["devices"]
@@ -431,13 +427,17 @@ class E2eKeysHandler:
@trace
@cancellable
async def query_local_devices(
- self, query: Mapping[str, Optional[List[str]]]
+ self,
+ query: Mapping[str, Optional[List[str]]],
+ include_displaynames: bool = True,
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
Args:
query: map from user_id to a list
of devices to query (None for all devices)
+ include_displaynames: Whether to include device displaynames in the returned
+ device details.
Returns:
A map from user_id -> device_id -> device details
@@ -469,7 +469,9 @@ class E2eKeysHandler:
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
- results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
+ results = await self.store.get_e2e_device_keys_for_cs_api(
+ local_query, include_displaynames
+ )
# Build the result structure
for user_id, device_keys in results.items():
@@ -482,11 +484,33 @@ class E2eKeysHandler:
async def on_federation_query_client_keys(
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
- """Handle a device key query from a federated server"""
+ """Handle a device key query from a federated server:
+
+ Handles the path: GET /_matrix/federation/v1/users/keys/query
+
+ Args:
+ query_body: The body of the query request. Should contain a key
+ "device_keys" that map to a dictionary of user ID's -> list of
+ device IDs. If the list of device IDs is empty, all devices of
+ that user will be queried.
+
+ Returns:
+ A json dictionary containing the following:
+ - device_keys: A dictionary containing the requested device information.
+ - master_keys: An optional dictionary of user ID -> master cross-signing
+ key info.
+ - self_signing_key: An optional dictionary of user ID -> self-signing
+ key info.
+ """
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {}
)
- res = await self.query_local_devices(device_keys_query)
+ res = await self.query_local_devices(
+ device_keys_query,
+ include_displaynames=(
+ self.config.federation.allow_device_name_lookup_over_federation
+ ),
+ )
ret = {"device_keys": res}
# add in the cross-signing keys
@@ -576,6 +600,8 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec()
@@ -703,6 +729,8 @@ class E2eKeysHandler:
user_id: the user uploading the keys
keys: the signing keys
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
# if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys
@@ -794,6 +822,9 @@ class E2eKeysHandler:
Raises:
SynapseError: if the signatures dict is not valid.
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
failures = {}
# signatures to be stored. Each item will be a SignatureListItem
@@ -841,7 +872,7 @@ class E2eKeysHandler:
- signatures of the user's master key by the user's devices.
Args:
- user_id (string): the user uploading the keys
+ user_id: the user uploading the keys
signatures (dict[string, dict]): map of devices to signed keys
Returns:
@@ -1171,6 +1202,9 @@ class E2eKeysHandler:
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
try:
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
@@ -1367,11 +1401,14 @@ class SignatureListItem:
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
- def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
- self.e2e_keys_handler = e2e_keys_handler
+
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self._device_handler = device_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
@@ -1416,9 +1453,6 @@ class SigningKeyEduUpdater:
user_id: the user whose updates we are processing
"""
- device_handler = self.e2e_keys_handler.device_handler
- device_list_updater = device_handler.device_list_updater
-
async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
@@ -1430,13 +1464,11 @@ class SigningKeyEduUpdater:
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- new_device_ids = (
- await device_list_updater.process_cross_signing_key_update(
- user_id,
- master_key,
- self_signing_key,
- )
+ new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
+ user_id,
+ master_key,
+ self_signing_key,
)
device_ids = device_ids + new_device_ids
- await device_handler.notify_device_update(user_id, device_ids)
+ await self._device_handler.notify_device_update(user_id, device_ids)
|