diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index eea64c1c9f..5c991e5412 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -15,11 +15,13 @@
# limitations under the License.
import logging
from collections import namedtuple
+from typing import Iterable, List
import six
from twisted.internet import defer
-from twisted.internet.defer import DeferredList
+from twisted.internet.defer import Deferred, DeferredList
+from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@@ -29,6 +31,7 @@ from synapse.api.room_versions import (
RoomVersion,
)
from synapse.crypto.event_signing import check_event_content_hash
+from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
@@ -36,10 +39,8 @@ from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
make_deferred_yieldable,
- preserve_fn,
)
from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import unwrapFirstError
logger = logging.getLogger(__name__)
@@ -54,92 +55,23 @@ class FederationBase(object):
self.store = hs.get_datastore()
self._clock = hs.get_clock()
- @defer.inlineCallbacks
- def _check_sigs_and_hash_and_fetch(
- self, origin, pdus, room_version, outlier=False, include_none=False
- ):
- """Takes a list of PDUs and checks the signatures and hashs of each
- one. If a PDU fails its signature check then we check if we have it in
- the database and if not then request if from the originating server of
- that PDU.
-
- If a PDU fails its content hash check then it is redacted.
-
- The given list of PDUs are not modified, instead the function returns
- a new list.
-
- Args:
- origin (str)
- pdu (list)
- room_version (str)
- outlier (bool): Whether the events are outliers or not
- include_none (str): Whether to include None in the returned list
- for events that have failed their checks
-
- Returns:
- Deferred : A list of PDUs that have valid signatures and hashes.
- """
- deferreds = self._check_sigs_and_hashes(room_version, pdus)
-
- @defer.inlineCallbacks
- def handle_check_result(pdu, deferred):
- try:
- res = yield make_deferred_yieldable(deferred)
- except SynapseError:
- res = None
-
- if not res:
- # Check local db.
- res = yield self.store.get_event(
- pdu.event_id, allow_rejected=True, allow_none=True
- )
-
- if not res and pdu.origin != origin:
- try:
- res = yield self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- room_version=room_version,
- outlier=outlier,
- timeout=10000,
- )
- except SynapseError:
- pass
-
- if not res:
- logger.warning(
- "Failed to find copy of %s with valid signature", pdu.event_id
- )
-
- return res
-
- handle = preserve_fn(handle_check_result)
- deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
-
- valid_pdus = yield make_deferred_yieldable(
- defer.gatherResults(deferreds2, consumeErrors=True)
- ).addErrback(unwrapFirstError)
-
- if include_none:
- return valid_pdus
- else:
- return [p for p in valid_pdus if p]
-
- def _check_sigs_and_hash(self, room_version, pdu):
+ def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)
- def _check_sigs_and_hashes(self, room_version, pdus):
+ def _check_sigs_and_hashes(
+ self, room_version: str, pdus: List[EventBase]
+ ) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
Args:
- room_version (str): The room version of the PDUs
- pdus (list[FrozenEvent]): the events to be checked
+ room_version: The room version of the PDUs
+ pdus: the events to be checked
Returns:
- list[Deferred]: for each input event, a deferred which:
+ For each input event, a deferred which:
* returns the original event if the checks pass
* returns a redacted version of the event (if the signature
matched but the hash did not)
@@ -150,7 +82,7 @@ class FederationBase(object):
ctx = LoggingContext.current_context()
- def callback(_, pdu):
+ def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
# let's try to distinguish between failures because the event was
@@ -187,7 +119,7 @@ class FederationBase(object):
return pdu
- def errback(failure, pdu):
+ def errback(failure: Failure, pdu: EventBase):
failure.trap(SynapseError)
with PreserveLoggingContext(ctx):
logger.warning(
@@ -213,16 +145,18 @@ class PduToCheckSig(
pass
-def _check_sigs_on_pdus(keyring, room_version, pdus):
+def _check_sigs_on_pdus(
+ keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
+) -> List[Deferred]:
"""Check that the given events are correctly signed
Args:
- keyring (synapse.crypto.Keyring): keyring object to do the checks
- room_version (str): the room version of the PDUs
- pdus (Collection[EventBase]): the events to be checked
+ keyring: keyring object to do the checks
+ room_version: the room version of the PDUs
+ pdus: the events to be checked
Returns:
- List[Deferred]: a Deferred for each event in pdus, which will either succeed if
+ A Deferred for each event in pdus, which will either succeed if
the signatures are valid, or fail (with a SynapseError) if not.
"""
@@ -327,7 +261,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
-def _flatten_deferred_list(deferreds):
+def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
"""Given a list of deferreds, either return the single deferred,
combine into a DeferredList, or return an already resolved deferred.
"""
@@ -339,7 +273,7 @@ def _flatten_deferred_list(deferreds):
return defer.succeed(None)
-def _is_invite_via_3pid(event):
+def _is_invite_via_3pid(event: EventBase) -> bool:
return (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 4870e39652..8c6b839478 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -33,6 +33,7 @@ from typing import (
from prometheus_client import Counter
from twisted.internet import defer
+from twisted.internet.defer import Deferred
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
@@ -51,7 +52,7 @@ from synapse.api.room_versions import (
)
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function
from synapse.types import JsonDict
from synapse.util import unwrapFirstError
@@ -187,7 +188,7 @@ class FederationClient(FederationBase):
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
- ) -> List[EventBase]:
+ ) -> Optional[List[EventBase]]:
"""Requests some more historic PDUs for the given room from the
given destination server.
@@ -199,9 +200,9 @@ class FederationClient(FederationBase):
"""
logger.debug("backfill extrem=%s", extremities)
- # If there are no extremeties then we've (probably) reached the start.
+ # If there are no extremities then we've (probably) reached the start.
if not extremities:
- return
+ return None
transaction_data = await self.transport_layer.backfill(
dest, room_id, extremities, limit
@@ -284,7 +285,7 @@ class FederationClient(FederationBase):
pdu_list = [
event_from_pdu_json(p, room_version, outlier=outlier)
for p in transaction_data["pdus"]
- ]
+ ] # type: List[EventBase]
if pdu_list and pdu_list[0]:
pdu = pdu_list[0]
@@ -345,6 +346,83 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids
+ async def _check_sigs_and_hash_and_fetch(
+ self,
+ origin: str,
+ pdus: List[EventBase],
+ room_version: str,
+ outlier: bool = False,
+ include_none: bool = False,
+ ) -> List[EventBase]:
+ """Takes a list of PDUs and checks the signatures and hashs of each
+ one. If a PDU fails its signature check then we check if we have it in
+ the database and if not then request if from the originating server of
+ that PDU.
+
+ If a PDU fails its content hash check then it is redacted.
+
+ The given list of PDUs are not modified, instead the function returns
+ a new list.
+
+ Args:
+ origin
+ pdu
+ room_version
+ outlier: Whether the events are outliers or not
+ include_none: Whether to include None in the returned list
+ for events that have failed their checks
+
+ Returns:
+ Deferred : A list of PDUs that have valid signatures and hashes.
+ """
+ deferreds = self._check_sigs_and_hashes(room_version, pdus)
+
+ @defer.inlineCallbacks
+ def handle_check_result(pdu: EventBase, deferred: Deferred):
+ try:
+ res = yield make_deferred_yieldable(deferred)
+ except SynapseError:
+ res = None
+
+ if not res:
+ # Check local db.
+ res = yield self.store.get_event(
+ pdu.event_id, allow_rejected=True, allow_none=True
+ )
+
+ if not res and pdu.origin != origin:
+ try:
+ res = yield defer.ensureDeferred(
+ self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ room_version=room_version, # type: ignore
+ outlier=outlier,
+ timeout=10000,
+ )
+ )
+ except SynapseError:
+ pass
+
+ if not res:
+ logger.warning(
+ "Failed to find copy of %s with valid signature", pdu.event_id
+ )
+
+ return res
+
+ handle = preserve_fn(handle_check_result)
+ deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
+
+ valid_pdus = await make_deferred_yieldable(
+ defer.gatherResults(deferreds2, consumeErrors=True)
+ ).addErrback(unwrapFirstError)
+
+ if include_none:
+ return valid_pdus
+ else:
+ return [p for p in valid_pdus if p]
+
async def get_event_auth(self, destination, room_id, event_id):
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
@@ -615,7 +693,7 @@ class FederationClient(FederationBase):
]
if auth_chain_create_events != [create_event.event_id]:
raise InvalidResponseError(
- "Unexpected create event(s) in auth chain"
+ "Unexpected create event(s) in auth chain: %s"
% (auth_chain_create_events,)
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 7f9da49326..275b9c99d7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -470,57 +470,6 @@ class FederationServer(FederationBase):
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
return 200, res
- async def on_query_auth_request(self, origin, content, room_id, event_id):
- """
- Content is a dict with keys::
- auth_chain (list): A list of events that give the auth chain.
- missing (list): A list of event_ids indicating what the other
- side (`origin`) think we're missing.
- rejects (dict): A mapping from event_id to a 2-tuple of reason
- string and a proof (or None) of why the event was rejected.
- The keys of this dict give the list of events the `origin` has
- rejected.
-
- Args:
- origin (str)
- content (dict)
- event_id (str)
-
- Returns:
- Deferred: Results in `dict` with the same format as `content`
- """
- with (await self._server_linearizer.queue((origin, room_id))):
- origin_host, _ = parse_server_name(origin)
- await self.check_server_matches_acl(origin_host, room_id)
-
- room_version = await self.store.get_room_version(room_id)
-
- auth_chain = [
- event_from_pdu_json(e, room_version) for e in content["auth_chain"]
- ]
-
- signed_auth = await self._check_sigs_and_hash_and_fetch(
- origin, auth_chain, outlier=True, room_version=room_version.identifier
- )
-
- ret = await self.handler.on_query_auth(
- origin,
- event_id,
- room_id,
- signed_auth,
- content.get("rejects", []),
- content.get("missing", []),
- )
-
- time_now = self._clock.time_msec()
- send_content = {
- "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
- "rejects": ret.get("rejects", []),
- "missing": ret.get("missing", []),
- }
-
- return 200, send_content
-
@log_function
def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index dc563538de..383e3fdc8b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -399,20 +399,30 @@ class TransportLayerClient(object):
{
"device_keys": {
"<user_id>": ["<device_id>"]
- } }
+ }
+ }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {...}
- } } }
+ }
+ },
+ "master_key": {
+ "<user_id>": {...}
+ }
+ },
+ "self_signing_key": {
+ "<user_id>": {...}
+ }
+ }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
- A dict containg the device keys.
+ A dict containing device and cross-signing keys.
"""
path = _create_v1_path("/user/keys/query")
@@ -429,14 +439,30 @@ class TransportLayerClient(object):
Response:
{
"stream_id": "...",
- "devices": [ { ... } ]
+ "devices": [ { ... } ],
+ "master_key": {
+ "user_id": "<user_id>",
+ "usage": [...],
+ "keys": {...},
+ "signatures": {
+ "<user_id>": {...}
+ }
+ },
+ "self_signing_key": {
+ "user_id": "<user_id>",
+ "usage": [...],
+ "keys": {...},
+ "signatures": {
+ "<user_id>": {...}
+ }
+ }
}
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
- A dict containg the device keys.
+ A dict containing device and cross-signing keys.
"""
path = _create_v1_path("/user/devices/%s", user_id)
@@ -454,8 +480,10 @@ class TransportLayerClient(object):
{
"one_time_keys": {
"<user_id>": {
- "<device_id>": "<algorithm>"
- } } }
+ "<device_id>": "<algorithm>"
+ }
+ }
+ }
Response:
{
@@ -463,13 +491,16 @@ class TransportLayerClient(object):
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
- } } } }
+ }
+ }
+ }
+ }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
- A dict containg the one-time keys.
+ A dict containing the one-time keys.
"""
path = _create_v1_path("/user/keys/claim")
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 92a9ae2320..af4595498c 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -643,17 +643,6 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
return 200, response
-class FederationQueryAuthServlet(BaseFederationServlet):
- PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
-
- async def on_POST(self, origin, content, query, context, event_id):
- new_content = await self.handler.on_query_auth_request(
- origin, content, context, event_id
- )
-
- return 200, new_content
-
-
class FederationGetMissingEventsServlet(BaseFederationServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
@@ -1412,7 +1401,6 @@ FEDERATION_SERVLET_CLASSES = (
FederationV2SendLeaveServlet,
FederationV1InviteServlet,
FederationV2InviteServlet,
- FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,
|