diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 19d09f5422..7a993fd1cf 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -14,25 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
+import logging
+import re
+
from twisted.internet import defer
+import synapse
+from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError, FederationDeniedError
+from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
- parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
parse_boolean_from_args,
+ parse_integer_from_args,
+ parse_json_object_from_request,
+ parse_string_from_args,
)
+from synapse.types import ThirdPartyInstanceID, get_domain_from_id
+from synapse.util.logcontext import run_in_background
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
-from synapse.util.logcontext import run_in_background
-from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-
-import functools
-import logging
-import re
-import synapse
-
logger = logging.getLogger(__name__)
@@ -99,26 +101,6 @@ class Authenticator(object):
origin = None
- def parse_auth_header(header_str):
- try:
- params = auth.split(" ")[1].split(",")
- param_dict = dict(kv.split("=") for kv in params)
-
- def strip_quotes(value):
- if value.startswith("\""):
- return value[1:-1]
- else:
- return value
-
- origin = strip_quotes(param_dict["origin"])
- key = strip_quotes(param_dict["key"])
- sig = strip_quotes(param_dict["sig"])
- return (origin, key, sig)
- except Exception:
- raise AuthenticationError(
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED
- )
-
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
@@ -127,8 +109,8 @@ class Authenticator(object):
)
for auth in auth_headers:
- if auth.startswith("X-Matrix"):
- (origin, key, sig) = parse_auth_header(auth)
+ if auth.startswith(b"X-Matrix"):
+ (origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
@@ -165,7 +147,84 @@ class Authenticator(object):
logger.exception("Error resetting retry timings on %s", origin)
+def _parse_auth_header(header_bytes):
+ """Parse an X-Matrix auth header
+
+ Args:
+ header_bytes (bytes): header value
+
+ Returns:
+ Tuple[str, str, str]: origin, key id, signature.
+
+ Raises:
+ AuthenticationError if the header could not be parsed
+ """
+ try:
+ header_str = header_bytes.decode('utf-8')
+ params = header_str.split(" ")[1].split(",")
+ param_dict = dict(kv.split("=") for kv in params)
+
+ def strip_quotes(value):
+ if value.startswith("\""):
+ return value[1:-1]
+ else:
+ return value
+
+ origin = strip_quotes(param_dict["origin"])
+
+ # ensure that the origin is a valid server name
+ parse_and_validate_server_name(origin)
+
+ key = strip_quotes(param_dict["key"])
+ sig = strip_quotes(param_dict["sig"])
+ return origin, key, sig
+ except Exception as e:
+ logger.warn(
+ "Error parsing auth header '%s': %s",
+ header_bytes.decode('ascii', 'replace'),
+ e,
+ )
+ raise AuthenticationError(
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+ )
+
+
class BaseFederationServlet(object):
+ """Abstract base class for federation servlet classes.
+
+ The servlet object should have a PATH attribute which takes the form of a regexp to
+ match against the request path (excluding the /federation/v1 prefix).
+
+ The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
+ the appropriate HTTP method. These methods have the signature:
+
+ on_<METHOD>(self, origin, content, query, **kwargs)
+
+ With arguments:
+
+ origin (unicode|None): The authenticated server_name of the calling server,
+ unless REQUIRE_AUTH is set to False and authentication failed.
+
+ content (unicode|None): decoded json body of the request. None if the
+ request was a GET.
+
+ query (dict[bytes, list[bytes]]): Query params from the request. url-decoded
+ (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded
+ yet.
+
+ **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+ components as specified in the path match regexp.
+
+ Returns:
+ Deferred[(int, object)|None]: either (response code, response object) to
+ return a JSON response, or None if the request has already been handled.
+
+ Raises:
+ SynapseError: to return an error code
+
+ Exception: other exceptions will be caught, logged, and a 500 will be
+ returned.
+ """
REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name):
@@ -180,6 +239,18 @@ class BaseFederationServlet(object):
@defer.inlineCallbacks
@functools.wraps(func)
def new_func(request, *args, **kwargs):
+ """ A callback which can be passed to HttpServer.RegisterPaths
+
+ Args:
+ request (twisted.web.http.Request):
+ *args: unused?
+ **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+ components as specified in the path match regexp.
+
+ Returns:
+ Deferred[(int, object)|None]: (response code, response object) as returned
+ by the callback method. None if the request has already been handled.
+ """
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
@@ -190,10 +261,10 @@ class BaseFederationServlet(object):
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
- logger.exception("authenticate_request failed")
+ logger.warn("authenticate_request failed: missing authentication")
raise
- except Exception:
- logger.exception("authenticate_request failed")
+ except Exception as e:
+ logger.warn("authenticate_request failed: %s", e)
raise
if origin:
@@ -259,11 +330,10 @@ class FederationSendServlet(BaseFederationServlet):
)
logger.info(
- "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
+ "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
transaction_id, origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
- len(transaction_data.get("failures", [])),
)
# We should ideally be getting this from the security layer.
@@ -361,8 +431,32 @@ class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
@defer.inlineCallbacks
- def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_join_request(context, user_id)
+ def on_GET(self, origin, _content, query, context, user_id):
+ """
+ Args:
+ origin (unicode): The authenticated server_name of the calling server
+
+ _content (None): (GETs don't have bodies)
+
+ query (dict[bytes, list[bytes]]): Query params from the request.
+
+ **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+ components as specified in the path match regexp.
+
+ Returns:
+ Deferred[(int, object)|None]: either (response code, response object) to
+ return a JSON response, or None if the request has already been handled.
+ """
+ versions = query.get(b'ver')
+ if versions is not None:
+ supported_versions = [v.decode("utf-8") for v in versions]
+ else:
+ supported_versions = ["1"]
+
+ content = yield self.handler.on_make_join_request(
+ origin, context, user_id,
+ supported_versions=supported_versions,
+ )
defer.returnValue((200, content))
@@ -371,15 +465,17 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_leave_request(context, user_id)
+ content = yield self.handler.on_make_leave_request(
+ origin, context, user_id,
+ )
defer.returnValue((200, content))
class FederationSendLeaveServlet(BaseFederationServlet):
- PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<txid>[^/]*)"
+ PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
@defer.inlineCallbacks
- def on_PUT(self, origin, content, query, room_id, txid):
+ def on_PUT(self, origin, content, query, room_id, event_id):
content = yield self.handler.on_send_leave_request(origin, content)
defer.returnValue((200, content))
|