diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 1ece54ccfc..668cec513b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -35,6 +35,7 @@ from synapse.api.errors import CodeMessageException, Codes, NotFoundError, Synap
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
+from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.types import (
JsonDict,
JsonMapping,
@@ -45,7 +46,10 @@ from synapse.types import (
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.cancellation import cancellable
-from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.retryutils import (
+ NotRetryingDestination,
+ filter_destinations_by_retry_limiter,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -53,6 +57,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
+
+
class E2eKeysHandler:
def __init__(self, hs: "HomeServer"):
self.config = hs.config
@@ -62,6 +69,7 @@ class E2eKeysHandler:
self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
+ self._worker_lock_handler = hs.get_worker_locks_handler()
federation_registry = hs.get_federation_registry()
@@ -82,6 +90,12 @@ class E2eKeysHandler:
edu_updater.incoming_signing_key_update,
)
+ self.device_key_uploader = self.upload_device_keys_for_user
+ else:
+ self.device_key_uploader = (
+ ReplicationUploadKeysForUserRestServlet.make_client(hs)
+ )
+
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
@@ -145,6 +159,11 @@ class E2eKeysHandler:
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
+ if not UserID.is_valid(user_id):
+ # Ignore invalid user IDs, which is the same behaviour as if
+ # the user existed but had no keys.
+ continue
+
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
@@ -259,10 +278,8 @@ class E2eKeysHandler:
"%d destinations to query devices for", len(remote_queries_not_in_cache)
)
- async def _query(
- destination_queries: Tuple[str, Dict[str, Iterable[str]]]
- ) -> None:
- destination, queries = destination_queries
+ async def _query(destination: str) -> None:
+ queries = remote_queries_not_in_cache[destination]
return await self._query_devices_for_destination(
results,
cross_signing_keys,
@@ -272,9 +289,20 @@ class E2eKeysHandler:
timeout,
)
+ # Only try and fetch keys for destinations that are not marked as
+ # down.
+ filtered_destinations = await filter_destinations_by_retry_limiter(
+ remote_queries_not_in_cache.keys(),
+ self.clock,
+ self.store,
+ # Let's give an arbitrary grace period for those hosts that are
+ # only recently down
+ retry_due_within_ms=60 * 1000,
+ )
+
await concurrently_execute(
_query,
- remote_queries_not_in_cache.items(),
+ filtered_destinations,
10,
delay_cancellation=True,
)
@@ -775,36 +803,17 @@ class E2eKeysHandler:
"one_time_keys": A mapping from algorithm to number of keys for that
algorithm, including those previously persisted.
"""
- # This can only be called from the main process.
- assert isinstance(self.device_handler, DeviceHandler)
-
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
- logger.info(
- "Updating device_keys for device %r for user %s at %d",
- device_id,
- user_id,
- time_now,
- )
- log_kv(
- {
- "message": "Updating device_keys for user.",
- "user_id": user_id,
- "device_id": device_id,
- }
- )
- # TODO: Sign the JSON with the server key
- changed = await self.store.set_e2e_device_keys(
- user_id, device_id, time_now, device_keys
+ await self.device_key_uploader(
+ user_id=user_id,
+ device_id=device_id,
+ keys={"device_keys": device_keys},
)
- if changed:
- # Only notify about device updates *if* the keys actually changed
- await self.device_handler.notify_device_update(user_id, [device_id])
- else:
- log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
+
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
log_kv(
@@ -840,60 +849,106 @@ class E2eKeysHandler:
{"message": "Did not update fallback_keys", "reason": "no keys given"}
)
- # the device should have been registered already, but it may have been
- # deleted due to a race with a DELETE request. Or we may be using an
- # old access_token without an associated device_id. Either way, we
- # need to double-check the device is registered to avoid ending up with
- # keys without a corresponding device.
- await self.device_handler.check_device_registered(user_id, device_id)
-
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
- async def _upload_one_time_keys_for_user(
- self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+ @tag_args
+ async def upload_device_keys_for_user(
+ self, user_id: str, device_id: str, keys: JsonDict
) -> None:
+ """
+ Args:
+ user_id: user whose keys are being uploaded.
+ device_id: device whose keys are being uploaded.
+ device_keys: the `device_keys` of an /keys/upload request.
+
+ """
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
+ time_now = self.clock.time_msec()
+
+ device_keys = keys["device_keys"]
logger.info(
- "Adding one_time_keys %r for device %r for user %r at %d",
- one_time_keys.keys(),
+ "Updating device_keys for device %r for user %s at %d",
device_id,
user_id,
time_now,
)
+ log_kv(
+ {
+ "message": "Updating device_keys for user.",
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ )
+ # TODO: Sign the JSON with the server key
+ changed = await self.store.set_e2e_device_keys(
+ user_id, device_id, time_now, device_keys
+ )
+ if changed:
+ # Only notify about device updates *if* the keys actually changed
+ await self.device_handler.notify_device_update(user_id, [device_id])
- # make a list of (alg, id, key) tuples
- key_list = []
- for key_id, key_obj in one_time_keys.items():
- algorithm, key_id = key_id.split(":")
- key_list.append((algorithm, key_id, key_obj))
+ # the device should have been registered already, but it may have been
+ # deleted due to a race with a DELETE request. Or we may be using an
+ # old access_token without an associated device_id. Either way, we
+ # need to double-check the device is registered to avoid ending up with
+ # keys without a corresponding device.
+ await self.device_handler.check_device_registered(user_id, device_id)
- # First we check if we have already persisted any of the keys.
- existing_key_map = await self.store.get_e2e_one_time_keys(
- user_id, device_id, [k_id for _, k_id, _ in key_list]
- )
+ async def _upload_one_time_keys_for_user(
+ self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+ ) -> None:
+ # We take out a lock so that we don't have to worry about a client
+ # sending duplicate requests.
+ lock_key = f"{user_id}_{device_id}"
+ async with self._worker_lock_handler.acquire_lock(
+ ONE_TIME_KEY_UPLOAD, lock_key
+ ):
+ logger.info(
+ "Adding one_time_keys %r for device %r for user %r at %d",
+ one_time_keys.keys(),
+ device_id,
+ user_id,
+ time_now,
+ )
- new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
- for algorithm, key_id, key in key_list:
- ex_json = existing_key_map.get((algorithm, key_id), None)
- if ex_json:
- if not _one_time_keys_match(ex_json, key):
- raise SynapseError(
- 400,
- (
- "One time key %s:%s already exists. "
- "Old key: %s; new key: %r"
+ # make a list of (alg, id, key) tuples
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append((algorithm, key_id, key_obj))
+
+ # First we check if we have already persisted any of the keys.
+ existing_key_map = await self.store.get_e2e_one_time_keys(
+ user_id, device_id, [k_id for _, k_id, _ in key_list]
+ )
+
+ new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
+ for algorithm, key_id, key in key_list:
+ ex_json = existing_key_map.get((algorithm, key_id), None)
+ if ex_json:
+ if not _one_time_keys_match(ex_json, key):
+ raise SynapseError(
+ 400,
+ (
+ "One time key %s:%s already exists. "
+ "Old key: %s; new key: %r"
+ )
+ % (algorithm, key_id, ex_json, key),
)
- % (algorithm, key_id, ex_json, key),
+ else:
+ new_keys.append(
+ (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
)
- else:
- new_keys.append(
- (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
- )
- log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
- await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
+ log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
+ await self.store.add_e2e_one_time_keys(
+ user_id, device_id, time_now, new_keys
+ )
async def upload_signing_keys_for_user(
self, user_id: str, keys: JsonDict
|