summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_base.py110
-rw-r--r--synapse/federation/federation_client.py90
-rw-r--r--synapse/federation/federation_server.py51
-rw-r--r--synapse/federation/transport/client.py49
-rw-r--r--synapse/federation/transport/server.py12
5 files changed, 146 insertions, 166 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index eea64c1c9f..5c991e5412 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -15,11 +15,13 @@
 # limitations under the License.
 import logging
 from collections import namedtuple
+from typing import Iterable, List
 
 import six
 
 from twisted.internet import defer
-from twisted.internet.defer import DeferredList
+from twisted.internet.defer import Deferred, DeferredList
+from twisted.python.failure import Failure
 
 from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
@@ -29,6 +31,7 @@ from synapse.api.room_versions import (
     RoomVersion,
 )
 from synapse.crypto.event_signing import check_event_content_hash
+from synapse.crypto.keyring import Keyring
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.http.servlet import assert_params_in_dict
@@ -36,10 +39,8 @@ from synapse.logging.context import (
     LoggingContext,
     PreserveLoggingContext,
     make_deferred_yieldable,
-    preserve_fn,
 )
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import unwrapFirstError
 
 logger = logging.getLogger(__name__)
 
@@ -54,92 +55,23 @@ class FederationBase(object):
         self.store = hs.get_datastore()
         self._clock = hs.get_clock()
 
-    @defer.inlineCallbacks
-    def _check_sigs_and_hash_and_fetch(
-        self, origin, pdus, room_version, outlier=False, include_none=False
-    ):
-        """Takes a list of PDUs and checks the signatures and hashs of each
-        one. If a PDU fails its signature check then we check if we have it in
-        the database and if not then request if from the originating server of
-        that PDU.
-
-        If a PDU fails its content hash check then it is redacted.
-
-        The given list of PDUs are not modified, instead the function returns
-        a new list.
-
-        Args:
-            origin (str)
-            pdu (list)
-            room_version (str)
-            outlier (bool): Whether the events are outliers or not
-            include_none (str): Whether to include None in the returned list
-                for events that have failed their checks
-
-        Returns:
-            Deferred : A list of PDUs that have valid signatures and hashes.
-        """
-        deferreds = self._check_sigs_and_hashes(room_version, pdus)
-
-        @defer.inlineCallbacks
-        def handle_check_result(pdu, deferred):
-            try:
-                res = yield make_deferred_yieldable(deferred)
-            except SynapseError:
-                res = None
-
-            if not res:
-                # Check local db.
-                res = yield self.store.get_event(
-                    pdu.event_id, allow_rejected=True, allow_none=True
-                )
-
-            if not res and pdu.origin != origin:
-                try:
-                    res = yield self.get_pdu(
-                        destinations=[pdu.origin],
-                        event_id=pdu.event_id,
-                        room_version=room_version,
-                        outlier=outlier,
-                        timeout=10000,
-                    )
-                except SynapseError:
-                    pass
-
-            if not res:
-                logger.warning(
-                    "Failed to find copy of %s with valid signature", pdu.event_id
-                )
-
-            return res
-
-        handle = preserve_fn(handle_check_result)
-        deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
-
-        valid_pdus = yield make_deferred_yieldable(
-            defer.gatherResults(deferreds2, consumeErrors=True)
-        ).addErrback(unwrapFirstError)
-
-        if include_none:
-            return valid_pdus
-        else:
-            return [p for p in valid_pdus if p]
-
-    def _check_sigs_and_hash(self, room_version, pdu):
+    def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
         return make_deferred_yieldable(
             self._check_sigs_and_hashes(room_version, [pdu])[0]
         )
 
