summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2014-10-13 14:37:46 +0100
committerMark Haines <mark.haines@matrix.org>2014-10-13 14:37:46 +0100
commit66848557672977e17f0d5ba785b7567b305ccdbe (patch)
tree0a0bee24bf0dd81559744d13e84eace9fa335972 /synapse/federation
parentSYN-75 sign at the request level rather than the transaction level (diff)
downloadsynapse-66848557672977e17f0d5ba785b7567b305ccdbe.tar.xz
Verify signatures for server2server requests
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/__init__.py1
-rw-r--r--synapse/federation/transport.py110
2 files changed, 89 insertions, 22 deletions
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 1351b68fd6..0112588656 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -22,6 +22,7 @@ from .transport import TransportLayer
 
 def initialize_http_replication(homeserver):
     transport = TransportLayer(
+        homeserver,
         homeserver.hostname,
         server=homeserver.get_resource_for_federation(),
         client=homeserver.get_http_client()
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index 48fc9fbf5e..7b2631fbc8 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -54,7 +54,7 @@ class TransportLayer(object):
             we receive data.
     """
 
-    def __init__(self, server_name, server, client):
+    def __init__(self, homeserver, server_name, server, client):
         """
         Args:
             server_name (str): Local home server host
@@ -63,6 +63,7 @@ class TransportLayer(object):
             client (synapse.protocol.http.HttpClient): the http client used to
                 send requests
         """
+        self.keyring = homeserver.get_keyring()
         self.server_name = server_name
         self.server = server
         self.client = client
@@ -195,6 +196,66 @@ class TransportLayer(object):
 
         defer.returnValue(response)
 
+    @defer.inlineCallbacks
+    def _authenticate_request(self, request):
+        json_request = {
+            "method": request.method,
+            "uri": request.uri,
+            "destination": self.server_name,
+            "signatures": {},
+        }
+
+        content = None
+        origin = None
+
+        if request.method == "PUT":
+            #TODO: Handle other method types? other content types?
+            content_bytes = request.content.read()
+            content = json.loads(content_bytes)
+            json_request["content"] = content
+
+        def parse_auth_header(header_str):
+            params = auth.split(" ")[1].split(",")
+            param_dict = dict(kv.split("=") for kv in params)
+            def strip_quotes(value):
+                if value.startswith("\""):
+                    return value[1:-1]
+                else:
+                    return value
+            origin = strip_quotes(param_dict["origin"])
+            key = strip_quotes(param_dict["key"])
+            sig = strip_quotes(param_dict["sig"])
+            return (origin, key, sig)
+
+        auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+        if not auth_headers:
+            #TODO(markjh): Send a 401 response?
+            raise Exception("Missing auth headers")
+
+        for auth in auth_headers:
+            if auth.startswith("X-Matrix"):
+                (origin, key, sig) = parse_auth_header(auth)
+                json_request["origin"] = origin
+                json_request["signatures"].setdefault(origin,{})[key] = sig
+
+        from syutil.jsonutil import encode_canonical_json
+        logger.debug("Checking  %s %s",
+            origin, encode_canonical_json(json_request))
+        yield self.keyring.verify_json_for_server(origin, json_request)
+
+        defer.returnValue((origin, content))
+
+    def _with_authentication(self, handler):
+        @defer.inlineCallbacks
+        def new_handler(request, *args, **kwargs):
+            (origin, content) = yield self._authenticate_request(request)
+            response = yield handler(
+                origin, content, request.args, *args, **kwargs
+            )
+            defer.returnValue(response)
+        return new_handler
+
     @log_function
     def register_received_handler(self, handler):
         """ Register a handler that will be fired when we receive data.
@@ -208,7 +269,7 @@ class TransportLayer(object):
         self.server.register_path(
             "PUT",
             re.compile("^" + PREFIX + "/send/([^/]*)/$"),
-            self._on_send_request
+            self._with_authentication(self._on_send_request)
         )
 
     @log_function
@@ -226,9 +287,9 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/pull/$"),
-            lambda request: handler.on_pull_request(
-                request.args["origin"][0],
-                request.args["v"]
+            self._with_authentication(
+                lambda origin, content, query:
+                handler.on_pull_request(query["origin"][0], query["v"])
             )
         )
 
@@ -237,8 +298,9 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
-            lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
-                pdu_origin, pdu_id
+            self._with_authentication(
+                lambda origin, content, query, pdu_origin, pdu_id:
+                handler.on_pdu_request(pdu_origin, pdu_id)
             )
         )
 
@@ -246,38 +308,47 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/state/([^/]*)/$"),
-            lambda request, context: handler.on_context_state_request(
-                context
+            self._with_authentication(
+                lambda origin, content, query, context:
+                handler.on_context_state_request(context)
             )
         )
 
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
-            lambda request, context: self._on_backfill_request(
-                context, request.args["v"],
-                request.args["limit"]
+            self._with_authentication(
+                lambda origin, content, query, context:
+                self._on_backfill_request(
+                    context, query["v"], query["limit"]
+                )
             )
         )
 
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/context/([^/]*)/$"),
-            lambda request, context: handler.on_context_pdus_request(context)
+            self._with_authentication(
+                lambda origin, content, query, context:
+                handler.on_context_pdus_request(context)
+            )
         )
 
         # This is when we receive a server-server Query
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/query/([^/]*)$"),
-            lambda request, query_type: handler.on_query_request(
-                query_type, {k: v[0] for k, v in request.args.items()}
+            self._with_authentication(
+                lambda origin, content, query, query_type:
+                handler.on_query_request(
+                    query_type, {k: v[0] for k, v in query.items()}
+                )
             )
         )
 
     @defer.inlineCallbacks
     @log_function
-    def _on_send_request(self, request, transaction_id):
+    def _on_send_request(self, origin, content, query, transaction_id):
         """ Called on PUT /send/<transaction_id>/
 
         Args:
@@ -292,12 +363,7 @@ class TransportLayer(object):
         """
         # Parse the request
         try:
-            data = request.content.read()
-
-            l = data[:20].encode("string_escape")
-            logger.debug("Got data: \"%s\"", l)
-
-            transaction_data = json.loads(data)
+            transaction_data = content
 
             logger.debug(
                 "Decoded %s: %s",