summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/__init__.py1
-rw-r--r--synapse/federation/pdu_codec.py4
-rw-r--r--synapse/federation/persistence.py2
-rw-r--r--synapse/federation/replication.py29
-rw-r--r--synapse/federation/transport.py163
-rw-r--r--synapse/federation/units.py27
6 files changed, 156 insertions, 70 deletions
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 1351b68fd6..0112588656 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -22,6 +22,7 @@ from .transport import TransportLayer
 
 def initialize_http_replication(homeserver):
     transport = TransportLayer(
+        homeserver,
         homeserver.hostname,
         server=homeserver.get_resource_for_federation(),
         client=homeserver.get_http_client()
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
index cef61108dd..e8180d94fd 100644
--- a/synapse/federation/pdu_codec.py
+++ b/synapse/federation/pdu_codec.py
@@ -96,7 +96,7 @@ class PduCodec(object):
             if k not in ["event_id", "room_id", "type", "prev_events"]
         })
 
-        if "ts" not in kwargs:
-            kwargs["ts"] = int(self.clock.time_msec())
+        if "origin_server_ts" not in kwargs:
+            kwargs["origin_server_ts"] = int(self.clock.time_msec())
 
         return Pdu(**kwargs)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index de36a80e41..7043fcc504 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -157,7 +157,7 @@ class TransactionActions(object):
         transaction.prev_ids = yield self.store.prep_send_transaction(
             transaction.transaction_id,
             transaction.destination,
-            transaction.ts,
+            transaction.origin_server_ts,
             [(p["pdu_id"], p["origin"]) for p in transaction.pdus]
         )
 
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 96b82f00cb..092411eaf9 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -159,7 +159,8 @@ class ReplicationLayer(object):
         return defer.succeed(None)
 
     @log_function
-    def make_query(self, destination, query_type, args):
+    def make_query(self, destination, query_type, args,
+                   retry_on_dns_fail=True):
         """Sends a federation Query to a remote homeserver of the given type
         and arguments.
 
@@ -174,7 +175,9 @@ class ReplicationLayer(object):
             a Deferred which will eventually yield a JSON object from the
             response
         """
-        return self.transport_layer.make_query(destination, query_type, args)
+        return self.transport_layer.make_query(
+            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
+        )
 
     @defer.inlineCallbacks
     @log_function
@@ -316,7 +319,7 @@ class ReplicationLayer(object):
 
         if hasattr(transaction, "edus"):
             for edu in [Edu(**x) for x in transaction.edus]:
-                self.received_edu(edu.origin, edu.edu_type, edu.content)
+                self.received_edu(transaction.origin, edu.edu_type, edu.content)
 
         results = yield defer.DeferredList(dl)
 
@@ -418,7 +421,7 @@ class ReplicationLayer(object):
         return Transaction(
             origin=self.server_name,
             pdus=pdus,
-            ts=int(self._clock.time_msec()),
+            origin_server_ts=int(self._clock.time_msec()),
             destination=None,
         )
 
@@ -489,7 +492,6 @@ class _TransactionQueue(object):
     """
 
     def __init__(self, hs, transaction_actions, transport_layer):
-
         self.server_name = hs.hostname
         self.transaction_actions = transaction_actions
         self.transport_layer = transport_layer
@@ -587,8 +589,8 @@ class _TransactionQueue(object):
             logger.debug("TX [%s] Persisting transaction...", destination)
 
             transaction = Transaction.create_new(
-                ts=self._clock.time_msec(),
-                transaction_id=self._next_txn_id,
+                origin_server_ts=self._clock.time_msec(),
+                transaction_id=str(self._next_txn_id),
                 origin=self.server_name,
                 destination=destination,
                 pdus=pdus,
@@ -606,18 +608,17 @@ class _TransactionQueue(object):
 
             # FIXME (erikj): This is a bit of a hack to make the Pdu age
             # keys work
-            def cb(transaction):
+            def json_data_cb():
+                data = transaction.get_dict()
                 now = int(self._clock.time_msec())
-                if "pdus" in transaction:
-                    for p in transaction["pdus"]:
+                if "pdus" in data:
+                    for p in data["pdus"]:
                         if "age_ts" in p:
                             p["age"] = now - int(p["age_ts"])
-
-                return transaction
+                return data
 
             code, response = yield self.transport_layer.send_transaction(
-                transaction,
-                on_send_callback=cb,
+                transaction, json_data_cb
             )
 
             logger.debug("TX [%s] Sent transaction", destination)
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index afc777ec9e..e7517cac4d 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -24,6 +24,7 @@ over a different (albeit still reliable) protocol.
 from twisted.internet import defer
 
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
+from synapse.api.errors import Codes, SynapseError
 from synapse.util.logutils import log_function
 
 import logging
@@ -54,7 +55,7 @@ class TransportLayer(object):
             we receive data.
     """
 
-    def __init__(self, server_name, server, client):
+    def __init__(self, homeserver, server_name, server, client):
         """
         Args:
             server_name (str): Local home server host
@@ -63,6 +64,7 @@ class TransportLayer(object):
             client (synapse.protocol.http.HttpClient): the http client used to
                 send requests
         """