-    def _check_sigs_and_hashes(self, room_version, pdus):
+    def _check_sigs_and_hashes(
+        self, room_version: str, pdus: List[EventBase]
+    ) -> List[Deferred]:
         """Checks that each of the received events is correctly signed by the
         sending server.
 
         Args:
-            room_version (str): The room version of the PDUs
-            pdus (list[FrozenEvent]): the events to be checked
+            room_version: The room version of the PDUs
+            pdus: the events to be checked
 
         Returns:
-            list[Deferred]: for each input event, a deferred which:
+            For each input event, a deferred which:
               * returns the original event if the checks pass
               * returns a redacted version of the event (if the signature
                 matched but the hash did not)
@@ -150,7 +82,7 @@ class FederationBase(object):
 
         ctx = LoggingContext.current_context()
 
-        def callback(_, pdu):
+        def callback(_, pdu: EventBase):
             with PreserveLoggingContext(ctx):
                 if not check_event_content_hash(pdu):
                     # let's try to distinguish between failures because the event was
@@ -187,7 +119,7 @@ class FederationBase(object):
 
                 return pdu
 
-        def errback(failure, pdu):
+        def errback(failure: Failure, pdu: EventBase):
             failure.trap(SynapseError)
             with PreserveLoggingContext(ctx):
                 logger.warning(
@@ -213,16 +145,18 @@ class PduToCheckSig(
     pass
 
 
-def _check_sigs_on_pdus(keyring, room_version, pdus):
+def _check_sigs_on_pdus(
+    keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
+) -> List[Deferred]:
     """Check that the given events are correctly signed
 
     Args:
-        keyring (synapse.crypto.Keyring): keyring object to do the checks
-        room_version (str): the room version of the PDUs
-        pdus (Collection[EventBase]): the events to be checked
+        keyring: keyring object to do the checks
+        room_version: the room version of the PDUs
+        pdus: the events to be checked
 
     Returns:
-        List[Deferred]: a Deferred for each event in pdus, which will either succeed if
+        A Deferred for each event in pdus, which will either succeed if
            the signatures are valid, or fail (with a SynapseError) if not.
     """
 
@@ -327,7 +261,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
     return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
 
 
-def _flatten_deferred_list(deferreds):
+def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
     """Given a list of deferreds, either return the single deferred,
     combine into a DeferredList, or return an already resolved deferred.
     """
@@ -339,7 +273,7 @@ def _flatten_deferred_list(deferreds):
         return defer.succeed(None)
 
 
-def _is_invite_via_3pid(event):
+def _is_invite_via_3pid(event: EventBase) -> bool:
     return (
         event.type == EventTypes.Member
         and event.membership == Membership.INVITE
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 4870e39652..8c6b839478 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -33,6 +33,7 @@ from typing import (
 from prometheus_client import Counter
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import (
@@ -51,7 +52,7 @@ from synapse.api.room_versions import (
 )
 from synapse.events import EventBase, builder
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.logging.utils import log_function
 from synapse.types import JsonDict
 from synapse.util import unwrapFirstError
@@ -187,7 +188,7 @@ class FederationClient(FederationBase):
 
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
-    ) -> List[EventBase]:
+    ) -> Optional[List[EventBase]]:
         """Requests some more historic PDUs for the given room from the
         given destination server.
 
