summary refs log tree commit diff
path: root/synapse/federation/transport/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/transport/server.py')
-rw-r--r--synapse/federation/transport/server.py180
1 files changed, 138 insertions, 42 deletions
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 19d09f5422..7a993fd1cf 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -14,25 +14,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import functools
+import logging
+import re
+
 from twisted.internet import defer
 
+import synapse
+from synapse.api.errors import Codes, FederationDeniedError, SynapseError
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError, FederationDeniedError
+from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
-    parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
     parse_boolean_from_args,
+    parse_integer_from_args,
+    parse_json_object_from_request,
+    parse_string_from_args,
 )
+from synapse.types import ThirdPartyInstanceID, get_domain_from_id
+from synapse.util.logcontext import run_in_background
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
-from synapse.util.logcontext import run_in_background
-from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-
-import functools
-import logging
-import re
-import synapse
-
 
 logger = logging.getLogger(__name__)
 
@@ -99,26 +101,6 @@ class Authenticator(object):
 
         origin = None
 
-        def parse_auth_header(header_str):
-            try:
-                params = auth.split(" ")[1].split(",")
-                param_dict = dict(kv.split("=") for kv in params)
-
-                def strip_quotes(value):
-                    if value.startswith("\""):
-                        return value[1:-1]
-                    else:
-                        return value
-
-                origin = strip_quotes(param_dict["origin"])
-                key = strip_quotes(param_dict["key"])
-                sig = strip_quotes(param_dict["sig"])
-                return (origin, key, sig)
-            except Exception:
-                raise AuthenticationError(
-                    400, "Malformed Authorization header", Codes.UNAUTHORIZED
-                )
-
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
 
         if not auth_headers:
@@ -127,8 +109,8 @@ class Authenticator(object):
             )
 
         for auth in auth_headers:
-            if auth.startswith("X-Matrix"):
-                (origin, key, sig) = parse_auth_header(auth)
+            if auth.startswith(b"X-Matrix"):
+                (origin, key, sig) = _parse_auth_header(auth)
                 json_request["origin"] = origin
                 json_request["signatures"].setdefault(origin, {})[key] = sig
 
@@ -165,7 +147,84 @@ class Authenticator(object):
             logger.exception("Error resetting retry timings on %s", origin)
 
 
+def _parse_auth_header(header_bytes):
+    """Parse an X-Matrix auth header
+
+    Args:
+        header_bytes (bytes): header value
+
+    Returns:
+        Tuple[str, str, str]: origin, key id, signature.
+
+    Raises:
+        AuthenticationError if the header could not be parsed
+    """
+    try:
+        header_str = header_bytes.decode('utf-8')
+        params = header_str.split(" ")[1].split(",")
+        param_dict = dict(kv.split("=") for kv in params)
+
+        def strip_quotes(value):
+            if value.startswith("\""):
+                return value[1:-1]
+            else:
+                return value
+
+        origin = strip_quotes(param_dict["origin"])
+
+        # ensure that the origin is a valid server name
+        parse_and_validate_server_name(origin)
+
+        key = strip_quotes(param_dict["key"])
+        sig = strip_quotes(param_dict["sig"])
+        return origin, key, sig
+    except Exception as e:
+        logger.warn(
+            "Error parsing auth header '%s': %s",
+            header_bytes.decode('ascii', 'replace'),
+            e,
+        )
+        raise AuthenticationError(
+            400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+        )
+
+
 class BaseFederationServlet(object):
