diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index dc1d4d8fc6..c5ff16adf3 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -130,9 +130,7 @@ class KeyUploadServlet(RestServlet):
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
- self.device_handler.check_device_registered(
- user_id, device_id, "unknown device"
- )
+ self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
@@ -186,17 +184,19 @@ class KeyQueryServlet(RestServlet):
)
def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer):
+ """
super(KeyQueryServlet, self).__init__()
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.federation = hs.get_replication_layer()
- self.is_mine = hs.is_mine
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
- result = yield self.handle_request(body)
+ result = yield self.e2e_keys_handler.query_devices(body)
defer.returnValue(result)
@defer.inlineCallbacks
@@ -205,45 +205,11 @@ class KeyQueryServlet(RestServlet):
auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
- result = yield self.handle_request(
+ result = yield self.e2e_keys_handler.query_devices(
{"device_keys": {user_id: device_ids}}
)
defer.returnValue(result)
- @defer.inlineCallbacks
- def handle_request(self, body):
- local_query = []
- remote_queries = {}
- for user_id, device_ids in body.get("device_keys", {}).items():
- user = UserID.from_string(user_id)
- if self.is_mine(user):
- if not device_ids:
- local_query.append((user_id, None))
- else:
- for device_id in device_ids:
- local_query.append((user_id, device_id))
- else:
- remote_queries.setdefault(user.domain, {})[user_id] = list(
- device_ids
- )
- results = yield self.store.get_e2e_device_keys(local_query)
-
- json_result = {}
- for user_id, device_keys in results.items():
- for device_id, json_bytes in device_keys.items():
- json_result.setdefault(user_id, {})[device_id] = json.loads(
- json_bytes
- )
-
- for destination, device_keys in remote_queries.items():
- remote_result = yield self.federation.query_client_keys(
- destination, {"device_keys": device_keys}
- )
- for user_id, keys in remote_result["device_keys"].items():
- if user_id in device_keys:
- json_result[user_id] = keys
- defer.returnValue((200, {"device_keys": json_result}))
-
class OneTimeKeyServlet(RestServlet):
"""
|