diff options
Diffstat (limited to 'synapse/federation/transport/server.py')
-rw-r--r-- | synapse/federation/transport/server.py | 180 |
1 files changed, 138 insertions, 42 deletions
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 19d09f5422..7a993fd1cf 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -14,25 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import logging +import re + from twisted.internet import defer +import synapse +from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.urls import FEDERATION_PREFIX as PREFIX -from synapse.api.errors import Codes, SynapseError, FederationDeniedError +from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( - parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, parse_boolean_from_args, + parse_integer_from_args, + parse_json_object_from_request, + parse_string_from_args, ) +from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.logcontext import run_in_background from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string -from synapse.util.logcontext import run_in_background -from synapse.types import ThirdPartyInstanceID, get_domain_from_id - -import functools -import logging -import re -import synapse - logger = logging.getLogger(__name__) @@ -99,26 +101,6 @@ class Authenticator(object): origin = None - def parse_auth_header(header_str): - try: - params = auth.split(" ")[1].split(",") - param_dict = dict(kv.split("=") for kv in params) - - def strip_quotes(value): - if value.startswith("\""): - return value[1:-1] - else: - return value - - origin = strip_quotes(param_dict["origin"]) - key = strip_quotes(param_dict["key"]) - sig = strip_quotes(param_dict["sig"]) - return (origin, key, sig) - except Exception: - raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED - ) - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") if not auth_headers: @@ -127,8 +109,8 @@ class Authenticator(object): ) for auth in auth_headers: - if auth.startswith("X-Matrix"): - (origin, key, sig) = parse_auth_header(auth) + if auth.startswith(b"X-Matrix"): + (origin, key, sig) = _parse_auth_header(auth) json_request["origin"] = origin json_request["signatures"].setdefault(origin, {})[key] = sig @@ -165,7 +147,84 @@ class Authenticator(object): logger.exception("Error resetting retry timings on %s", origin) +def _parse_auth_header(header_bytes): + """Parse an X-Matrix auth header + + Args: + header_bytes (bytes): header value + + Returns: + Tuple[str, str, str]: origin, key id, signature. + + Raises: + AuthenticationError if the header could not be parsed + """ + try: + header_str = header_bytes.decode('utf-8') + params = header_str.split(" ")[1].split(",") + param_dict = dict(kv.split("=") for kv in params) + + def strip_quotes(value): + if value.startswith("\""): + return value[1:-1] + else: + return value + + origin = strip_quotes(param_dict["origin"]) + + # ensure that the origin is a valid server name + parse_and_validate_server_name(origin) + + key = strip_quotes(param_dict["key"]) + sig = strip_quotes(param_dict["sig"]) + return origin, key, sig + except Exception as e: + logger.warn( + "Error parsing auth header '%s': %s", + header_bytes.decode('ascii', 'replace'), + e, + ) + raise AuthenticationError( + 400, "Malformed Authorization header", Codes.UNAUTHORIZED, + ) + + class BaseFederationServlet(object): + """Abstract base class for federation servlet classes. + + The servlet object should have a PATH attribute which takes the form of a regexp to + match against the request path (excluding the /federation/v1 prefix). + + The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match + the appropriate HTTP method. These methods have the signature: + + on_<METHOD>(self, origin, content, query, **kwargs) + + With arguments: + + origin (unicode|None): The authenticated server_name of the calling server, + unless REQUIRE_AUTH is set to False and authentication failed. + + content (unicode|None): decoded json body of the request. None if the + request was a GET. + + query (dict[bytes, list[bytes]]): Query params from the request. url-decoded + (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded + yet. + + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Deferred[(int, object)|None]: either (response code, response object) to + return a JSON response, or None if the request has already been handled. + + Raises: + SynapseError: to return an error code + + Exception: other exceptions will be caught, logged, and a 500 will be + returned. + """ REQUIRE_AUTH = True def __init__(self, handler, authenticator, ratelimiter, server_name): @@ -180,6 +239,18 @@ class BaseFederationServlet(object): @defer.inlineCallbacks @functools.wraps(func) def new_func(request, *args, **kwargs): + """ A callback which can be passed to HttpServer.RegisterPaths + + Args: + request (twisted.web.http.Request): + *args: unused? + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Deferred[(int, object)|None]: (response code, response object) as returned + by the callback method. None if the request has already been handled. + """ content = None if request.method in ["PUT", "POST"]: # TODO: Handle other method types? other content types? @@ -190,10 +261,10 @@ class BaseFederationServlet(object): except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: - logger.exception("authenticate_request failed") + logger.warn("authenticate_request failed: missing authentication") raise - except Exception: - logger.exception("authenticate_request failed") + except Exception as e: + logger.warn("authenticate_request failed: %s", e) raise if origin: @@ -259,11 +330,10 @@ class FederationSendServlet(BaseFederationServlet): ) logger.info( - "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)", + "Received txn %s from %s. (PDUs: %d, EDUs: %d)", transaction_id, origin, len(transaction_data.get("pdus", [])), len(transaction_data.get("edus", [])), - len(transaction_data.get("failures", [])), ) # We should ideally be getting this from the security layer. @@ -361,8 +431,32 @@ class FederationMakeJoinServlet(BaseFederationServlet): PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)" @defer.inlineCallbacks - def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_join_request(context, user_id) + def on_GET(self, origin, _content, query, context, user_id): + """ + Args: + origin (unicode): The authenticated server_name of the calling server + + _content (None): (GETs don't have bodies) + + query (dict[bytes, list[bytes]]): Query params from the request. + + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Deferred[(int, object)|None]: either (response code, response object) to + return a JSON response, or None if the request has already been handled. + """ + versions = query.get(b'ver') + if versions is not None: + supported_versions = [v.decode("utf-8") for v in versions] + else: + supported_versions = ["1"] + + content = yield self.handler.on_make_join_request( + origin, context, user_id, + supported_versions=supported_versions, + ) defer.returnValue((200, content)) @@ -371,15 +465,17 @@ class FederationMakeLeaveServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_leave_request(context, user_id) + content = yield self.handler.on_make_leave_request( + origin, context, user_id, + ) defer.returnValue((200, content)) class FederationSendLeaveServlet(BaseFederationServlet): - PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<txid>[^/]*)" + PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" @defer.inlineCallbacks - def on_PUT(self, origin, content, query, room_id, txid): + def on_PUT(self, origin, content, query, room_id, event_id): content = yield self.handler.on_send_leave_request(origin, content) defer.returnValue((200, content)) |