summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
authormatrix.org <matrix@matrix.org>2014-08-12 15:10:52 +0100
committermatrix.org <matrix@matrix.org>2014-08-12 15:10:52 +0100
commit4f475c7697722e946e39e42f38f3dd03a95d8765 (patch)
tree076d96d3809fb836c7245fd9f7960e7b75888a77 /synapse/http
downloadsynapse-4f475c7697722e946e39e42f38f3dd03a95d8765.tar.xz
Reference Matrix Home Server
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/__init__.py14
-rw-r--r--synapse/http/client.py246
-rw-r--r--synapse/http/endpoint.py171
-rw-r--r--synapse/http/server.py181
4 files changed, 612 insertions, 0 deletions
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
new file mode 100644
index 0000000000..fe8a073cd3
--- /dev/null
+++ b/synapse/http/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/http/client.py b/synapse/http/client.py
new file mode 100644
index 0000000000..bb22b0ee9a
--- /dev/null
+++ b/synapse/http/client.py
@@ -0,0 +1,246 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, reactor
+from twisted.web.client import _AgentBase, _URI, readBody
+from twisted.web.http_headers import Headers
+
+from synapse.http.endpoint import matrix_endpoint
+from synapse.util.async import sleep
+
+from syutil.jsonutil import encode_canonical_json
+
+from synapse.api.errors import CodeMessageException
+
+import json
+import logging
+import urllib
+
+
+logger = logging.getLogger(__name__)
+
+
+_destination_mappings = {
+    "red": "localhost:8080",
+    "blue": "localhost:8081",
+    "green": "localhost:8082",
+}
+
+
+class HttpClient(object):
+    """ Interface for talking json over http
+    """
+
+    def put_json(self, destination, path, data):
+        """ Sends the specifed json data using PUT
+
+        Args:
+            destination (str): The remote server to send the HTTP request
+                to.
+            path (str): The HTTP path.
+            data (dict): A dict containing the data that will be used as
+                the request body. This will be encoded as JSON.
+
+        Returns:
+            Deferred: Succeeds when we get *any* HTTP response.
+
+            The result of the deferred is a tuple of `(code, response)`,
+            where `response` is a dict representing the decoded JSON body.
+        """
+        pass
+
+    def get_json(self, destination, path, args=None):
+        """ Get's some json from the given host homeserver and path
+
+        Args:
+            destination (str): The remote server to send the HTTP request
+                to.
+            path (str): The HTTP path.
+            args (dict): A dictionary used to create query strings, defaults to
+                None.
+                **Note**: The value of each key is assumed to be an iterable
+                and *not* a string.
+
+        Returns:
+            Deferred: Succeeds when we get *any* HTTP response.
+
+            The result of the deferred is a tuple of `(code, response)`,
+            where `response` is a dict representing the decoded JSON body.
+        """
+        pass
+
+
+class MatrixHttpAgent(_AgentBase):
+
+    def __init__(self, reactor, pool=None):
+        _AgentBase.__init__(self, reactor, pool)
+
+    def request(self, destination, endpoint, method, path, params, query,
+                headers, body_producer):
+
+        host = b""
+        port = 0
+        fragment = b""
+
+        parsed_URI = _URI(b"http", destination, host, port, path, params,
+                          query, fragment)
+
+        # Set the connection pool key to be the destination.
+        key = destination
+
+        return self._requestWithEndpoint(key, endpoint, method, parsed_URI,
+                                         headers, body_producer,
+                                         parsed_URI.originForm)
+
+
+class TwistedHttpClient(HttpClient):
+    """ Wrapper around the twisted HTTP client api.
+
+    Attributes:
+        agent (twisted.web.client.Agent): The twisted Agent used to send the
+            requests.
+    """
+
+    def __init__(self):
+        self.agent = MatrixHttpAgent(reactor)
+
+    @defer.inlineCallbacks
+    def put_json(self, destination, path, data):
+        if destination in _destination_mappings:
+            destination = _destination_mappings[destination]
+
+        response = yield self._create_request(
+            destination.encode("ascii"),
+            "PUT",
+            path.encode("ascii"),
+            producer=_JsonProducer(data),
+            headers_dict={"Content-Type": ["application/json"]}
+        )
+
+        logger.debug("Getting resp body")
+        body = yield readBody(response)
+        logger.debug("Got resp body")
+
+        defer.returnValue((response.code, body))
+
+    @defer.inlineCallbacks
+    def get_json(self, destination, path, args={}):
+        if destination in _destination_mappings:
+            destination = _destination_mappings[destination]
+
+        logger.debug("get_json args: %s", args)
+        query_bytes = urllib.urlencode(args, True)
+
+        response = yield self._create_request(
+            destination.encode("ascii"),
+            "GET",
+            path.encode("ascii"),
+            query_bytes
+        )
+
+        body = yield readBody(response)
+
+        defer.returnValue(json.loads(body))
+
+    @defer.inlineCallbacks
+    def _create_request(self, destination, method, path_bytes, param_bytes=b"",
+                        query_bytes=b"", producer=None, headers_dict={}):
+        """ Creates and sends a request to the given url
+        """
+        headers_dict[b"User-Agent"] = [b"Synapse"]
+        headers_dict[b"Host"] = [destination]
+
+        logger.debug("Sending request to %s: %s %s;%s?%s",
+                     destination, method, path_bytes, param_bytes, query_bytes)
+
+        logger.debug(
+            "Types: %s",
+            [
+                type(destination), type(method), type(path_bytes),
+                type(param_bytes),
+                type(query_bytes)
+            ]
+        )
+
+        retries_left = 5
+
+        # TODO: setup and pass in an ssl_context to enable TLS
+        endpoint = matrix_endpoint(reactor, destination, timeout=10)
+
+        while True:
+            try:
+                response = yield self.agent.request(
+                    destination,
+                    endpoint,
+                    method,
+                    path_bytes,
+                    param_bytes,
+                    query_bytes,
+                    Headers(headers_dict),
+                    producer
+                )
+
+                logger.debug("Got response to %s", method)
+                break
+            except Exception as e:
+                logger.exception("Got error in _create_request")
+                _print_ex(e)
+
+                if retries_left:
+                    yield sleep(2 ** (5 - retries_left))
+                    retries_left -= 1
+                else:
+                    raise
+
+        if 200 <= response.code < 300:
+            # We need to update the transactions table to say it was sent?
+            pass
+        else:
+            # :'(
+            # Update transactions table?
+            logger.error(
+                "Got response %d %s", response.code, response.phrase
+            )
+            raise CodeMessageException(
+                response.code, response.phrase
+            )
+
+        defer.returnValue(response)
+
+
+def _print_ex(e):
+    if hasattr(e, "reasons") and e.reasons:
+        for ex in e.reasons:
+            _print_ex(ex)
+    else:
+        logger.exception(e)
+
+
+class _JsonProducer(object):
+    """ Used by the twisted http client to create the HTTP body from json
+    """
+    def __init__(self, jsn):
+        self.body = encode_canonical_json(jsn)
+        self.length = len(self.body)
+
+    def startProducing(self, consumer):
+        consumer.write(self.body)
+        return defer.succeed(None)
+
+    def pauseProducing(self):
+        pass
+
+    def stopProducing(self):
+        pass
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
new file mode 100644
index 0000000000..c4e6e63a80
--- /dev/null
+++ b/synapse/http/endpoint.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
+from twisted.internet import defer
+from twisted.internet.error import ConnectError
+from twisted.names import client, dns
+from twisted.names.error import DNSNameError
+
+import collections
+import logging
+import random
+
+
+logger = logging.getLogger(__name__)
+
+
+def matrix_endpoint(reactor, destination, ssl_context_factory=None,
+                    timeout=None):
+    """Construct an endpoint for the given matrix destination.
+
+    Args:
+        reactor: Twisted reactor.
+        destination (bytes): The name of the server to connect to.
+        ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory
+            which generates SSL contexts to use for TLS.
+        timeout (int): connection timeout in seconds
+    """
+
+    domain_port = destination.split(":")
+    domain = domain_port[0]
+    port = int(domain_port[1]) if domain_port[1:] else None
+
+    endpoint_kw_args = {}
+
+    if timeout is not None:
+        endpoint_kw_args.update(timeout=timeout)
+
+    if ssl_context_factory is None:
+        transport_endpoint = TCP4ClientEndpoint
+        default_port = 8080
+    else:
+        transport_endpoint = SSL4ClientEndpoint
+        endpoint_kw_args.update(ssl_context_factory=ssl_context_factory)
+        default_port = 443
+
+    if port is None:
+        return SRVClientEndpoint(
+            reactor, "matrix", domain, protocol="tcp",
+            default_port=default_port, endpoint=transport_endpoint,
+            endpoint_kw_args=endpoint_kw_args
+        )
+    else:
+        return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
+
+
+class SRVClientEndpoint(object):
+    """An endpoint which looks up SRV records for a service.
+    Cycles through the list of servers starting with each call to connect
+    picking the next server.
+    Implements twisted.internet.interfaces.IStreamClientEndpoint.
+    """
+
+    _Server = collections.namedtuple(
+        "_Server", "priority weight host port"
+    )
+
+    def __init__(self, reactor, service, domain, protocol="tcp",
+                 default_port=None, endpoint=TCP4ClientEndpoint,
+                 endpoint_kw_args={}):
+        self.reactor = reactor
+        self.service_name = "_%s._%s.%s" % (service, protocol, domain)
+
+        if default_port is not None:
+            self.default_server = self._Server(
+                host=domain,
+                port=default_port,
+                priority=0,
+                weight=0
+            )
+        else:
+            self.default_server = None
+
+        self.endpoint = endpoint
+        self.endpoint_kw_args = endpoint_kw_args
+
+        self.servers = None
+        self.used_servers = None
+
+    @defer.inlineCallbacks
+    def fetch_servers(self):
+        try:
+            answers, auth, add = yield client.lookupService(self.service_name)
+        except DNSNameError:
+            answers = []
+
+        if (len(answers) == 1
+                and answers[0].type == dns.SRV
+                and answers[0].payload
+                and answers[0].payload.target == dns.Name('.')):
+            raise ConnectError("Service %s unavailable", self.service_name)
+
+        self.servers = []
+        self.used_servers = []
+
+        for answer in answers:
+            if answer.type != dns.SRV or not answer.payload:
+                continue
+            payload = answer.payload
+            self.servers.append(self._Server(
+                host=str(payload.target),
+                port=int(payload.port),
+                priority=int(payload.priority),
+                weight=int(payload.weight)
+            ))
+
+        self.servers.sort()
+
+    def pick_server(self):
+        if not self.servers:
+            if self.used_servers:
+                self.servers = self.used_servers
+                self.used_servers = []
+                self.servers.sort()
+            elif self.default_server:
+                return self.default_server
+            else:
+                raise ConnectError(
+                    "Not server available for %s", self.service_name
+                )
+
+        min_priority = self.servers[0].priority
+        weight_indexes = list(
+            (index, server.weight + 1)
+            for index, server in enumerate(self.servers)
+            if server.priority == min_priority
+        )
+
+        total_weight = sum(weight for index, weight in weight_indexes)
+        target_weight = random.randint(0, total_weight)
+
+        for index, weight in weight_indexes:
+            target_weight -= weight
+            if target_weight <= 0:
+                server = self.servers[index]
+                del self.servers[index]
+                self.used_servers.append(server)
+                return server
+
+    @defer.inlineCallbacks
+    def connect(self, protocolFactory):
+        if self.servers is None:
+            yield self.fetch_servers()
+        server = self.pick_server()
+        logger.info("Connecting to %s:%s", server.host, server.port)
+        endpoint = self.endpoint(
+            self.reactor, server.host, server.port, **self.endpoint_kw_args
+        )
+        connection = yield endpoint.connect(protocolFactory)
+        defer.returnValue(connection)
diff --git a/synapse/http/server.py b/synapse/http/server.py
new file mode 100644
index 0000000000..8823aade78
--- /dev/null
+++ b/synapse/http/server.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from syutil.jsonutil import (
+    encode_canonical_json, encode_pretty_printed_json
+)
+from synapse.api.errors import cs_exception, CodeMessageException
+
+from twisted.internet import defer, reactor
+from twisted.web import server, resource
+from twisted.web.server import NOT_DONE_YET
+
+import collections
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class HttpServer(object):
+    """ Interface for registering callbacks on a HTTP server
+    """
+
+    def register_path(self, method, path_pattern, callback):
+        """ Register a callback that get's fired if we receive a http request
+        with the given method for a path that matches the given regex.
+
+        If the regex contains groups these get's passed to the calback via
+        an unpacked tuple.
+
+        Args:
+            method (str): The method to listen to.
+            path_pattern (str): The regex used to match requests.
+            callback (function): The function to fire if we receive a matched
+                request. The first argument will be the request object and
+                subsequent arguments will be any matched groups from the regex.
+                This should return a tuple of (code, response).
+        """
+        pass
+
+
+# The actual HTTP server impl, using twisted http server
+class TwistedHttpServer(HttpServer, resource.Resource):
+    """ This wraps the twisted HTTP server, and triggers the correct callbacks
+    on the transport_layer.
+
+    Register callbacks via register_path()
+    """
+
+    isLeaf = True
+
+    _PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
+
+    def __init__(self):
+        resource.Resource.__init__(self)
+
+        self.path_regexs = {}
+
+    def register_path(self, method, path_pattern, callback):
+        self.path_regexs.setdefault(method, []).append(
+            self._PathEntry(path_pattern, callback)
+        )
+
+    def start_listening(self, port):
+        """ Registers the http server with the twisted reactor.
+
+        Args:
+            port (int): The port to listen on.
+
+        """
+        reactor.listenTCP(port, server.Site(self))
+
+    # Gets called by twisted
+    def render(self, request):
+        """ This get's called by twisted every time someone sends us a request.
+        """
+        self._async_render(request)
+        return server.NOT_DONE_YET
+
+    @defer.inlineCallbacks
+    def _async_render(self, request):
+        """ This get's called by twisted every time someone sends us a request.
+            This checks if anyone has registered a callback for that method and
+            path.
+        """
+        try:
+            # Loop through all the registered callbacks to check if the method
+            # and path regex match
+            for path_entry in self.path_regexs.get(request.method, []):
+                m = path_entry.pattern.match(request.path)
+                if m:
+                    # We found a match! Trigger callback and then return the
+                    # returned response. We pass both the request and any
+                    # matched groups from the regex to the callback.
+                    code, response = yield path_entry.callback(
+                        request,
+                        *m.groups()
+                    )
+
+                    self._send_response(request, code, response)
+                    return
+
+            # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
+            self._send_response(
+                request,
+                400,
+                {"error": "Unrecognized request"}
+            )
+        except CodeMessageException as e:
+            logger.exception(e)
+            self._send_response(
+                request,
+                e.code,
+                cs_exception(e)
+            )
+        except Exception as e:
+            logger.exception(e)
+            self._send_response(
+                request,
+                500,
+                {"error": "Internal server error"}
+            )
+
+    def _send_response(self, request, code, response_json_object):
+
+        if not self._request_user_agent_is_curl(request):
+            json_bytes = encode_canonical_json(response_json_object)
+        else:
+            json_bytes = encode_pretty_printed_json(response_json_object)
+
+        # TODO: Only enable CORS for the requests that need it.
+        respond_with_json_bytes(request, code, json_bytes, send_cors=True)
+
+    @staticmethod
+    def _request_user_agent_is_curl(request):
+        user_agents = request.requestHeaders.getRawHeaders(
+            "User-Agent", default=[]
+        )
+        for user_agent in user_agents:
+            if "curl" in user_agent:
+                return True
+        return False
+
+
+def respond_with_json_bytes(request, code, json_bytes, send_cors=False):
+    """Sends encoded JSON in response to the given request.
+
+    Args:
+        request (twisted.web.http.Request): The http request to respond to.
+        code (int): The HTTP response code.
+        json_bytes (bytes): The json bytes to use as the response body.
+        send_cors (bool): Whether to send Cross-Origin Resource Sharing headers
+            http://www.w3.org/TR/cors/
+    Returns:
+        twisted.web.server.NOT_DONE_YET"""
+
+    request.setResponseCode(code)
+    request.setHeader(b"Content-Type", b"application/json")
+
+    if send_cors:
+        request.setHeader("Access-Control-Allow-Origin", "*")
+        request.setHeader("Access-Control-Allow-Methods",
+                          "GET, POST, PUT, DELETE, OPTIONS")
+        request.setHeader("Access-Control-Allow-Headers",
+                          "Origin, X-Requested-With, Content-Type, Accept")
+
+    request.write(json_bytes)
+    request.finish()
+    return NOT_DONE_YET