summary refs log tree commit diff
path: root/synapse/federation/transport/server/_base.py
blob: e1244814742b077df70ad7bf8a688201adcc1fa7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
#  Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#

import functools
import logging
import re
import time
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast

from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_V1_PREFIX
from synapse.http.server import HttpServer, ServletCallback
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
    active_span,
    set_tag,
    span_context_from_request,
    start_active_span,
    start_active_span_follows_from,
    whitelisted_homeserver,
)
from synapse.types import JsonDict
from synapse.util.cancellation import is_function_cancellable
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import parse_and_validate_server_name

if TYPE_CHECKING:
    from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class AuthenticationError(SynapseError):
    """There was a problem authenticating the request"""


class NoAuthenticationError(AuthenticationError):
    """The request had no authentication information"""


class Authenticator:
    def __init__(self, hs: "HomeServer"):
        self._clock = hs.get_clock()
        self.keyring = hs.get_keyring()
        self.server_name = hs.hostname
        self._is_mine_server_name = hs.is_mine_server_name
        self.store = hs.get_datastores().main
        self.federation_domain_whitelist = (
            hs.config.federation.federation_domain_whitelist
        )
        self.notifier = hs.get_notifier()

        self.replication_client = None
        if hs.config.worker.worker_app:
            self.replication_client = hs.get_replication_command_handler()

    # A method just so we can pass 'self' as the authenticator to the Servlets
    async def authenticate_request(
        self, request: SynapseRequest, content: Optional[JsonDict]
    ) -> str:
        now = self._clock.time_msec()
        json_request: JsonDict = {
            "method": request.method.decode("ascii"),
            "uri": request.uri.decode("ascii"),
            "destination": self.server_name,
            "signatures": {},
        }

        if content is not None:
            json_request["content"] = content

        origin = None

        auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")

        if not auth_headers:
            raise NoAuthenticationError(
                HTTPStatus.UNAUTHORIZED,
                "Missing Authorization headers",
                Codes.UNAUTHORIZED,
            )

        for auth in auth_headers:
            if auth.startswith(b"X-Matrix"):
                (origin, key, sig, destination) = _parse_auth_header(auth)
                json_request["origin"] = origin
                json_request["signatures"].setdefault(origin, {})[key] = sig

                # if the origin_server sent a destination along it needs to match our own server_name
                if destination is not None and not self._is_mine_server_name(
                    destination
                ):
                    raise AuthenticationError(
                        HTTPStatus.UNAUTHORIZED,
                        "Destination mismatch in auth header",
                        Codes.UNAUTHORIZED,
                    )
        if (
            self.federation_domain_whitelist is not None
            and origin not in self.federation_domain_whitelist
        ):
            raise FederationDeniedError(origin)

        if origin is None or not json_request["signatures"]:
            raise NoAuthenticationError(
                HTTPStatus.UNAUTHORIZED,
                "Missing Authorization headers",
                Codes.UNAUTHORIZED,
            )

        await self.keyring.verify_json_for_server(
            origin,
            json_request,
            now,
        )

        logger.debug("Request from %s", origin)
        request.requester = origin

        # If we get a valid signed request from the other side, its probably
        # alive
        retry_timings = await self.store.get_destination_retry_timings(origin)
        if retry_timings and retry_timings.retry_last_ts:
            run_in_background(self.reset_retry_timings, origin)

        return origin

    async def reset_retry_timings(self, origin: str) -> None:
        try:
            logger.info("Marking origin %r as up", origin)
            await self.store.set_destination_retry_timings(origin, None, 0, 0)

            # Inform the relevant places that the remote server is back up.
            self.notifier.notify_remote_server_up(origin)
            if self.replication_client:
                # If we're on a worker we try and inform master about this. The
                # replication client doesn't hook into the notifier to avoid
                # infinite loops where we send a `REMOTE_SERVER_UP` command to
                # master, which then echoes it back to us which in turn pokes
                # the notifier.
                self.replication_client.send_remote_server_up(origin)

        except Exception:
            logger.exception("Error resetting retry timings on %s", origin)


def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str, Optional[str]]:
    """Parse an X-Matrix auth header

    Args:
        header_bytes: header value

    Returns:
        origin, key id, signature, destination.
        origin, key id, signature.

    Raises:
        AuthenticationError if the header could not be parsed
    """
    try:
        header_str = header_bytes.decode("utf-8")
        space_or_tab = "[ \t]"
        params = re.split(
            rf"{space_or_tab}*,{space_or_tab}*",
            re.split(r"^X-Matrix +", header_str, maxsplit=1)[1],
        )
        param_dict: Dict[str, str] = {
            k.lower(): v for k, v in [param.split("=", maxsplit=1) for param in params]
        }

        def strip_quotes(value: str) -> str:
            if value.startswith('"'):
                return re.sub(
                    "\\\\(.)", lambda matchobj: matchobj.group(1), value[1:-1]
                )
            else:
                return value

        origin = strip_quotes(param_dict["origin"])

        # ensure that the origin is a valid server name
        parse_and_validate_server_name(origin)

        key = strip_quotes(param_dict["key"])
        sig = strip_quotes(param_dict["sig"])

        # get the destination server_name from the auth header if it exists
        destination = param_dict.get("destination")
        if destination is not None:
            destination = strip_quotes(destination)
        else:
            destination = None

        return origin, key, sig, destination
    except Exception as e:
        logger.warning(
            "Error parsing auth header '%s': %s",
            header_bytes.decode("ascii", "replace"),
            e,
        )
        raise AuthenticationError(
            HTTPStatus.BAD_REQUEST, "Malformed Authorization header", Codes.UNAUTHORIZED
        )