+        self.keyring = homeserver.get_keyring()
         self.server_name = server_name
         self.server = server
         self.client = client
@@ -144,7 +146,7 @@ class TransportLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def send_transaction(self, transaction, on_send_callback=None):
+    def send_transaction(self, transaction, json_data_callback=None):
         """ Sends the given Transaction to it's destination
 
         Args:
@@ -163,25 +165,15 @@ class TransportLayer(object):
         if transaction.destination == self.server_name:
             raise RuntimeError("Transport layer cannot send to itself!")
 
-        data = transaction.get_dict()
-
-        # FIXME (erikj): This is a bit of a hack to make the Pdu age
-        # keys work
-        def cb(destination, method, path_bytes, producer):
-            if not on_send_callback:
-                return
-
-            transaction = json.loads(producer.body)
-
-            new_transaction = on_send_callback(transaction)
-
-            producer.reset(new_transaction)
+        # FIXME: This is only used by the tests. The actual json sent is
+        # generated by the json_data_callback.
+        json_data = transaction.get_dict()
 
         code, response = yield self.client.put_json(
             transaction.destination,
             path=PREFIX + "/send/%s/" % transaction.transaction_id,
-            data=data,
-            on_send_callback=cb,
+            data=json_data,
+            json_data_callback=json_data_callback,
         )
 
         logger.debug(
@@ -193,17 +185,93 @@ class TransportLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def make_query(self, destination, query_type, args):
+    def make_query(self, destination, query_type, args, retry_on_dns_fail):
         path = PREFIX + "/query/%s" % query_type
 
         response = yield self.client.get_json(
             destination=destination,
             path=path,
-            args=args
+            args=args,
+            retry_on_dns_fail=retry_on_dns_fail,
         )
 
         defer.returnValue(response)
 
+    @defer.inlineCallbacks
+    def _authenticate_request(self, request):
+        json_request = {
+            "method": request.method,
+            "uri": request.uri,
+            "destination": self.server_name,
+            "signatures": {},
+        }
+
+        content = None
+        origin = None
+
+        if request.method == "PUT":
+            #TODO: Handle other method types? other content types?
+            try:
+                content_bytes = request.content.read()
+                content = json.loads(content_bytes)
+                json_request["content"] = content
+            except:
+                raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
+
+        def parse_auth_header(header_str):
+            try:
+                params = auth.split(" ")[1].split(",")
+                param_dict = dict(kv.split("=") for kv in params)
+                def strip_quotes(value):
+                    if value.startswith("\""):
+                        return value[1:-1]
+                    else:
+                        return value
+                origin = strip_quotes(param_dict["origin"])
+                key = strip_quotes(param_dict["key"])
+                sig = strip_quotes(param_dict["sig"])
+                return (origin, key, sig)
+            except:
+                raise SynapseError(
+                    400, "Malformed Authorization header", Codes.UNAUTHORIZED
+                )
+
+        auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+        if not auth_headers:
+            raise SynapseError(
+                401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+            )
+
+        for auth in auth_headers:
+            if auth.startswith("X-Matrix"):
+                (origin, key, sig) = parse_auth_header(auth)
+                json_request["origin"] = origin
+                json_request["signatures"].setdefault(origin,{})[key] = sig
+
+        if not json_request["signatures"]:
+            raise SynapseError(
+                401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+            )
+
+        yield self.keyring.verify_json_for_server(origin, json_request)
+
+        defer.returnValue((origin, content))
+
+    def _with_authentication(self, handler):
+        @defer.inlineCallbacks
+        def new_handler(request, *args, **kwargs):
+            try:
+                (origin, content) = yield self._authenticate_request(request)
+                response = yield handler(
+                    origin, content, request.args, *args, **kwargs
+                )
+            except:
+                logger.exception("_authenticate_request failed")
+                raise
+            defer.returnValue(response)
+        return new_handler
+
     @log_function
     def register_received_handler(self, handler):
         """ Register a handler that will be fired when we receive data.
@@ -217,7 +285,7 @@ class TransportLayer(object):
         self.server.register_path(
             "PUT",
             re.compile("^" + PREFIX + "/send/([^/]*)/$"),
-            self._on_send_request
+            self._with_authentication(self._on_send_request)
         )
 
     @log_function
@@ -235,9 +303,9 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/pull/$"),
-            lambda request: handler.on_pull_request(
-                request.args["origin"][0],
-                request.args["v"]
+            self._with_authentication(
+                lambda origin, content, query:
+                handler.on_pull_request(query["origin"][0], query["v"])
             )
         )
 
@@ -246,8 +314,9 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
-            lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
-                pdu_origin, pdu_id
+            self._with_authentication(
+                lambda origin, content, query, pdu_origin, pdu_id:
+                handler.on_pdu_request(pdu_origin, pdu_id)
             )
         )
 