@@ -199,9 +200,9 @@ class FederationClient(FederationBase):
         """
         logger.debug("backfill extrem=%s", extremities)
 
-        # If there are no extremeties then we've (probably) reached the start.
+        # If there are no extremities then we've (probably) reached the start.
         if not extremities:
-            return
+            return None
 
         transaction_data = await self.transport_layer.backfill(
             dest, room_id, extremities, limit
@@ -284,7 +285,7 @@ class FederationClient(FederationBase):
                 pdu_list = [
                     event_from_pdu_json(p, room_version, outlier=outlier)
                     for p in transaction_data["pdus"]
-                ]
+                ]  # type: List[EventBase]
 
                 if pdu_list and pdu_list[0]:
                     pdu = pdu_list[0]
@@ -345,6 +346,83 @@ class FederationClient(FederationBase):
 
         return state_event_ids, auth_event_ids
 
+    async def _check_sigs_and_hash_and_fetch(
+        self,
+        origin: str,
+        pdus: List[EventBase],
+        room_version: str,
+        outlier: bool = False,
+        include_none: bool = False,
+    ) -> List[EventBase]:
+        """Takes a list of PDUs and checks the signatures and hashs of each
+        one. If a PDU fails its signature check then we check if we have it in
+        the database and if not then request if from the originating server of
+        that PDU.
+
+        If a PDU fails its content hash check then it is redacted.
+
+        The given list of PDUs are not modified, instead the function returns
+        a new list.
+
+        Args:
+            origin
+            pdu
+            room_version
+            outlier: Whether the events are outliers or not
+            include_none: Whether to include None in the returned list
+                for events that have failed their checks
+
+        Returns:
+            Deferred : A list of PDUs that have valid signatures and hashes.
+        """
+        deferreds = self._check_sigs_and_hashes(room_version, pdus)
+
+        @defer.inlineCallbacks
+        def handle_check_result(pdu: EventBase, deferred: Deferred):
+            try:
+                res = yield make_deferred_yieldable(deferred)
+            except SynapseError:
+                res = None
+
+            if not res:
+                # Check local db.
+                res = yield self.store.get_event(
+                    pdu.event_id, allow_rejected=True, allow_none=True
+                )
+
+            if not res and pdu.origin != origin:
+                try:
+                    res = yield defer.ensureDeferred(
+                        self.get_pdu(
+                            destinations=[pdu.origin],
+                            event_id=pdu.event_id,
+                            room_version=room_version,  # type: ignore
+                            outlier=outlier,
+                            timeout=10000,
+                        )
+                    )
+                except SynapseError:
+                    pass
+
+            if not res:
+                logger.warning(
+                    "Failed to find copy of %s with valid signature", pdu.event_id
+                )
+
+            return res
+
+        handle = preserve_fn(handle_check_result)
+        deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
+
+        valid_pdus = await make_deferred_yieldable(
+            defer.gatherResults(deferreds2, consumeErrors=True)
+        ).addErrback(unwrapFirstError)
+
+        if include_none:
+            return valid_pdus
+        else:
+            return [p for p in valid_pdus if p]
+
     async def get_event_auth(self, destination, room_id, event_id):
         res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
 
@@ -615,7 +693,7 @@ class FederationClient(FederationBase):
             ]
             if auth_chain_create_events != [create_event.event_id]:
                 raise InvalidResponseError(
-                    "Unexpected create event(s) in auth chain"
+                    "Unexpected create event(s) in auth chain: %s"
                     % (auth_chain_create_events,)
                 )
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 7f9da49326..275b9c99d7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -470,57 +470,6 @@ class FederationServer(FederationBase):
             res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
         return 200, res
 
-    async def on_query_auth_request(self, origin, content, room_id, event_id):
-        """
-        Content is a dict with keys::
-            auth_chain (list): A list of events that give the auth chain.
-            missing (list): A list of event_ids indicating what the other
-              side (`origin`) think we're missing.
-            rejects (dict): A mapping from event_id to a 2-tuple of reason
-              string and a proof (or None) of why the event was rejected.
-              The keys of this dict give the list of events the `origin` has
-              rejected.
-
-        Args:
-            origin (str)
-            content (dict)
-            event_id (str)
-
-        Returns:
-            Deferred: Results in `dict` with the same format as `content`
-        """
-        with (await self._server_linearizer.queue((origin, room_id))):
-            origin_host, _ = parse_server_name(origin)
-            await self.check_server_matches_acl(origin_host, room_id)
-
-            room_version = await self.store.get_room_version(room_id)
-
-            auth_chain = [
-                event_from_pdu_json(e, room_version) for e in content["auth_chain"]
-            ]
-
-            signed_auth = await self._check_sigs_and_hash_and_fetch(
-                origin, auth_chain, outlier=True, room_version=room_version.identifier
-            )
-
-            ret = await self.handler.on_query_auth(
-                origin,
-                event_id,
-                room_id,
-                signed_auth,
-                content.get("rejects", []),
-                content.get("missing", []),
-            )
-
-            time_now = self._clock.time_msec()
-            send_content = {
-                "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
-                "rejects": ret.get("rejects", []),
-                "missing": ret.get("missing", []),
-            }
-
-        return 200, send_content
-
     @log_function
     def on_query_client_keys(self, origin, content):
         return self.on_query_request("client_keys", content)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index dc563538de..383e3fdc8b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -399,20 +399,30 @@ class TransportLayerClient(object):
             {
               "device_keys": {
                 "<user_id>": ["<device_id>"]
-            } }
+              }
+            }
 
         Response:
             {
               "device_keys": {
                 "<user_id>": {
                   "<device_id>": {...}
-            } } }
+                }
+              },
+              "master_key": {
+                "<user_id>": {...}
+                }
+              },
+              "self_signing_key": {
+                "<user_id>": {...}
+              }
+            }
 
         Args:
             destination(str): The server to query.
             query_content(dict): The user ids to query.
         Returns:
-            A dict containg the device keys.
+            A dict containing device and cross-signing keys.
         """
         path = _create_v1_path("/user/keys/query")
 
