summary refs log tree commit diff
path: root/synapse/federation/transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/transport.py')
-rw-r--r--synapse/federation/transport.py163
1 files changed, 120 insertions, 43 deletions
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index afc777ec9e..e7517cac4d 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -24,6 +24,7 @@ over a different (albeit still reliable) protocol.
 from twisted.internet import defer
 
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
+from synapse.api.errors import Codes, SynapseError
 from synapse.util.logutils import log_function
 
 import logging
@@ -54,7 +55,7 @@ class TransportLayer(object):
             we receive data.
     """
 
-    def __init__(self, server_name, server, client):
+    def __init__(self, homeserver, server_name, server, client):
         """
         Args:
             server_name (str): Local home server host
@@ -63,6 +64,7 @@ class TransportLayer(object):
             client (synapse.protocol.http.HttpClient): the http client used to
                 send requests
         """
+        self.keyring = homeserver.get_keyring()
         self.server_name = server_name
         self.server = server
         self.client = client
@@ -144,7 +146,7 @@ class TransportLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_transaction(self, transaction, on_send_callback=None):
+    def send_transaction(self, transaction, json_data_callback=None):
         """ Sends the given Transaction to it's destination
 
         Args:
@@ -163,25 +165,15 @@ class TransportLayer(object):
         if transaction.destination == self.server_name:
             raise RuntimeError("Transport layer cannot send to itself!")
 
