summary refs log tree commit diff
path: root/synapse/federation/transport/server/_base.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-08-16 10:14:31 -0400
committerGitHub <noreply@github.com>2021-08-16 10:14:31 -0400
commit87b62f8bb23f99d76bf0ee62c8217fa45a087673 (patch)
treebff2d59eb703ce9726f7e9114a27430c0e4324ab /synapse/federation/transport/server/_base.py
parentClean up some logging in the federation event handler (#10591) (diff)
downloadsynapse-87b62f8bb23f99d76bf0ee62c8217fa45a087673.tar.xz
Split `synapse.federation.transport.server` into multiple files. (#10590)
Diffstat (limited to 'synapse/federation/transport/server/_base.py')
-rw-r--r--synapse/federation/transport/server/_base.py328
1 files changed, 328 insertions, 0 deletions
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
new file mode 100644
index 0000000000..624c859f1e
--- /dev/null
+++ b/synapse/federation/transport/server/_base.py
@@ -0,0 +1,328 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+import functools
+import logging
+import re
+
+from synapse.api.errors import Codes, FederationDeniedError, SynapseError
+from synapse.api.urls import FEDERATION_V1_PREFIX
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.logging import opentracing
+from synapse.logging.context import run_in_background
+from synapse.logging.opentracing import (
+    SynapseTags,
+    start_active_span,
+    start_active_span_from_request,
+    tags,
+    whitelisted_homeserver,
+)
+from synapse.server import HomeServer
+from synapse.util.ratelimitutils import FederationRateLimiter
+from synapse.util.stringutils import parse_and_validate_server_name
+
+logger = logging.getLogger(__name__)
+
+
+class AuthenticationError(SynapseError):
+    """There was a problem authenticating the request"""
+
+
+class NoAuthenticationError(AuthenticationError):
+    """The request had no authentication information"""
+
+
+class Authenticator:
+    def __init__(self, hs: HomeServer):
+        self._clock = hs.get_clock()
+        self.keyring = hs.get_keyring()
+        self.server_name = hs.hostname
+        self.store = hs.get_datastore()
+        self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+        self.notifier = hs.get_notifier()
+
+        self.replication_client = None
+        if hs.config.worker.worker_app:
+            self.replication_client = hs.get_tcp_replication()
+
+    # A method just so we can pass 'self' as the authenticator to the Servlets
+    async def authenticate_request(self, request, content):
+        now = self._clock.time_msec()
+        json_request = {
+            "method": request.method.decode("ascii"),
+            "uri": request.uri.decode("ascii"),
+            "destination": self.server_name,
+            "signatures": {},
+        }
+
+        if content is not None:
+            json_request["content"] = content
+
+        origin = None
+
+        auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+        if not auth_headers:
+            raise NoAuthenticationError(
+                401, "Missing Authorization headers", Codes.UNAUTHORIZED
+            )
+
+        for auth in auth_headers:
+            if auth.startswith(b"X-Matrix"):
+                (origin, key, sig) = _parse_auth_header(auth)
+                json_request["origin"] = origin
+                json_request["signatures"].setdefault(origin, {})[key] = sig
+
+        if (
+            self.federation_domain_whitelist is not None
+            and origin not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(origin)
+
+        if origin is None or not json_request["signatures"]:
+            raise NoAuthenticationError(
+                401, "Missing Authorization headers", Codes.UNAUTHORIZED
+            )
+
+        await self.keyring.verify_json_for_server(
+            origin,
+            json_request,
+            now,
+        )
+
+        logger.debug("Request from %s", origin)
+        request.requester = origin
+
+        # If we get a valid signed request from the other side, its probably
+        # alive
+        retry_timings = await self.store.get_destination_retry_timings(origin)
+        if retry_timings and retry_timings.retry_last_ts:
+            run_in_background(self._reset_retry_timings, origin)
+
+        return origin
+
+    async def _reset_retry_timings(self, origin):
+        try:
+            logger.info("Marking origin %r as up", origin)
+            await self.store.set_destination_retry_timings(origin, None, 0, 0)
+
+            # Inform the relevant places that the remote server is back up.
+            self.notifier.notify_remote_server_up(origin)
+            if self.replication_client:
+                # If we're on a worker we try and inform master about this. The
+                # replication client doesn't hook into the notifier to avoid
+                # infinite loops where we send a `REMOTE_SERVER_UP` command to
+                # master, which then echoes it back to us which in turn pokes
+                # the notifier.
+                self.replication_client.send_remote_server_up(origin)
+
+        except Exception:
+            logger.exception("Error resetting retry timings on %s", origin)
+
+
+def _parse_auth_header(header_bytes):
+    """Parse an X-Matrix auth header
+
+    Args:
+        header_bytes (bytes): header value
+
+    Returns:
+        Tuple[str, str, str]: origin, key id, signature.
+
+    Raises:
+        AuthenticationError if the header could not be parsed
+    """
+    try:
+        header_str = header_bytes.decode("utf-8")
+        params = header_str.split(" ")[1].split(",")
+        param_dict = dict(kv.split("=") for kv in params)
+
+        def strip_quotes(value):
+            if value.startswith('"'):
+                return value[1:-1]
+            else:
+                return value
+
+        origin = strip_quotes(param_dict["origin"])
+
+        # ensure that the origin is a valid server name
+        parse_and_validate_server_name(origin)
+
+        key = strip_quotes(param_dict["key"])
+        sig = strip_quotes(param_dict["sig"])
+        return origin, key, sig
+    except Exception as e:
+        logger.warning(
+            "Error parsing auth header '%s': %s",
+            header_bytes.decode("ascii", "replace"),
+            e,
+        )
+        raise AuthenticationError(
+            400, "Malformed Authorization header", Codes.UNAUTHORIZED
+        )
+
+
+class BaseFederationServlet:
+    """Abstract base class for federation servlet classes.
+
+    The servlet object should have a PATH attribute which takes the form of a regexp to
+    match against the request path (excluding the /federation/v1 prefix).
+
+    The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
+    the appropriate HTTP method. These methods must be *asynchronous* and have the
+    signature:
+
+        on_<METHOD>(self, origin, content, query, **kwargs)
+
+        With arguments:
+
+            origin (unicode|None): The authenticated server_name of the calling server,
+                unless REQUIRE_AUTH is set to False and authentication failed.
+
+            content (unicode|None): decoded json body of the request. None if the
+                request was a GET.
+
+            query (dict[bytes, list[bytes]]): Query params from the request. url-decoded
+                (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded
+                yet.
+
+            **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+                components as specified in the path match regexp.
+
+        Returns:
+            Optional[Tuple[int, object]]: either (response code, response object) to
+                 return a JSON response, or None if the request has already been handled.
+
+        Raises:
+            SynapseError: to return an error code
+
+            Exception: other exceptions will be caught, logged, and a 500 will be
+                returned.
+    """
+
+    PATH = ""  # Overridden in subclasses, the regex to match against the path.
+
+    REQUIRE_AUTH = True
+
+    PREFIX = FEDERATION_V1_PREFIX  # Allows specifying the API version
+
+    RATELIMIT = True  # Whether to rate limit requests or not
+
+    def __init__(
+        self,
+        hs: HomeServer,
+        authenticator: Authenticator,
+        ratelimiter: FederationRateLimiter,
+        server_name: str,
+    ):
+        self.hs = hs
+        self.authenticator = authenticator
+        self.ratelimiter = ratelimiter
+        self.server_name = server_name
+
+    def _wrap(self, func):
+        authenticator = self.authenticator
+        ratelimiter = self.ratelimiter
+
+        @functools.wraps(func)
+        async def new_func(request, *args, **kwargs):
+            """A callback which can be passed to HttpServer.RegisterPaths
+
+            Args:
+                request (twisted.web.http.Request):
+                *args: unused?
+                **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+                    components as specified in the path match regexp.
+
+            Returns:
+                Tuple[int, object]|None: (response code, response object) as returned by
+                    the callback method. None if the request has already been handled.
+            """
+            content = None
+            if request.method in [b"PUT", b"POST"]:
+                # TODO: Handle other method types? other content types?
+                content = parse_json_object_from_request(request)
+
+            try:
+                origin = await authenticator.authenticate_request(request, content)
+            except NoAuthenticationError:
+                origin = None
+                if self.REQUIRE_AUTH:
+                    logger.warning(
+                        "authenticate_request failed: missing authentication"
+                    )
+                    raise
+            except Exception as e:
+                logger.warning("authenticate_request failed: %s", e)
+                raise
+
+            request_tags = {
+                SynapseTags.REQUEST_ID: request.get_request_id(),
+                tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+                tags.HTTP_METHOD: request.get_method(),
+                tags.HTTP_URL: request.get_redacted_uri(),
+                tags.PEER_HOST_IPV6: request.getClientIP(),
+                "authenticated_entity": origin,
+                "servlet_name": request.request_metrics.name,
+            }
+
+            # Only accept the span context if the origin is authenticated
+            # and whitelisted
+            if origin and whitelisted_homeserver(origin):
+                scope = start_active_span_from_request(
+                    request, "incoming-federation-request", tags=request_tags
+                )
+            else:
+                scope = start_active_span(
+                    "incoming-federation-request", tags=request_tags
+                )
+
+            with scope:
+                opentracing.inject_response_headers(request.responseHeaders)
+
+                if origin and self.RATELIMIT:
+                    with ratelimiter.ratelimit(origin) as d:
+                        await d
+                        if request._disconnected:
+                            logger.warning(
+                                "client disconnected before we started processing "
+                                "request"
+                            )
+                            return -1, None
+                        response = await func(
+                            origin, content, request.args, *args, **kwargs
+                        )
+                else:
+                    response = await func(
+                        origin, content, request.args, *args, **kwargs
+                    )
+
+            return response
+
+        return new_func
+
+    def register(self, server):
+        pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
+
+        for method in ("GET", "PUT", "POST"):
+            code = getattr(self, "on_%s" % (method), None)
+            if code is None:
+                continue
+
+            server.register_paths(
+                method,
+                (pattern,),
+                self._wrap(code),
+                self.__class__.__name__,
+            )