+    """Abstract base class for federation servlet classes.
+
+    The servlet object should have a PATH attribute which takes the form of a regexp to
+    match against the request path (excluding the /federation/v1 prefix).
+
+    The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
+    the appropriate HTTP method. These methods have the signature:
+
+        on_<METHOD>(self, origin, content, query, **kwargs)
+
+        With arguments:
+
+            origin (unicode|None): The authenticated server_name of the calling server,
+                unless REQUIRE_AUTH is set to False and authentication failed.
+
+            content (unicode|None): decoded json body of the request. None if the
+                request was a GET.
+
+            query (dict[bytes, list[bytes]]): Query params from the request. url-decoded
+                (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded
+                yet.
+
+            **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+                components as specified in the path match regexp.
+
+        Returns:
+            Deferred[(int, object)|None]: either (response code, response object) to
+                 return a JSON response, or None if the request has already been handled.
+
+        Raises:
+            SynapseError: to return an error code
+
+            Exception: other exceptions will be caught, logged, and a 500 will be
+                returned.
+    """
     REQUIRE_AUTH = True
 
     def __init__(self, handler, authenticator, ratelimiter, server_name):
@@ -180,6 +239,18 @@ class BaseFederationServlet(object):
         @defer.inlineCallbacks
         @functools.wraps(func)
         def new_func(request, *args, **kwargs):
+            """ A callback which can be passed to HttpServer.RegisterPaths
+
+            Args:
+                request (twisted.web.http.Request):
+                *args: unused?
+                **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+                    components as specified in the path match regexp.
+
+            Returns:
+                Deferred[(int, object)|None]: (response code, response object) as returned
+                    by the callback method. None if the request has already been handled.
+            """
             content = None
             if request.method in ["PUT", "POST"]:
                 # TODO: Handle other method types? other content types?
@@ -190,10 +261,10 @@ class BaseFederationServlet(object):
             except NoAuthenticationError:
                 origin = None
                 if self.REQUIRE_AUTH:
-                    logger.exception("authenticate_request failed")
+                    logger.warn("authenticate_request failed: missing authentication")
                     raise
-            except Exception:
-                logger.exception("authenticate_request failed")
+            except Exception as e:
+                logger.warn("authenticate_request failed: %s", e)
                 raise
 
             if origin:
@@ -259,11 +330,10 @@ class FederationSendServlet(BaseFederationServlet):
             )
 
             logger.info(
-                "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
+                "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
                 transaction_id, origin,
                 len(transaction_data.get("pdus", [])),
                 len(transaction_data.get("edus", [])),
-                len(transaction_data.get("failures", [])),
             )
 
             # We should ideally be getting this from the security layer.
@@ -361,8 +431,32 @@ class FederationMakeJoinServlet(BaseFederationServlet):
     PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
 
     @defer.inlineCallbacks
-    def on_GET(self, origin, content, query, context, user_id):
-        content = yield self.handler.on_make_join_request(context, user_id)
+    def on_GET(self, origin, _content, query, context, user_id):
+        """
+        Args:
+            origin (unicode): The authenticated server_name of the calling server
+
+            _content (None): (GETs don't have bodies)
+
+            query (dict[bytes, list[bytes]]): Query params from the request.
+
+            **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+                components as specified in the path match regexp.
+
+        Returns:
+            Deferred[(int, object)|None]: either (response code, response object) to
+                 return a JSON response, or None if the request has already been handled.
+        """
+        versions = query.get(b'ver')
+        if versions is not None:
+            supported_versions = [v.decode("utf-8") for v in versions]
+        else:
+            supported_versions = ["1"]
+
+        content = yield self.handler.on_make_join_request(
+            origin, context, user_id,
+            supported_versions=supported_versions,
+        )
         defer.returnValue((200, content))
 
 
@@ -371,15 +465,17 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, context, user_id):
-        content = yield self.handler.on_make_leave_request(context, user_id)
+        content = yield self.handler.on_make_leave_request(
+            origin, context, user_id,
+        )
         defer.returnValue((200, content))
 
 
 class FederationSendLeaveServlet(BaseFederationServlet):
-    PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<txid>[^/]*)"
+    PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
     @defer.inlineCallbacks
-    def on_PUT(self, origin, content, query, room_id, txid):
+    def on_PUT(self, origin, content, query, room_id, event_id):
         content = yield self.handler.on_send_leave_request(origin, content)
         defer.returnValue((200, content))