diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 0bc6e0801d..ee8f94e340 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
-from synapse.http.servlet import parse_json_object_from_request, parse_string
+from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter
import functools
@@ -60,6 +60,16 @@ class TransportLayerServer(JsonResource):
)
+class AuthenticationError(SynapseError):
+ """There was a problem authenticating the request"""
+ pass
+
+
+class NoAuthenticationError(AuthenticationError):
+ """The request had no authentication information"""
+ pass
+
+
class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
@@ -67,7 +77,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
- def authenticate_request(self, request):
+ def authenticate_request(self, request, content):
json_request = {
"method": request.method,
"uri": request.uri,
@@ -75,17 +85,10 @@ class Authenticator(object):
"signatures": {},
}
- content = None
- origin = None
+ if content is not None:
+ json_request["content"] = content
- if request.method in ["PUT", "POST"]:
- # TODO: Handle other method types? other content types?
- try:
- content_bytes = request.content.read()
- content = json.loads(content_bytes)
- json_request["content"] = content
- except:
- raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
+ origin = None
def parse_auth_header(header_str):
try:
@@ -103,14 +106,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
except:
- raise SynapseError(
+ raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
- raise SynapseError(
+ raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
@@ -121,7 +124,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]:
- raise SynapseError(
+ raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
@@ -130,10 +133,12 @@ class Authenticator(object):
logger.info("Request from %s", origin)
request.authenticated_entity = origin
- defer.returnValue((origin, content))
+ defer.returnValue(origin)
class BaseFederationServlet(object):
+ REQUIRE_AUTH = True
+
def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler):
self.handler = handler
@@ -141,29 +146,46 @@ class BaseFederationServlet(object):
self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler
- def _wrap(self, code):
+ def _wrap(self, func):
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@defer.inlineCallbacks
- @functools.wraps(code)
- def new_code(request, *args, **kwargs):
+ @functools.wraps(func)
+ def new_func(request, *args, **kwargs):
+ content = None
+ if request.method in ["PUT", "POST"]:
+ # TODO: Handle other method types? other content types?
+ content = parse_json_object_from_request(request)
+
try:
- (origin, content) = yield authenticator.authenticate_request(request)
+ origin = yield authenticator.authenticate_request(request, content)
+ except NoAuthenticationError:
+ origin = None
+ if self.REQUIRE_AUTH:
+ logger.exception("authenticate_request failed")
+ raise
+ except:
+ logger.exception("authenticate_request failed")
+ raise
+
+ if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
- response = yield code(
+ response = yield func(
origin, content, request.args, *args, **kwargs
)
- except:
- logger.exception("authenticate_request failed")
- raise
+ else:
+ response = yield func(
+ origin, content, request.args, *args, **kwargs
+ )
+
defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish
- new_code.__self__ = code.__self__
+ new_func.__self__ = func.__self__
- return new_code
+ return new_func
def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$")
@@ -429,9 +451,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind"
+ REQUIRE_AUTH = False
+
@defer.inlineCallbacks
- def on_POST(self, request):
- content = parse_json_object_from_request(request)
+ def on_POST(self, origin, content, query):
if "invites" in content:
last_exception = None
for invite in content["invites"]:
@@ -453,11 +476,6 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception
defer.returnValue((200, {}))
- # Avoid doing remote HS authorization checks which are done by default by
- # BaseFederationServlet.
- def _wrap(self, code):
- return code
-
class OpenIdUserInfo(BaseFederationServlet):
"""
@@ -478,9 +496,11 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo"
+ REQUIRE_AUTH = False
+
@defer.inlineCallbacks
- def on_GET(self, request):
- token = parse_string(request, "access_token")
+ def on_GET(self, origin, content, query):
+ token = query.get("access_token", [None])[0]
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@@ -497,11 +517,6 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id}))
- # Avoid doing remote HS authorization checks which are done by default by
- # BaseFederationServlet.
- def _wrap(self, code):
- return code
-
class PublicRoomList(BaseFederationServlet):
"""
|