diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5bfd700931..60a1acc6f4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
+import ujson as json
import logging
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException
@@ -29,7 +30,9 @@ class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
+ self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id
+ self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
@@ -103,9 +106,9 @@ class E2eKeysHandler(object):
for destination in remote_queries
]))
- defer.returnValue((200, {
+ defer.returnValue({
"device_keys": results, "failures": failures,
- }))
+ })
@defer.inlineCallbacks
def query_local_devices(self, query):
@@ -159,3 +162,99 @@ class E2eKeysHandler(object):
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})
+
+ @defer.inlineCallbacks
+ def claim_one_time_keys(self, query, timeout):
+ local_query = []
+ remote_queries = {}
+
+ for user_id, device_keys in query.get("one_time_keys", {}).items():
+ if self.is_mine_id(user_id):
+ for device_id, algorithm in device_keys.items():
+ local_query.append((user_id, device_id, algorithm))
+ else:
+ domain = get_domain_from_id(user_id)
+ remote_queries.setdefault(domain, {})[user_id] = device_keys
+
+ results = yield self.store.claim_e2e_one_time_keys(local_query)
+
+ json_result = {}
+ failures = {}
+ for user_id, device_keys in results.items():
+ for device_id, keys in device_keys.items():
+ for key_id, json_bytes in keys.items():
+ json_result.setdefault(user_id, {})[device_id] = {
+ key_id: json.loads(json_bytes)
+ }
+
+ @defer.inlineCallbacks
+ def claim_client_keys(destination):
+ device_keys = remote_queries[destination]
+ try:
+ remote_result = yield self.federation.claim_client_keys(
+ destination,
+ {"one_time_keys": device_keys},
+ timeout=timeout
+ )
+ for user_id, keys in remote_result["one_time_keys"].items():
+ if user_id in device_keys:
+ json_result[user_id] = keys
+ except CodeMessageException as e:
+ failures[destination] = {
+ "status": e.code, "message": e.message
+ }
+
+ yield preserve_context_over_deferred(defer.gatherResults([
+ preserve_fn(claim_client_keys)(destination)
+ for destination in remote_queries
+ ]))
+
+ defer.returnValue({
+ "one_time_keys": json_result,
+ "failures": failures
+ })
+
+ @defer.inlineCallbacks
+ def upload_keys_for_user(self, user_id, device_id, keys):
+ time_now = self.clock.time_msec()
+
+ # TODO: Validate the JSON to make sure it has the right keys.
+ device_keys = keys.get("device_keys", None)
+ if device_keys:
+ logger.info(
+ "Updating device_keys for device %r for user %s at %d",
+ device_id, user_id, time_now
+ )
+ # TODO: Sign the JSON with the server key
+ yield self.store.set_e2e_device_keys(
+ user_id, device_id, time_now,
+ encode_canonical_json(device_keys)
+ )
+
+ one_time_keys = keys.get("one_time_keys", None)
+ if one_time_keys:
+ logger.info(
+ "Adding %d one_time_keys for device %r for user %r at %d",
+ len(one_time_keys), device_id, user_id, time_now
+ )
+ key_list = []
+ for key_id, key_json in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append((
+ algorithm, key_id, encode_canonical_json(key_json)
+ ))
+
+ yield self.store.add_e2e_one_time_keys(
+ user_id, device_id, time_now, key_list
+ )
+
+ # the device should have been registered already, but it may have been
+ # deleted due to a race with a DELETE request. Or we may be using an
+ # old access_token without an associated device_id. Either way, we
+ # need to double-check the device is registered to avoid ending up with
+ # keys without a corresponding device.
+ self.device_handler.check_device_registered(user_id, device_id)
+
+ result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+
+ defer.returnValue({"one_time_key_counts": result})
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 8f05727652..f185f9a774 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -15,16 +15,12 @@
import logging
-import simplejson as json
-from canonicaljson import encode_canonical_json
from twisted.internet import defer
-from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, parse_integer
)
-from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -64,17 +60,13 @@ class KeyUploadServlet(RestServlet):
hs (synapse.server.HomeServer): server
"""
super(KeyUploadServlet, self).__init__()
- self.store = hs.get_datastore()
- self.clock = hs.get_clock()
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks
def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
-
user_id = requester.user.to_string()
-
body = parse_json_object_from_request(request)
if device_id is not None:
@@ -94,47 +86,10 @@ class KeyUploadServlet(RestServlet):
"To upload keys, you must pass device_id when authenticating"
)
- time_now = self.clock.time_msec()
-
- # TODO: Validate the JSON to make sure it has the right keys.
- device_keys = body.get("device_keys", None)
- if device_keys:
- logger.info(
- "Updating device_keys for device %r for user %s at %d",
- device_id, user_id, time_now
- )
- # TODO: Sign the JSON with the server key
- yield self.store.set_e2e_device_keys(
- user_id, device_id, time_now,
- encode_canonical_json(device_keys)
- )
-
- one_time_keys = body.get("one_time_keys", None)
- if one_time_keys:
- logger.info(
- "Adding %d one_time_keys for device %r for user %r at %d",
- len(one_time_keys), device_id, user_id, time_now
- )
- key_list = []
- for key_id, key_json in one_time_keys.items():
- algorithm, key_id = key_id.split(":")
- key_list.append((
- algorithm, key_id, encode_canonical_json(key_json)
- ))
-
- yield self.store.add_e2e_one_time_keys(
- user_id, device_id, time_now, key_list
- )
-
- # the device should have been registered already, but it may have been
- # deleted due to a race with a DELETE request. Or we may be using an
- # old access_token without an associated device_id. Either way, we
- # need to double-check the device is registered to avoid ending up with
- # keys without a corresponding device.
- self.device_handler.check_device_registered(user_id, device_id)
-
- result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
- defer.returnValue((200, {"one_time_key_counts": result}))
+ result = yield self.e2e_keys_handler.upload_keys_for_user(
+ user_id, device_id, body
+ )
+ defer.returnValue((200, result))
class KeyQueryServlet(RestServlet):
@@ -199,7 +154,7 @@ class KeyQueryServlet(RestServlet):
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body, timeout)
- defer.returnValue(result)
+ defer.returnValue((200, result))
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
@@ -212,7 +167,7 @@ class KeyQueryServlet(RestServlet):
{"device_keys": {user_id: device_ids}},
timeout,
)
- defer.returnValue(result)
+ defer.returnValue((200, result))
class OneTimeKeyServlet(RestServlet):
@@ -244,80 +199,29 @@ class OneTimeKeyServlet(RestServlet):
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.federation = hs.get_replication_layer()
- self.is_mine_id = hs.is_mine_id
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
- result = yield self.handle_request(
+ result = yield self.e2e_keys_handler.claim_one_time_keys(
{"one_time_keys": {user_id: {device_id: algorithm}}},
timeout,
)
- defer.returnValue(result)
+ defer.returnValue((200, result))
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.handle_request(body, timeout)
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def handle_request(self, body, timeout):
- local_query = []
- remote_queries = {}
-
- for user_id, device_keys in body.get("one_time_keys", {}).items():
- if self.is_mine_id(user_id):
- for device_id, algorithm in device_keys.items():
- local_query.append((user_id, device_id, algorithm))
- else:
- domain = get_domain_from_id(user_id)
- remote_queries.setdefault(domain, {})[user_id] = device_keys
-
- results = yield self.store.claim_e2e_one_time_keys(local_query)
-
- json_result = {}
- failures = {}
- for user_id, device_keys in results.items():
- for device_id, keys in device_keys.items():
- for key_id, json_bytes in keys.items():
- json_result.setdefault(user_id, {})[device_id] = {
- key_id: json.loads(json_bytes)
- }
-
- @defer.inlineCallbacks
- def claim_client_keys(destination):
- device_keys = remote_queries[destination]
- try:
- remote_result = yield self.federation.claim_client_keys(
- destination,
- {"one_time_keys": device_keys},
- timeout=timeout
- )
- for user_id, keys in remote_result["one_time_keys"].items():
- if user_id in device_keys:
- json_result[user_id] = keys
- except CodeMessageException as e:
- failures[destination] = {
- "status": e.code, "message": e.message
- }
-
- yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(claim_client_keys)(destination)
- for destination in remote_queries
- ]))
-
- defer.returnValue((200, {
- "one_time_keys": json_result,
- "failures": failures
- }))
+ result = yield self.e2e_keys_handler.claim_one_time_keys(
+ body,
+ timeout,
+ )
+ defer.returnValue((200, result))
def register_servlets(hs, http_server):
|