diff options
Diffstat (limited to 'synapse/rest/client/v2_alpha')
-rw-r--r-- | synapse/rest/client/v2_alpha/keys.py | 100 |
1 files changed, 70 insertions, 30 deletions
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 5f3a6207b5..739a08ada8 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet +from synapse.types import UserID from syutil.jsonutil import encode_canonical_json from ._base import client_v2_pattern @@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet): super(KeyQueryServlet, self).__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() + self.federation = hs.get_replication_layer() + self.is_mine = hs.is_mine @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): - logger.debug("onPOST") yield self.auth.get_user_by_req(request) try: body = json.loads(request.content.read()) except: raise SynapseError(400, "Invalid key JSON") - query = [] - for user_id, device_ids in body.get("device_keys", {}).items(): - if not device_ids: - query.append((user_id, None)) - else: - for device_id in device_ids: - query.append((user_id, device_id)) - results = yield self.store.get_e2e_device_keys(query) - defer.returnValue(self.json_result(request, results)) + result = yield self.handle_request(body) + defer.returnValue(result) @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): auth_user, client_info = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() - if not user_id: - user_id = auth_user_id - if not device_id: - device_id = None - # Returns a map of user_id->device_id->json_bytes. - results = yield self.store.get_e2e_device_keys([(user_id, device_id)]) - defer.returnValue(self.json_result(request, results)) - - def json_result(self, request, results): + user_id = user_id if user_id else auth_user_id + device_ids = [device_id] if device_id else [] + result = yield self.handle_request( + {"device_keys": {user_id: device_ids}} + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_request(self, body): + local_query = [] + remote_queries = {} + for user_id, device_ids in body.get("device_keys", {}).items(): + user = UserID.from_string(user_id) + if self.is_mine(user): + if not device_ids: + local_query.append((user_id, None)) + else: + for device_id in device_ids: + local_query.append((user_id, device_id)) + else: + remote_queries.set_default(user.domain, {})[user_id] = list( + device_ids + ) + results = yield self.store.get_e2e_device_keys(local_query) + json_result = {} for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): json_result.setdefault(user_id, {})[device_id] = json.loads( json_bytes ) - return (200, {"device_keys": json_result}) + + for destination, device_keys in remote_queries.items(): + remote_result = yield self.federation.query_client_keys( + destination, {"device_keys": device_keys} + ) + for user_id, keys in remote_result.items(): + if user_id in device_keys: + json_result[user_id] = keys + defer.returnValue((200, {"device_keys": json_result})) class OneTimeKeyServlet(RestServlet): @@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() + self.federation = hs.get_replication_layer() + self.is_mine = hs.is_mine @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) - results = yield self.store.claim_e2e_one_time_keys( - [(user_id, device_id, algorithm)] + result = yield self.handle_request( + {"one_time_keys": {user_id: {device_id: algorithm}}} ) - defer.returnValue(self.json_result(request, results)) + defer.returnValue(result) @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): @@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet): body = json.loads(request.content.read()) except: raise SynapseError(400, "Invalid key JSON") - query = [] + result = yield self.handle_request(body) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_request(self, body): + local_query = [] + remote_queries = {} for user_id, device_keys in body.get("one_time_keys", {}).items(): - for device_id, algorithm in device_keys.items(): - query.append((user_id, device_id, algorithm)) - results = yield self.store.claim_e2e_one_time_keys(query) - defer.returnValue(self.json_result(request, results)) + user = UserID.from_string(user_id) + if self.is_mine(user): + for device_id, algorithm in device_keys.items(): + local_query.append((user_id, device_id, algorithm)) + else: + remote_queries.set_default(user.domain, {})[user_id] = ( + device_keys + ) + results = yield self.store.claim_e2e_one_time_keys(local_query) - def json_result(self, request, results): json_result = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): @@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet): json_result.setdefault(user_id, {})[device_id] = { key_id: json.loads(json_bytes) } - return (200, {"one_time_keys": json_result}) + + for destination, device_keys in remote_queries.items(): + remote_result = yield self.federation.query_client_keys( + destination, {"one_time_keys": device_keys} + ) + for user_id, keys in remote_result.items(): + if user_id in device_keys: + json_result[user_id] = keys + + defer.returnValue((200, {"one_time_keys": json_result})) def register_servlets(hs, http_server): |