diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 299493af91..bdfa247604 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
- def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
+ def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
+ include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@@ -50,84 +51,108 @@ class FederationBase(object):
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
+ deferreds = self._check_sigs_and_hashes(pdus)
- signed_pdus = []
+ def callback(pdu):
+ return pdu
- @defer.inlineCallbacks
- def do(pdu):
- try:
- new_pdu = yield self._check_sigs_and_hash(pdu)
- signed_pdus.append(new_pdu)
- except SynapseError:
- # FIXME: We should handle signature failures more gracefully.
+ def errback(failure, pdu):
+ failure.trap(SynapseError)
+ return None
+ def try_local_db(res, pdu):
+ if not res:
# Check local db.
- new_pdu = yield self.store.get_event(
+ return self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
- if new_pdu:
- signed_pdus.append(new_pdu)
- return
-
- # Check pdu.origin
- if pdu.origin != origin:
- try:
- new_pdu = yield self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- outlier=outlier,
- timeout=10000,
- )
-
- if new_pdu:
- signed_pdus.append(new_pdu)
- return
- except:
- pass
-
+ return res
+
+ def try_remote(res, pdu):
+ if not res and pdu.origin != origin:
+ return self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ outlier=outlier,
+ timeout=10000,
+ ).addErrback(lambda e: None)
+ return res
+
+ def warn(res, pdu):
+ if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
+ return res
+
+ for pdu, deferred in zip(pdus, deferreds):
+ deferred.addCallbacks(
+ callback, errback, errbackArgs=[pdu]
+ ).addCallback(
+ try_local_db, pdu
+ ).addCallback(
+ try_remote, pdu
+ ).addCallback(
+ warn, pdu
+ )
- yield defer.gatherResults(
- [do(pdu) for pdu in pdus],
+ valid_pdus = yield defer.gatherResults(
+ deferreds,
consumeErrors=True
).addErrback(unwrapFirstError)
- defer.returnValue(signed_pdus)
+ if include_none:
+ defer.returnValue(valid_pdus)
+ else:
+ defer.returnValue([p for p in valid_pdus if p])
- @defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
- """Throws a SynapseError if the PDU does not have the correct
+ return self._check_sigs_and_hashes([pdu])[0]
+
+ def _check_sigs_and_hashes(self, pdus):
+ """Throws a SynapseError if a PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
- # Check signatures are correct.
- redacted_event = prune_event(pdu)
- redacted_pdu_json = redacted_event.get_pdu_json()
- try:
- yield self.keyring.verify_json_for_server(
- pdu.origin, redacted_pdu_json
- )
- except SynapseError:
+ redacted_pdus = [
+ prune_event(pdu)
+ for pdu in pdus
+ ]
+
+ deferreds = self.keyring.verify_json_objects_for_server([
+ (p.origin, p.get_pdu_json())
+ for p in redacted_pdus
+ ])
+
+ def callback(_, pdu, redacted):
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+ return pdu
+
+ def errback(failure, pdu):
+ failure.trap(SynapseError)
logger.warn(
"Signature check failed for %s",
pdu.event_id,
)
- raise
+ return failure
- if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting.",
- pdu.event_id,
+ for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
+ deferred.addCallbacks(
+ callback, errback,
+ callbackArgs=[pdu, redacted],
+ errbackArgs=[pdu],
)
- defer.returnValue(redacted_event)
- defer.returnValue(pdu)
+ return deferreds
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d3b46b24c1..f5e346cdbc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -23,13 +23,14 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError
-from synapse.util.expiringcache import ExpiringCache
+from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+import copy
import itertools
import logging
import random
@@ -133,6 +134,36 @@ class FederationClient(FederationBase):
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
+ @log_function
+ def query_client_keys(self, destination, content):
+ """Query device keys for a device hosted on a remote server.
+
+ Args:
+ destination (str): Domain name of the remote homeserver
+ content (dict): The query content.
+
+ Returns:
+ a Deferred which will eventually yield a JSON object from the
+ response
+ """
+ sent_queries_counter.inc("client_device_keys")
+ return self.transport_layer.query_client_keys(destination, content)
+
+ @log_function
+ def claim_client_keys(self, destination, content):
+ """Claims one-time keys for a device hosted on a remote server.
+
+ Args:
+ destination (str): Domain name of the remote homeserver
+ content (dict): The query content.
+
+ Returns:
+ a Deferred which will eventually yield a JSON object from the
+ response
+ """
+ sent_queries_counter.inc("client_one_time_keys")
+ return self.transport_layer.claim_client_keys(destination, content)
+
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
@@ -167,7 +198,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults(
- [self._check_sigs_and_hash(pdu) for pdu in pdus],
+ self._check_sigs_and_hashes(pdus),
consumeErrors=True,
).addErrback(unwrapFirstError)
@@ -230,7 +261,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- pdu = yield self._check_sigs_and_hash(pdu)
+ pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
@@ -327,6 +358,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id):
for destination in destinations:
+ if destination == self.server_name:
+ continue
+
try:
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
@@ -353,6 +387,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
for destination in destinations:
+ if destination == self.server_name:
+ continue
+
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
@@ -374,17 +411,39 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
- signed_state, signed_auth = yield defer.gatherResults(
- [
- self._check_sigs_and_hash_and_fetch(
- destination, state, outlier=True
- ),
- self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True
- )
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError)
+ pdus = {
+ p.event_id: p
+ for p in itertools.chain(state, auth_chain)
+ }
+
+ valid_pdus = yield self._check_sigs_and_hash_and_fetch(
+ destination, pdus.values(),
+ outlier=True,
+ )
+
+ valid_pdus_map = {
+ p.event_id: p
+ for p in valid_pdus
+ }
+
+ # NB: We *need* to copy to ensure that we don't have multiple
+ # references being passed on, as that causes... issues.
+ signed_state = [
+ copy.copy(valid_pdus_map[p.event_id])
+ for p in state
+ if p.event_id in valid_pdus_map
+ ]
+
+ signed_auth = [
+ valid_pdus_map[p.event_id]
+ for p in auth_chain
+ if p.event_id in valid_pdus_map
+ ]
+
+ # NB: We *need* to copy to ensure that we don't have multiple
+ # references being passed on, as that causes... issues.
+ for s in signed_state:
+ s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth)
@@ -396,7 +455,7 @@ class FederationClient(FederationBase):
except CodeMessageException:
raise
except Exception as e:
- logger.warn(
+ logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index cd79e23f4b..725c6f3fa5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
+import simplejson as json
import logging
@@ -314,6 +315,48 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
+ def on_query_client_keys(self, origin, content):
+ query = []
+ for user_id, device_ids in content.get("device_keys", {}).items():
+ if not device_ids:
+ query.append((user_id, None))
+ else:
+ for device_id in device_ids:
+ query.append((user_id, device_id))
+
+ results = yield self.store.get_e2e_device_keys(query)
+
+ json_result = {}
+ for user_id, device_keys in results.items():
+ for device_id, json_bytes in device_keys.items():
+ json_result.setdefault(user_id, {})[device_id] = json.loads(
+ json_bytes
+ )
+
+ defer.returnValue({"device_keys": json_result})
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_claim_client_keys(self, origin, content):
+ query = []
+ for user_id, device_keys in content.get("one_time_keys", {}).items():
+ for device_id, algorithm in device_keys.items():
+ query.append((user_id, device_id, algorithm))
+
+ results = yield self.store.claim_e2e_one_time_keys(query)
+
+ json_result = {}
+ for user_id, device_keys in results.items():
+ for device_id, keys in device_keys.items():
+ for key_id, json_bytes in keys.items():
+ json_result.setdefault(user_id, {})[device_id] = {
+ key_id: json.loads(json_bytes)
+ }
+
+ defer.returnValue({"one_time_keys": json_result})
+
+ @defer.inlineCallbacks
+ @log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
missing_events = yield self.handler.on_get_missing_events(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 610a4c3163..ced703364b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -224,6 +224,76 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
+ def query_client_keys(self, destination, query_content):
+ """Query the device keys for a list of user ids hosted on a remote
+ server.
+
+ Request:
+ {
+ "device_keys": {
+ "<user_id>": ["<device_id>"]
+ } }
+
+ Response:
+ {
+ "device_keys": {
+ "<user_id>": {
+ "<device_id>": {...}
+ } } }
+
+ Args:
+ destination(str): The server to query.
+ query_content(dict): The user ids to query.
+ Returns:
+ A dict containg the device keys.
+ """
+ path = PREFIX + "/user/keys/query"
+
+ content = yield self.client.post_json(
+ destination=destination,
+ path=path,
+ data=query_content,
+ )
+ defer.returnValue(content)
+
+ @defer.inlineCallbacks
+ @log_function
+ def claim_client_keys(self, destination, query_content):
+ """Claim one-time keys for a list of devices hosted on a remote server.
+
+ Request:
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": "<algorithm>"
+ } } }
+
+ Response:
+ {
+ "device_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ } } } }
+
+ Args:
+ destination(str): The server to query.
+ query_content(dict): The user ids to query.
+ Returns:
+ A dict containg the one-time keys.
+ """
+
+ path = PREFIX + "/user/keys/claim"
+
+ content = yield self.client.post_json(
+ destination=destination,
+ path=path,
+ data=query_content,
+ )
+ defer.returnValue(content)
+
+ @defer.inlineCallbacks
+ @log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth):
path = PREFIX + "/get_missing_events/%s" % (room_id,)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index bad93c6b2f..36f250e1a3 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content))
+class FederationClientKeysQueryServlet(BaseFederationServlet):
+ PATH = "/user/keys/query"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query):
+ response = yield self.handler.on_query_client_keys(origin, content)
+ defer.returnValue((200, response))
+
+
+class FederationClientKeysClaimServlet(BaseFederationServlet):
+ PATH = "/user/keys/claim"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query):
+ response = yield self.handler.on_claim_client_keys(origin, content)
+ defer.returnValue((200, response))
+
+
class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/([^/]*)/([^/]*)"
@@ -373,4 +391,6 @@ SERVLET_CLASSES = (
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
+ FederationClientKeysQueryServlet,
+ FederationClientKeysClaimServlet,
)
|