class BaseFederationServlet:
    """Abstract base class for federation servlet classes.

    The servlet object should have a PATH attribute which takes the form of a regexp to
    match against the request path (excluding the /federation/v1 prefix).

    The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
    the appropriate HTTP method. These methods must be *asynchronous* and have the
    signature:

        on_<METHOD>(self, origin, content, query, **kwargs)

        With arguments:

            origin (str|None): The authenticated server_name of the calling server,
                unless REQUIRE_AUTH is set to False and authentication failed.

            content (str|None): decoded json body of the request. None if the
                request was a GET.

            query (dict[bytes, list[bytes]]): Query params from the request. url-decoded
                (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded
                yet.

            **kwargs (dict[unicode, unicode]): the dict mapping keys to path
                components as specified in the path match regexp.

        Returns:
            Optional[Tuple[int, object]]: either (response code, response object) to
                 return a JSON response, or None if the request has already been handled.

        Raises:
            SynapseError: to return an error code

            Exception: other exceptions will be caught, logged, and a 500 will be
                returned.
    """

    PATH = ""  # Overridden in subclasses, the regex to match against the path.

    REQUIRE_AUTH = True

    PREFIX = FEDERATION_V1_PREFIX  # Allows specifying the API version

    RATELIMIT = True  # Whether to rate limit requests or not

    def __init__(
        self,
        hs: "HomeServer",
        authenticator: Authenticator,
        ratelimiter: FederationRateLimiter,
        server_name: str,
    ):
        self.hs = hs
        self.authenticator = authenticator
        self.ratelimiter = ratelimiter
        self.server_name = server_name

    def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
        authenticator = self.authenticator
        ratelimiter = self.ratelimiter

        @functools.wraps(func)
        async def new_func(
            request: SynapseRequest, *args: Any, **kwargs: str
        ) -> Optional[Tuple[int, Any]]:
            """A callback which can be passed to HttpServer.RegisterPaths

            Args:
                request:
                *args: unused?
                **kwargs: the dict mapping keys to path components as specified
                    in the path match regexp.

            Returns:
                (response code, response object) as returned by the callback method.
                None if the request has already been handled.
            """
            content = None
            if request.method in [b"PUT", b"POST"]:
                # TODO: Handle other method types? other content types?
                content = parse_json_object_from_request(request)

            try:
                with start_active_span("authenticate_request"):
                    origin: Optional[str] = await authenticator.authenticate_request(
                        request, content
                    )
            except NoAuthenticationError:
                origin = None
                if self.REQUIRE_AUTH:
                    logger.warning(
                        "authenticate_request failed: missing authentication"
                    )
                    raise
            except Exception as e:
                logger.warning("authenticate_request failed: %s", e)
                raise

            # update the active opentracing span with the authenticated entity
            set_tag("authenticated_entity", str(origin))

            # if the origin is authenticated and whitelisted, use its span context
            # as the parent.
            context = None
            if origin and whitelisted_homeserver(origin):
                context = span_context_from_request(request)

            if context:
                servlet_span = active_span()
                # a scope which uses the origin's context as a parent
                processing_start_time = time.time()
                scope = start_active_span_follows_from(
                    "incoming-federation-request",
                    child_of=context,
                    contexts=(servlet_span,),
                    start_time=processing_start_time,
                )

            else:
                # just use our context as a parent
                scope = start_active_span(
                    "incoming-federation-request",
                )

            try:
                with scope:
                    if origin and self.RATELIMIT:
                        with ratelimiter.ratelimit(origin) as d:
                            await d
                            if request._disconnected:
                                logger.warning(
                                    "client disconnected before we started processing "
                                    "request"
                                )
                                return None
                            if (
                                func.__self__.__class__.__name__  # type: ignore
                                == "FederationMediaDownloadServlet"
                            ):
                                response = await func(
                                    origin, content, request, *args, **kwargs
                                )
                            else:
                                response = await func(
                                    origin, content, request.args, *args, **kwargs
                                )
                    else:
                        if (
                            func.__self__.__class__.__name__  # type: ignore
                            == "FederationMediaDownloadServlet"
                        ):
                            response = await func(
                                origin, content, request, *args, **kwargs
                            )
                        else:
                            response = await func(
                                origin, content, request.args, *args, **kwargs
                            )
            finally:
                # if we used the origin's context as the parent, add a new span using
                # the servlet span as a parent, so that we have a link
                if context:
                    scope2 = start_active_span_follows_from(
                        "process-federation_request",
                        contexts=(scope.span,),
                        start_time=processing_start_time,
                    )
                    scope2.close()

            return response

        return cast(ServletCallback, new_func)

    def register(self, server: HttpServer) -> None:
        pattern = re.compile("^" + self.PREFIX + self.PATH + "$")

        for method in ("GET", "PUT", "POST"):
            code = getattr(self, "on_%s" % (method), None)
            if code is None:
                continue

            if is_function_cancellable(code):
                # The wrapper added by `self._wrap` will inherit the cancellable flag,
                # but the wrapper itself does not support cancellation yet.
                # Once resolved, the cancellation tests in
                # `tests/federation/transport/server/test__base.py` can be re-enabled.
                raise Exception(
                    f"{self.__class__.__name__}.on_{method} has been marked as "
                    "cancellable, but federation servlets do not support cancellation "
                    "yet."
                )

            server.register_paths(
                method,
                (pattern,),
                self._wrap(code),
                self.__class__.__name__,
            )