diff options
-rw-r--r-- | synapse/federation/__init__.py | 1 | ||||
-rw-r--r-- | synapse/federation/transport.py | 110 | ||||
-rw-r--r-- | synapse/http/client.py | 10 | ||||
-rw-r--r-- | tests/federation/test_federation.py | 1 | ||||
-rw-r--r-- | tests/utils.py | 3 |
5 files changed, 100 insertions, 25 deletions
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index 1351b68fd6..0112588656 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -22,6 +22,7 @@ from .transport import TransportLayer def initialize_http_replication(homeserver): transport = TransportLayer( + homeserver, homeserver.hostname, server=homeserver.get_resource_for_federation(), client=homeserver.get_http_client() diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py index 48fc9fbf5e..7b2631fbc8 100644 --- a/synapse/federation/transport.py +++ b/synapse/federation/transport.py @@ -54,7 +54,7 @@ class TransportLayer(object): we receive data. """ - def __init__(self, server_name, server, client): + def __init__(self, homeserver, server_name, server, client): """ Args: server_name (str): Local home server host @@ -63,6 +63,7 @@ class TransportLayer(object): client (synapse.protocol.http.HttpClient): the http client used to send requests """ + self.keyring = homeserver.get_keyring() self.server_name = server_name self.server = server self.client = client @@ -195,6 +196,66 @@ class TransportLayer(object): defer.returnValue(response) + @defer.inlineCallbacks + def _authenticate_request(self, request): + json_request = { + "method": request.method, + "uri": request.uri, + "destination": self.server_name, + "signatures": {}, + } + + content = None + origin = None + + if request.method == "PUT": + #TODO: Handle other method types? other content types? + content_bytes = request.content.read() + content = json.loads(content_bytes) + json_request["content"] = content + + def parse_auth_header(header_str): + params = auth.split(" ")[1].split(",") + param_dict = dict(kv.split("=") for kv in params) + def strip_quotes(value): + if value.startswith("\""): + return value[1:-1] + else: + return value + origin = strip_quotes(param_dict["origin"]) + key = strip_quotes(param_dict["key"]) + sig = strip_quotes(param_dict["sig"]) + return (origin, key, sig) + + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + + if not auth_headers: + #TODO(markjh): Send a 401 response? + raise Exception("Missing auth headers") + + for auth in auth_headers: + if auth.startswith("X-Matrix"): + (origin, key, sig) = parse_auth_header(auth) + json_request["origin"] = origin + json_request["signatures"].setdefault(origin,{})[key] = sig + + from syutil.jsonutil import encode_canonical_json + logger.debug("Checking %s %s", + origin, encode_canonical_json(json_request)) + yield self.keyring.verify_json_for_server(origin, json_request) + + defer.returnValue((origin, content)) + + def _with_authentication(self, handler): + @defer.inlineCallbacks + def new_handler(request, *args, **kwargs): + (origin, content) = yield self._authenticate_request(request) + response = yield handler( + origin, content, request.args, *args, **kwargs + ) + defer.returnValue(response) + return new_handler + @log_function def register_received_handler(self, handler): """ Register a handler that will be fired when we receive data. @@ -208,7 +269,7 @@ class TransportLayer(object): self.server.register_path( "PUT", re.compile("^" + PREFIX + "/send/([^/]*)/$"), - self._on_send_request + self._with_authentication(self._on_send_request) ) @log_function @@ -226,9 +287,9 @@ class TransportLayer(object): self.server.register_path( "GET", re.compile("^" + PREFIX + "/pull/$"), - lambda request: handler.on_pull_request( - request.args["origin"][0], - request.args["v"] + self._with_authentication( + lambda origin, content, query: + handler.on_pull_request(query["origin"][0], query["v"]) ) ) @@ -237,8 +298,9 @@ class TransportLayer(object): self.server.register_path( "GET", re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"), - lambda request, pdu_origin, pdu_id: handler.on_pdu_request( - pdu_origin, pdu_id + self._with_authentication( + lambda origin, content, query, pdu_origin, pdu_id: + handler.on_pdu_request(pdu_origin, pdu_id) ) ) @@ -246,38 +308,47 @@ class TransportLayer(object): self.server.register_path( "GET", re.compile("^" + PREFIX + "/state/([^/]*)/$"), - lambda request, context: handler.on_context_state_request( - context + self._with_authentication( + lambda origin, content, query, context: + handler.on_context_state_request(context) ) ) self.server.register_path( "GET", re.compile("^" + PREFIX + "/backfill/([^/]*)/$"), - lambda request, context: self._on_backfill_request( - context, request.args["v"], - request.args["limit"] + self._with_authentication( + lambda origin, content, query, context: + self._on_backfill_request( + context, query["v"], query["limit"] + ) ) ) self.server.register_path( "GET", re.compile("^" + PREFIX + "/context/([^/]*)/$"), - lambda request, context: handler.on_context_pdus_request(context) + self._with_authentication( + lambda origin, content, query, context: + handler.on_context_pdus_request(context) + ) ) # This is when we receive a server-server Query self.server.register_path( "GET", re.compile("^" + PREFIX + "/query/([^/]*)$"), - lambda request, query_type: handler.on_query_request( - query_type, {k: v[0] for k, v in request.args.items()} + self._with_authentication( + lambda origin, content, query, query_type: + handler.on_query_request( + query_type, {k: v[0] for k, v in query.items()} + ) ) ) @defer.inlineCallbacks @log_function - def _on_send_request(self, request, transaction_id): + def _on_send_request(self, origin, content, query, transaction_id): """ Called on PUT /send/<transaction_id>/ Args: @@ -292,12 +363,7 @@ class TransportLayer(object): """ # Parse the request try: - data = request.content.read() - - l = data[:20].encode("string_escape") - logger.debug("Got data: \"%s\"", l) - - transaction_data = json.loads(data) + transaction_data = content logger.debug( "Decoded %s: %s", diff --git a/synapse/http/client.py b/synapse/http/client.py index 62fe14fa5e..9f54b74e3a 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -177,16 +177,20 @@ class MatrixHttpClient(BaseHttpClient): request = sign_json(request, self.server_name, self.signing_key) + from syutil.jsonutil import encode_canonical_json + logger.debug("Signing " + " " * 11 + "%s %s", + self.server_name, encode_canonical_json(request)) + auth_headers = [] for key,sig in request["signatures"][self.server_name].items(): - auth_headers.append( + auth_headers.append(bytes( "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( self.server_name, key, sig, ) - ) + )) - headers_dict["Authorization"] = auth_headers + headers_dict[b"Authorization"] = auth_headers @defer.inlineCallbacks def put_json(self, destination, path, data={}, json_data_callback=None): diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py index 91edeaa4b9..8d277d6612 100644 --- a/tests/federation/test_federation.py +++ b/tests/federation/test_federation.py @@ -221,6 +221,7 @@ class FederationTestCase(unittest.TestCase): json_data_callback=ANY, ) + @defer.inlineCallbacks def test_recv_edu(self): recv_observer = Mock() diff --git a/tests/utils.py b/tests/utils.py index 797818be72..83dbd4f4d3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -76,6 +76,9 @@ class MockHttpResource(HttpServer): mock_content.configure_mock(**config) mock_request.content = mock_content + mock_request.method = http_method + mock_request.uri = path + # return the right path if the event requires it mock_request.path = path |