diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7736d14fb5..21a86a4c6d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -134,6 +134,40 @@ class FederationClient(FederationBase):
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
+ @log_function
+ def query_client_keys(self, destination, content, retry_on_dns_fail=True):
+ """Query device keys for a device hosted on a remote server.
+
+ Args:
+ destination (str): Domain name of the remote homeserver
+ content (dict): The query content.
+
+ Returns:
+ a Deferred which will eventually yield a JSON object from the
+ response
+ """
+ sent_queries_counter.inc("client_device_keys")
+ return self.transport_layer.query_client_keys(
+ destination, content, retry_on_dns_fail=retry_on_dns_fail
+ )
+
+ @log_function
+ def claim_client_keys(self, destination, content, retry_on_dns_fail=True):
+ """Claims one-time keys for a device hosted on a remote server.
+
+ Args:
+ destination (str): Domain name of the remote homeserver
+ content (dict): The query content.
+
+ Returns:
+ a Deferred which will eventually yield a JSON object from the
+ response
+ """
+ sent_queries_counter.inc("client_one_time_keys")
+ return self.transport_layer.claim_client_keys(
+ destination, content, retry_on_dns_fail=retry_on_dns_fail
+ )
+
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index cd79e23f4b..c32908ac28 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
+import simplejson as json
import logging
@@ -314,6 +315,42 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
+ def on_query_client_keys(self, origin, content):
+ query = []
+ for user_id, device_ids in content.get("device_keys", {}).items():
+ if not device_ids:
+ query.append((user_id, None))
+ else:
+ for device_id in device_ids:
+ query.append((user_id, device_id))
+ results = yield self.store.get_e2e_device_keys(query)
+ json_result = {}
+ for user_id, device_keys in results.items():
+ for device_id, json_bytes in device_keys.items():
+ json_result.setdefault(user_id, {})[device_id] = json.loads(
+ json_bytes
+ )
+ defer.returnValue({"device_keys": json_result})
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_claim_client_keys(self, origin, content):
+ query = []
+ for user_id, device_keys in content.get("one_time_keys", {}).items():
+ for device_id, algorithm in device_keys.items():
+ query.append((user_id, device_id, algorithm))
+ results = yield self.store.claim_e2e_one_time_keys(query)
+ json_result = {}
+ for user_id, device_keys in results.items():
+ for device_id, keys in device_keys.items():
+ for key_id, json_bytes in keys.items():
+ json_result.setdefault(user_id, {})[device_id] = {
+ key_id: json.loads(json_bytes)
+ }
+ defer.returnValue({"one_time_keys": json_result})
+
+ @defer.inlineCallbacks
+ @log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
missing_events = yield self.handler.on_get_missing_events(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 610a4c3163..df5083dd22 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -224,6 +224,76 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
+ def query_client_keys(self, destination, query_content):
+ """Query the device keys for a list of user ids hosted on a remote
+ server.
+
+ Request:
+ {
+ "device_keys": {
+ "<user_id>": ["<device_id>"]
+ } }
+
+ Response:
+ {
+ "device_keys": {
+ "<user_id>": {
+ "<device_id>": {...}
+ } } }
+
+ Args:
+ destination(str): The server to query.
+ query_content(dict): The user ids to query.
+ Returns:
+ A dict containg the device keys.
+ """
+ path = PREFIX + "/client_keys/query"
+
+ content = yield self.client.post_json(
+ destination=destination,
+ path=path,
+ data=query_content,
+ )
+ defer.returnValue(content)
+
+ @defer.inlineCallbacks
+ @log_function
+ def claim_client_keys(self, destination, query_content):
+ """Claim one-time keys for a list of devices hosted on a remote server.
+
+ Request:
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": "<algorithm>"
+ } } }
+
+ Response:
+ {
+ "device_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ } } } }
+
+ Args:
+ destination(str): The server to query.
+ query_content(dict): The user ids to query.
+ Returns:
+ A dict containg the one-time keys.
+ """
+
+ path = PREFIX + "/client_keys/claim"
+
+ content = yield self.client.post_json(
+ destination=destination,
+ path=path,
+ data=query_content,
+ )
+ defer.returnValue(content)
+
+ @defer.inlineCallbacks
+ @log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth):
path = PREFIX + "/get_missing_events/%s" % (room_id,)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index bad93c6b2f..fb59383ecd 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content))
+class FederationClientKeysQueryServlet(BaseFederationServlet):
+ PATH = "/client_keys/query"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content):
+ response = yield self.handler.on_client_key_query(origin, content)
+ defer.returnValue((200, response))
+
+
+class FederationClientKeysClaimServlet(BaseFederationServlet):
+ PATH = "/client_keys/claim"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content):
+ response = yield self.handler.on_client_key_claim(origin, content)
+ defer.returnValue((200, response))
+
+
class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/([^/]*)/([^/]*)"
@@ -373,4 +391,6 @@ SERVLET_CLASSES = (
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
+ FederationClientKeysQueryServlet,
+ FederationClientKeysClaimServlet,
)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 5f3a6207b5..739a08ada8 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
+from synapse.types import UserID
from syutil.jsonutil import encode_canonical_json
from ._base import client_v2_pattern
@@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet):
super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self.federation = hs.get_replication_layer()
+ self.is_mine = hs.is_mine
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
- logger.debug("onPOST")
yield self.auth.get_user_by_req(request)
try:
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
- query = []
- for user_id, device_ids in body.get("device_keys", {}).items():
- if not device_ids:
- query.append((user_id, None))
- else:
- for device_id in device_ids:
- query.append((user_id, device_id))
- results = yield self.store.get_e2e_device_keys(query)
- defer.returnValue(self.json_result(request, results))
+ result = yield self.handle_request(body)
+ defer.returnValue(result)
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string()
- if not user_id:
- user_id = auth_user_id
- if not device_id:
- device_id = None
- # Returns a map of user_id->device_id->json_bytes.
- results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
- defer.returnValue(self.json_result(request, results))
-
- def json_result(self, request, results):
+ user_id = user_id if user_id else auth_user_id
+ device_ids = [device_id] if device_id else []
+ result = yield self.handle_request(
+ {"device_keys": {user_id: device_ids}}
+ )
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def handle_request(self, body):
+ local_query = []
+ remote_queries = {}
+ for user_id, device_ids in body.get("device_keys", {}).items():
+ user = UserID.from_string(user_id)
+ if self.is_mine(user):
+ if not device_ids:
+ local_query.append((user_id, None))
+ else:
+ for device_id in device_ids:
+ local_query.append((user_id, device_id))
+ else:
+ remote_queries.set_default(user.domain, {})[user_id] = list(
+ device_ids
+ )
+ results = yield self.store.get_e2e_device_keys(local_query)
+
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
- return (200, {"device_keys": json_result})
+
+ for destination, device_keys in remote_queries.items():
+ remote_result = yield self.federation.query_client_keys(
+ destination, {"device_keys": device_keys}
+ )
+ for user_id, keys in remote_result.items():
+ if user_id in device_keys:
+ json_result[user_id] = keys
+ defer.returnValue((200, {"device_keys": json_result}))
class OneTimeKeyServlet(RestServlet):
@@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
+ self.federation = hs.get_replication_layer()
+ self.is_mine = hs.is_mine
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
- results = yield self.store.claim_e2e_one_time_keys(
- [(user_id, device_id, algorithm)]
+ result = yield self.handle_request(
+ {"one_time_keys": {user_id: {device_id: algorithm}}}
)
- defer.returnValue(self.json_result(request, results))
+ defer.returnValue(result)
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
@@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet):
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
- query = []
+ result = yield self.handle_request(body)
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def handle_request(self, body):
+ local_query = []
+ remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items():
- for device_id, algorithm in device_keys.items():
- query.append((user_id, device_id, algorithm))
- results = yield self.store.claim_e2e_one_time_keys(query)
- defer.returnValue(self.json_result(request, results))
+ user = UserID.from_string(user_id)
+ if self.is_mine(user):
+ for device_id, algorithm in device_keys.items():
+ local_query.append((user_id, device_id, algorithm))
+ else:
+ remote_queries.set_default(user.domain, {})[user_id] = (
+ device_keys
+ )
+ results = yield self.store.claim_e2e_one_time_keys(local_query)
- def json_result(self, request, results):
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
@@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet):
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
- return (200, {"one_time_keys": json_result})
+
+ for destination, device_keys in remote_queries.items():
+ remote_result = yield self.federation.query_client_keys(
+ destination, {"one_time_keys": device_keys}
+ )
+ for user_id, keys in remote_result.items():
+ if user_id in device_keys:
+ json_result[user_id] = keys
+
+ defer.returnValue((200, {"one_time_keys": json_result}))
def register_servlets(hs, http_server):
|