@@ -429,14 +439,30 @@ class TransportLayerClient(object):
         Response:
             {
               "stream_id": "...",
-              "devices": [ { ... } ]
+              "devices": [ { ... } ],
+              "master_key": {
+                "user_id": "<user_id>",
+                "usage": [...],
+                "keys": {...},
+                "signatures": {
+                  "<user_id>": {...}
+                }
+              },
+              "self_signing_key": {
+                "user_id": "<user_id>",
+                "usage": [...],
+                "keys": {...},
+                "signatures": {
+                  "<user_id>": {...}
+                }
+              }
             }
 
         Args:
             destination(str): The server to query.
             query_content(dict): The user ids to query.
         Returns:
-            A dict containg the device keys.
+            A dict containing device and cross-signing keys.
         """
         path = _create_v1_path("/user/devices/%s", user_id)
 
@@ -454,8 +480,10 @@ class TransportLayerClient(object):
             {
               "one_time_keys": {
                 "<user_id>": {
-                    "<device_id>": "<algorithm>"
-            } } }
+                  "<device_id>": "<algorithm>"
+                }
+              }
+            }
 
         Response:
             {
@@ -463,13 +491,16 @@ class TransportLayerClient(object):
                 "<user_id>": {
                   "<device_id>": {
                     "<algorithm>:<key_id>": "<key_base64>"
-            } } } }
+                  }
+                }
+              }
+            }
 
         Args:
             destination(str): The server to query.
             query_content(dict): The user ids to query.
         Returns:
-            A dict containg the one-time keys.
+            A dict containing the one-time keys.
         """
 
         path = _create_v1_path("/user/keys/claim")
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 92a9ae2320..af4595498c 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -643,17 +643,6 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
         return 200, response
 
 
-class FederationQueryAuthServlet(BaseFederationServlet):
-    PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
-
-    async def on_POST(self, origin, content, query, context, event_id):
-        new_content = await self.handler.on_query_auth_request(
-            origin, content, context, event_id
-        )
-
-        return 200, new_content
-
-
 class FederationGetMissingEventsServlet(BaseFederationServlet):
     # TODO(paul): Why does this path alone end with "/?" optional?
     PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
@@ -1412,7 +1401,6 @@ FEDERATION_SERVLET_CLASSES = (
     FederationV2SendLeaveServlet,
     FederationV1InviteServlet,
     FederationV2InviteServlet,
-    FederationQueryAuthServlet,
     FederationGetMissingEventsServlet,
     FederationEventAuthServlet,
     FederationClientKeysQueryServlet,