diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index 93296af204..755eee8cf6 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -24,6 +24,7 @@ over a different (albeit still reliable) protocol.
from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
+from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function
import logging
@@ -54,7 +55,7 @@ class TransportLayer(object):
we receive data.
"""
- def __init__(self, server_name, server, client):
+ def __init__(self, homeserver, server_name, server, client):
"""
Args:
server_name (str): Local home server host
@@ -63,6 +64,7 @@ class TransportLayer(object):
client (synapse.protocol.http.HttpClient): the http client used to
send requests
"""
+ self.keyring = homeserver.get_keyring()
self.server_name = server_name
self.server = server
self.client = client
@@ -144,7 +146,7 @@ class TransportLayer(object):
@defer.inlineCallbacks
@log_function
- def send_transaction(self, transaction, on_send_callback=None):
+ def send_transaction(self, transaction, json_data_callback=None):
""" Sends the given Transaction to it's destination
Args:
@@ -163,25 +165,15 @@ class TransportLayer(object):
if transaction.destination == self.server_name:
raise RuntimeError("Transport layer cannot send to itself!")
- data = transaction.get_dict()
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def cb(destination, method, path_bytes, producer):
- if not on_send_callback:
- return
-
- transaction = json.loads(producer.body)
-
- new_transaction = on_send_callback(transaction)
-
- producer.reset(new_transaction)
+ # FIXME: This is only used by the tests. The actual json sent is
+ # generated by the json_data_callback.
+ json_data = transaction.get_dict()
code, response = yield self.client.put_json(
transaction.destination,
path=PREFIX + "/send/%s/" % transaction.transaction_id,
- data=data,
- on_send_callback=cb,
+ data=json_data,
+ json_data_callback=json_data_callback,
)
logger.debug(
@@ -205,6 +197,72 @@ class TransportLayer(object):
defer.returnValue(response)
+ @defer.inlineCallbacks
+ def _authenticate_request(self, request):
+ json_request = {
+ "method": request.method,
+ "uri": request.uri,
+ "destination": self.server_name,
+ "signatures": {},
+ }
+
+ content = None
+ origin = None
+
+ if request.method == "PUT":
+ #TODO: Handle other method types? other content types?
+ try:
+ content_bytes = request.content.read()
+ content = json.loads(content_bytes)
+ json_request["content"] = content
+ except:
+ raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
+
+ def parse_auth_header(header_str):
+ try:
+ params = auth.split(" ")[1].split(",")
+ param_dict = dict(kv.split("=") for kv in params)
+ def strip_quotes(value):
+ if value.startswith("\""):
+ return value[1:-1]
+ else:
+ return value
+ origin = strip_quotes(param_dict["origin"])
+ key = strip_quotes(param_dict["key"])
+ sig = strip_quotes(param_dict["sig"])
+ return (origin, key, sig)
+ except:
+ raise SynapseError(
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED
+ )
+
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+ for auth in auth_headers:
+ if auth.startswith("X-Matrix"):
+ (origin, key, sig) = parse_auth_header(auth)
+ json_request["origin"] = origin
+ json_request["signatures"].setdefault(origin,{})[key] = sig
+
+ if not json_request["signatures"]:
+ raise SynapseError(
+ 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+ )
+
+ yield self.keyring.verify_json_for_server(origin, json_request)
+
+ defer.returnValue((origin, content))
+
+ def _with_authentication(self, handler):
+ @defer.inlineCallbacks
+ def new_handler(request, *args, **kwargs):
+ (origin, content) = yield self._authenticate_request(request)
+ response = yield handler(
+ origin, content, request.args, *args, **kwargs
+ )
+ defer.returnValue(response)
+ return new_handler
+
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
@@ -218,7 +276,7 @@ class TransportLayer(object):
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/send/([^/]*)/$"),
- self._on_send_request
+ self._with_authentication(self._on_send_request)
)
@log_function
@@ -236,9 +294,9 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/pull/$"),
- lambda request: handler.on_pull_request(
- request.args["origin"][0],
- request.args["v"]
+ self._with_authentication(
+ lambda origin, content, query:
+ handler.on_pull_request(query["origin"][0], query["v"])
)
)
@@ -247,8 +305,9 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
- lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
- pdu_origin, pdu_id
+ self._with_authentication(
+ lambda origin, content, query, pdu_origin, pdu_id:
+ handler.on_pdu_request(pdu_origin, pdu_id)
)
)
@@ -256,38 +315,47 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
- lambda request, context: handler.on_context_state_request(
- context
+ self._with_authentication(
+ lambda origin, content, query, context:
+ handler.on_context_state_request(context)
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
- lambda request, context: self._on_backfill_request(
- context, request.args["v"],
- request.args["limit"]
+ self._with_authentication(
+ lambda origin, content, query, context:
+ self._on_backfill_request(
+ context, query["v"], query["limit"]
+ )
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/context/([^/]*)/$"),
- lambda request, context: handler.on_context_pdus_request(context)
+ self._with_authentication(
+ lambda origin, content, query, context:
+ handler.on_context_pdus_request(context)
+ )
)
# This is when we receive a server-server Query
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/query/([^/]*)$"),
- lambda request, query_type: handler.on_query_request(
- query_type, {k: v[0] for k, v in request.args.items()}
+ self._with_authentication(
+ lambda origin, content, query, query_type:
+ handler.on_query_request(
+ query_type, {k: v[0] for k, v in query.items()}
+ )
)
)
@defer.inlineCallbacks
@log_function
- def _on_send_request(self, request, transaction_id):
+ def _on_send_request(self, origin, content, query, transaction_id):
""" Called on PUT /send/<transaction_id>/
Args:
@@ -302,12 +370,7 @@ class TransportLayer(object):
"""
# Parse the request
try:
- data = request.content.read()
-
- l = data[:20].encode("string_escape")
- logger.debug("Got data: \"%s\"", l)
-
- transaction_data = json.loads(data)
+ transaction_data = content
logger.debug(
"Decoded %s: %s",
|