@@ -255,38 +324,47 @@ class TransportLayer(object):
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/state/([^/]*)/$"),
-            lambda request, context: handler.on_context_state_request(
-                context
+            self._with_authentication(
+                lambda origin, content, query, context:
+                handler.on_context_state_request(context)
             )
         )
 
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
-            lambda request, context: self._on_backfill_request(
-                context, request.args["v"],
-                request.args["limit"]
+            self._with_authentication(
+                lambda origin, content, query, context:
+                self._on_backfill_request(
+                    context, query["v"], query["limit"]
+                )
             )
         )
 
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/context/([^/]*)/$"),
-            lambda request, context: handler.on_context_pdus_request(context)
+            self._with_authentication(
+                lambda origin, content, query, context:
+                handler.on_context_pdus_request(context)
+            )
         )
 
         # This is when we receive a server-server Query
         self.server.register_path(
             "GET",
             re.compile("^" + PREFIX + "/query/([^/]*)$"),
-            lambda request, query_type: handler.on_query_request(
-                query_type, {k: v[0] for k, v in request.args.items()}
+            self._with_authentication(
+                lambda origin, content, query, query_type:
+                handler.on_query_request(
+                    query_type, {k: v[0] for k, v in query.items()}
+                )
             )
         )
 
     @defer.inlineCallbacks
     @log_function
-    def _on_send_request(self, request, transaction_id):
+    def _on_send_request(self, origin, content, query, transaction_id):
         """ Called on PUT /send/<transaction_id>/
 
         Args:
@@ -301,12 +379,7 @@ class TransportLayer(object):
         """
         # Parse the request
         try:
-            data = request.content.read()
-
-            l = data[:20].encode("string_escape")
-            logger.debug("Got data: \"%s\"", l)
-
-            transaction_data = json.loads(data)
+            transaction_data = content
 
             logger.debug(
                 "Decoded %s: %s",
@@ -328,9 +401,13 @@ class TransportLayer(object):
             defer.returnValue((400, {"error": "Invalid transaction"}))
             return
 
-        code, response = yield self.received_handler.on_incoming_transaction(
-            transaction_data
-        )
+        try:
+            code, response = yield self.received_handler.on_incoming_transaction(
+                transaction_data
+            )
+        except:
+            logger.exception("on_incoming_transaction failed")
+            raise
 
         defer.returnValue((code, response))
 
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 622fe66a8f..b2fb964180 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -40,7 +40,7 @@ class Pdu(JsonEncodedObject):
 
         {
             "pdu_id": "78c",
-            "ts": 1404835423000,
+            "origin_server_ts": 1404835423000,
             "origin": "bar",
             "prev_ids": [
                 ["23b", "foo"],
@@ -55,7 +55,7 @@ class Pdu(JsonEncodedObject):
         "pdu_id",
         "context",
         "origin",
-        "ts",
+        "origin_server_ts",
         "pdu_type",
         "destinations",
         "transaction_id",
@@ -82,7 +82,7 @@ class Pdu(JsonEncodedObject):
         "pdu_id",
         "context",
         "origin",
-        "ts",
+        "origin_server_ts",
         "pdu_type",
         "content",
     ]
@@ -118,6 +118,7 @@ class Pdu(JsonEncodedObject):
         """
         if pdu_tuple:
             d = copy.copy(pdu_tuple.pdu_entry._asdict())
+            d["origin_server_ts"] = d.pop("ts")
 
             d["content"] = json.loads(d["content_json"])
             del d["content_json"]
@@ -156,11 +157,15 @@ class Edu(JsonEncodedObject):
     ]
 
     required_keys = [
-        "origin",
-        "destination",
         "edu_type",
     ]
 
+#    TODO: SYN-103: Remove "origin" and "destination" keys.
+#    internal_keys = [
+#        "origin",
+#        "destination",
+#    ]
+
 
 class Transaction(JsonEncodedObject):
     """ A transaction is a list of Pdus and Edus to be sent to a remote home
@@ -182,10 +187,12 @@ class Transaction(JsonEncodedObject):
         "transaction_id",
         "origin",
         "destination",
-        "ts",
+        "origin_server_ts",
         "previous_ids",
         "pdus",
         "edus",
+        "transaction_id",
+        "destination",
     ]
 
     internal_keys = [
@@ -197,7 +204,7 @@ class Transaction(JsonEncodedObject):
         "transaction_id",
         "origin",
         "destination",
-        "ts",
+        "origin_server_ts",
         "pdus",
     ]
 
@@ -219,10 +226,10 @@ class Transaction(JsonEncodedObject):
     @staticmethod
     def create_new(pdus, **kwargs):
         """ Used to create a new transaction. Will auto fill out
-        transaction_id and ts keys.
+        transaction_id and origin_server_ts keys.
         """
-        if "ts" not in kwargs:
-            raise KeyError("Require 'ts' to construct a Transaction")
+        if "origin_server_ts" not in kwargs:
+            raise KeyError("Require 'origin_server_ts' to construct a Transaction")
         if "transaction_id" not in kwargs:
             raise KeyError(
                 "Require 'transaction_id' to construct a Transaction"