-        data = transaction.get_dict()
-
-        # FIXME (erikj): This is a bit of a hack to make the Pdu age
-        # keys work
-        def cb(destination, method, path_bytes, producer):
-            if not on_send_callback:
-                return
-
-            transaction = json.loads(producer.body)
-
-            new_transaction = on_send_callback(transaction)
-
-            producer.reset(new_transaction)
+        # FIXME: This is only used by the tests. The actual json sent is
+        # generated by the json_data_callback.
+        json_data = transaction.get_dict()
 
         code, response = yield self.client.put_json(
             transaction.destination,
             path=PREFIX + "/send/%s/" % transaction.transaction_id,
-            data=data,
-            on_send_callback=cb,
+            data=json_data,
+            json_data_callback=json_data_callback,
         )
 
         logger.debug(
@@ -193,17 +185,93 @@ class TransportLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def make_query(self, destination, query_type, args):
+    def make_query(self, destination, query_type, args, retry_on_dns_fail):
         path = PREFIX + "/query/%s" % query_type
 
         response = yield self.client.get_json(
             destination=destination,
             path=path,
-            args=args
+            args=args,
+            retry_on_dns_fail=retry_on_dns_fail,
         )
 
         defer.returnValue(response)
 
+    @defer.inlineCallbacks
+    def _authenticate_request(self, request):
+        json_request = {
+            "method": request.method,
+            "uri": request.uri,
+            "destination": self.server_name,
+            "signatures": {},
+        }
+
+        content = None
+        origin = None
+
+        if request.method == "PUT":
+            #TODO: Handle other method types? other content types?
+            try:
+                content_bytes = request.content.read()
+                content = json.loads(content_bytes)
+                json_request["content"] = content
+            except:
+                raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
+
+        def parse_auth_header(header_str):
+            try:
+                params = auth.split(" ")[1].split(",")
+                param_dict = dict(kv.split("=") for kv in params)
+                def strip_quotes(value):
+                    if value.startswith("\""):
+                        return value[1:-1]
+                    else:
+                        return value
+                origin = strip_quotes(param_dict["origin"])
+                key = strip_quotes(param_dict["key"])
+                sig = strip_quotes(param_dict["sig"])
+                return (origin, key, sig)
+            except:
+                raise SynapseError(
+                    400, "Malformed Authorization header", Codes.UNAUTHORIZED
+                )
+
+        auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+        if not auth_headers:
+            raise SynapseError(
+                401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+            )
+
+        for auth in auth_headers:
+            if auth.startswith("X-Matrix"):
+                (origin, key, sig) = parse_auth_header(auth)
+                json_request["origin"] = origin
+                json_request["signatures"].setdefault(origin,{})[key] = sig
+
+        if not json_request["signatures"]:
+            raise SynapseError(
+                401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+            )
+
+        yield self.keyring.verify_json_for_server(origin, json_request)
+
+        defer.returnValue((origin, content))
+
+    def _with_authentication(self, handler):
+        @defer.inlineCallbacks
+        def new_handler(request, *args, **kwargs):
+            try:
+                (origin, content) = yield self._authenticate_request(request)
+                response = yield handler(
+                    origin, content, request.args, *args, **kwargs
+                )
+            except:
+                logger.exception("_authenticate_request failed")
+                raise
+            defer.returnValue(response)
+        return new_handler
+
     @log_function
     def register_received_handler(self, handler):
         """ Register a handler that will be fired when we receive data.
@@ -217,7 +285,7 @@ class TransportLayer(object):
         self.server.register_path(
             "PUT",
             re.compile("^" + PREFIX + "/send/([^/]*)/$"),
-            self._on_send_request
+            self._with_authentication(self._on_send_request)
         )
 
     @log_function
@@ -235,9 +303,9 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/pull/$"),
-            lambda request: handler.on_pull_request(
-                request.args["origin"][0],
-                request.args["v"]
+            self._with_authentication(
+                lambda origin, content, query:
+                handler.on_pull_request(query["origin"][0], query["v"])
             )
         )
 
@@ -246,8 +314,9 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
-            lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
-                pdu_origin, pdu_id
+            self._with_authentication(
+                lambda origin, content, query, pdu_origin, pdu_id:
+                handler.on_pdu_request(pdu_origin, pdu_id)
             )
         )
 
@@ -255,38 +324,47 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/state/([^/]*)/$"),
-            lambda request, context: handler.on_context_state_request(
-                context
+            self._with_authentication(
+                lambda origin, content, query, context:
+                handler.on_context_state_request(context)
             )
         )
 
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
-            lambda request, context: self._on_backfill_request(
-                context, request.args["v"],
-                request.args["limit"]
+            self._with_authentication(
+                lambda origin, content, query, context:
+                self._on_backfill_request(
+                    context, query["v"], query["limit"]
+                )
             )
         )
 
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/context/([^/]*)/$"),
-            lambda request, context: handler.on_context_pdus_request(context)
+            self._with_authentication(
+                lambda origin, content, query, context:
+                handler.on_context_pdus_request(context)
+            )
         )
 
         # This is when we receive a server-server Query
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/query/([^/]*)$"),
-            lambda request, query_type: handler.on_query_request(
-                query_type, {k: v[0] for k, v in request.args.items()}
+            self._with_authentication(
+                lambda origin, content, query, query_type:
+                handler.on_query_request(
+                    query_type, {k: v[0] for k, v in query.items()}
+                )
             )
         )
 
     @defer.inlineCallbacks
     @log_function
-    def _on_send_request(self, request, transaction_id):
+    def _on_send_request(self, origin, content, query, transaction_id):
         """ Called on PUT /send/<transaction_id>/
 
         Args:
@@ -301,12 +379,7 @@ class TransportLayer(object):
         """
         # Parse the request
         try:
-            data = request.content.read()
-
-            l = data[:20].encode("string_escape")
-            logger.debug("Got data: \"%s\"", l)
-
-            transaction_data = json.loads(data)
+            transaction_data = content
 
             logger.debug(
                 "Decoded %s: %s",
@@ -328,9 +401,13 @@ class TransportLayer(object):
             defer.returnValue((400, {"error": "Invalid transaction"}))
             return
 
-        code, response = yield self.received_handler.on_incoming_transaction(
-            transaction_data
-        )
+        try:
+            code, response = yield self.received_handler.on_incoming_transaction(
+                transaction_data
+            )
+        except:
+            logger.exception("on_incoming_transaction failed")
+            raise
 
         defer.returnValue((code, response))