diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index c938339ddd..5fe102e2f2 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -27,9 +27,9 @@ from twisted.internet import defer
from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
UserID,
@@ -37,7 +37,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key,
)
from synapse.util import json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, delay_cancellation
+from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@@ -48,33 +49,30 @@ logger = logging.getLogger(__name__)
class E2eKeysHandler:
def __init__(self, hs: "HomeServer"):
+ self.config = hs.config
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
- self._edu_updater = SigningKeyEduUpdater(hs, self)
-
federation_registry = hs.get_federation_registry()
- self._is_master = hs.config.worker.worker_app is None
- if not self._is_master:
- self._user_device_resync_client = (
- ReplicationUserDevicesResyncRestServlet.make_client(hs)
- )
- else:
+ is_master = hs.config.worker.worker_app is None
+ if is_master:
+ edu_updater = SigningKeyEduUpdater(hs)
+
# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
EduTypes.SIGNING_KEY_UPDATE,
- self._edu_updater.incoming_signing_key_update,
+ edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
- self._edu_updater.incoming_signing_key_update,
+ edu_updater.incoming_signing_key_update,
)
# doesn't really work as part of the generic query API, because the
@@ -91,6 +89,7 @@ class E2eKeysHandler:
)
@trace
+ @cancellable
async def query_devices(
self,
query_body: JsonDict,
@@ -173,6 +172,35 @@ class E2eKeysHandler:
user_ids_not_in_cache,
remote_results,
) = await self.store.get_user_devices_from_cache(query_list)
+
+ # Check that the homeserver still shares a room with all cached users.
+ # Note that this check may be slightly racy when a remote user leaves a
+ # room after we have fetched their cached device list. In the worst case
+ # we will do extra federation queries for devices that we had cached.
+ cached_users = set(remote_results.keys())
+ valid_cached_users = (
+ await self.store.get_users_server_still_shares_room_with(
+ remote_results.keys()
+ )
+ )
+ invalid_cached_users = cached_users - valid_cached_users
+ if invalid_cached_users:
+ # Fix up results. If we get here, it means there was either a bug in
+ # device list tracking, or we hit the race mentioned above.
+ # TODO: In practice, this path is hit fairly often in existing
+ # deployments when clients query the keys of departed remote
+ # users. A background update to mark the appropriate device
+ # lists as unsubscribed is needed.
+ # https://github.com/matrix-org/synapse/issues/13651
+ # Note that this currently introduces a failure mode when clients
+ # are trying to decrypt old messages from a remote user whose
+ # homeserver is no longer available. We may want to consider falling
+ # back to the cached data when we fail to retrieve a device list
+ # over federation for such remote users.
+ user_ids_not_in_cache.update(invalid_cached_users)
+ for invalid_user_id in invalid_cached_users:
+ remote_results.pop(invalid_user_id)
+
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items():
@@ -208,22 +236,26 @@ class E2eKeysHandler:
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
+ # TODO It might make sense to propagate cancellations into the
+ # deferreds which are querying remote homeservers.
await make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self._query_devices_for_destination,
- results,
- cross_signing_keys,
- failures,
- destination,
- queries,
- timeout,
- )
- for destination, queries in remote_queries_not_in_cache.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ delay_cancellation(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self._query_devices_for_destination,
+ results,
+ cross_signing_keys,
+ failures,
+ destination,
+ queries,
+ timeout,
+ )
+ for destination, queries in remote_queries_not_in_cache.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
)
ret = {"device_keys": results, "failures": failures}
@@ -283,14 +315,13 @@ class E2eKeysHandler:
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
- if self._is_master:
- resync_results = await self.device_handler.device_list_updater.user_device_resync(
+ resync_results = (
+ await self.device_handler.device_list_updater.user_device_resync(
user_id
)
- else:
- resync_results = await self._user_device_resync_client(
- user_id=user_id
- )
+ )
+ if resync_results is None:
+ raise ValueError("Device resync failed")
# Add the device keys to the results.
user_devices = resync_results["devices"]
@@ -347,6 +378,7 @@ class E2eKeysHandler:
return
+ @cancellable
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
@@ -393,14 +425,19 @@ class E2eKeysHandler:
}
@trace
+ @cancellable
async def query_local_devices(
- self, query: Mapping[str, Optional[List[str]]]
+ self,
+ query: Mapping[str, Optional[List[str]]],
+ include_displaynames: bool = True,
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
Args:
query: map from user_id to a list
of devices to query (None for all devices)
+ include_displaynames: Whether to include device displaynames in the returned
+ device details.
Returns:
A map from user_id -> device_id -> device details
@@ -432,7 +469,9 @@ class E2eKeysHandler:
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
- results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
+ results = await self.store.get_e2e_device_keys_for_cs_api(
+ local_query, include_displaynames
+ )
# Build the result structure
for user_id, device_keys in results.items():
@@ -445,11 +484,33 @@ class E2eKeysHandler:
async def on_federation_query_client_keys(
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
- """Handle a device key query from a federated server"""
+ """Handle a device key query from a federated server:
+
+ Handles the path: GET /_matrix/federation/v1/users/keys/query
+
+ Args:
+ query_body: The body of the query request. Should contain a key
+ "device_keys" that map to a dictionary of user ID's -> list of
+ device IDs. If the list of device IDs is empty, all devices of
+ that user will be queried.
+
+ Returns:
+ A json dictionary containing the following:
+ - device_keys: A dictionary containing the requested device information.
+ - master_keys: An optional dictionary of user ID -> master cross-signing
+ key info.
+ - self_signing_key: An optional dictionary of user ID -> self-signing
+ key info.
+ """
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {}
)
- res = await self.query_local_devices(device_keys_query)
+ res = await self.query_local_devices(
+ device_keys_query,
+ include_displaynames=(
+ self.config.federation.allow_device_name_lookup_over_federation
+ ),
+ )
ret = {"device_keys": res}
# add in the cross-signing keys
@@ -539,6 +600,8 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec()
@@ -666,6 +729,8 @@ class E2eKeysHandler:
user_id: the user uploading the keys
keys: the signing keys
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
# if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys
@@ -757,6 +822,9 @@ class E2eKeysHandler:
Raises:
SynapseError: if the signatures dict is not valid.
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
failures = {}
# signatures to be stored. Each item will be a SignatureListItem
@@ -804,7 +872,7 @@ class E2eKeysHandler:
- signatures of the user's master key by the user's devices.
Args:
- user_id (string): the user uploading the keys
+ user_id: the user uploading the keys
signatures (dict[string, dict]): map of devices to signed keys
Returns:
@@ -1134,6 +1202,9 @@ class E2eKeysHandler:
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
try:
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
@@ -1330,11 +1401,14 @@ class SignatureListItem:
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
- def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
- self.e2e_keys_handler = e2e_keys_handler
+
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self._device_handler = device_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
@@ -1379,9 +1453,6 @@ class SigningKeyEduUpdater:
user_id: the user whose updates we are processing
"""
- device_handler = self.e2e_keys_handler.device_handler
- device_list_updater = device_handler.device_list_updater
-
async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
@@ -1393,13 +1464,11 @@ class SigningKeyEduUpdater:
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- new_device_ids = (
- await device_list_updater.process_cross_signing_key_update(
- user_id,
- master_key,
- self_signing_key,
- )
+ new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
+ user_id,
+ master_key,
+ self_signing_key,
)
device_ids = device_ids + new_device_ids
- await device_handler.notify_device_update(user_id, device_ids)
+ await self._device_handler.notify_device_update(user_id, device_ids)
|