summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/client.py173
-rw-r--r--synapse/http/endpoint.py80
-rw-r--r--synapse/http/server.py107
-rw-r--r--synapse/http/servlet.py81
-rw-r--r--synapse/http/site.py146
5 files changed, 496 insertions, 91 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index cbd45b2bbe..c7fa692435 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -15,17 +15,24 @@
 from OpenSSL import SSL
 from OpenSSL.SSL import VERIFY_NONE
 
-from synapse.api.errors import CodeMessageException
+from synapse.api.errors import (
+    CodeMessageException, SynapseError, Codes,
+)
 from synapse.util.logcontext import preserve_context_over_fn
 import synapse.metrics
+from synapse.http.endpoint import SpiderEndpoint
 
 from canonicaljson import encode_canonical_json
 
-from twisted.internet import defer, reactor, ssl
+from twisted.internet import defer, reactor, ssl, protocol
+from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
 from twisted.web.client import (
-    Agent, readBody, FileBodyProducer, PartialDownloadError,
+    BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
+    readBody, FileBodyProducer, PartialDownloadError,
 )
+from twisted.web.http import PotentialDataLoss
 from twisted.web.http_headers import Headers
+from twisted.web._newclient import ResponseDone
 
 from StringIO import StringIO
 
@@ -238,6 +245,107 @@ class SimpleHttpClient(object):
         else:
             raise CodeMessageException(response.code, body)
 
+    # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
+    # The two should be factored out.
+
+    @defer.inlineCallbacks
+    def get_file(self, url, output_stream, max_size=None):
+        """GETs a file from a given URL
+        Args:
+            url (str): The URL to GET
+            output_stream (file): File to write the response body to.
+        Returns:
+            A (int,dict,string,int) tuple of the file length, dict of the response
+            headers, absolute URI of the response and HTTP response code.
+        """
+
+        response = yield self.request(
+            "GET",
+            url.encode("ascii"),
+            headers=Headers({
+                b"User-Agent": [self.user_agent],
+            })
+        )
+
+        headers = dict(response.headers.getAllRawHeaders())
+
+        if 'Content-Length' in headers and headers['Content-Length'] > max_size:
+            logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
+            raise SynapseError(
+                502,
+                "Requested file is too large > %r bytes" % (self.max_size,),
+                Codes.TOO_LARGE,
+            )
+
+        if response.code > 299:
+            logger.warn("Got %d when downloading %s" % (response.code, url))
+            raise SynapseError(
+                502,
+                "Got error %d" % (response.code,),
+                Codes.UNKNOWN,
+            )
+
+        # TODO: if our Content-Type is HTML or something, just read the first
+        # N bytes into RAM rather than saving it all to disk only to read it
+        # straight back in again
+
+        try:
+            length = yield preserve_context_over_fn(
+                _readBodyToFile,
+                response, output_stream, max_size
+            )
+        except Exception as e:
+            logger.exception("Failed to download body")
+            raise SynapseError(
+                502,
+                ("Failed to download remote body: %s" % e),
+                Codes.UNKNOWN,
+            )
+
+        defer.returnValue((length, headers, response.request.absoluteURI, response.code))
+
+
+# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
+# The two should be factored out.
+
+class _ReadBodyToFileProtocol(protocol.Protocol):
+    def __init__(self, stream, deferred, max_size):
+        self.stream = stream
+        self.deferred = deferred
+        self.length = 0
+        self.max_size = max_size
+
+    def dataReceived(self, data):
+        self.stream.write(data)
+        self.length += len(data)
+        if self.max_size is not None and self.length >= self.max_size:
+            self.deferred.errback(SynapseError(
+                502,
+                "Requested file is too large > %r bytes" % (self.max_size,),
+                Codes.TOO_LARGE,
+            ))
+            self.deferred = defer.Deferred()
+            self.transport.loseConnection()
+
+    def connectionLost(self, reason):
+        if reason.check(ResponseDone):
+            self.deferred.callback(self.length)
+        elif reason.check(PotentialDataLoss):
+            # stolen from https://github.com/twisted/treq/pull/49/files
+            # http://twistedmatrix.com/trac/ticket/4840
+            self.deferred.callback(self.length)
+        else:
+            self.deferred.errback(reason)
+
+
+# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
+# The two should be factored out.
+
+def _readBodyToFile(response, stream, max_size):
+    d = defer.Deferred()
+    response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
+    return d
+
 
 class CaptchaServerHttpClient(SimpleHttpClient):
     """
@@ -269,6 +377,60 @@ class CaptchaServerHttpClient(SimpleHttpClient):
             defer.returnValue(e.response)
 
 
+class SpiderEndpointFactory(object):
+    def __init__(self, hs):
+        self.blacklist = hs.config.url_preview_ip_range_blacklist
+        self.whitelist = hs.config.url_preview_ip_range_whitelist
+        self.policyForHTTPS = hs.get_http_client_context_factory()
+
+    def endpointForURI(self, uri):
+        logger.info("Getting endpoint for %s", uri.toBytes())
+        if uri.scheme == "http":
+            return SpiderEndpoint(
+                reactor, uri.host, uri.port, self.blacklist, self.whitelist,
+                endpoint=TCP4ClientEndpoint,
+                endpoint_kw_args={
+                    'timeout': 15
+                },
+            )
+        elif uri.scheme == "https":
+            tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
+            return SpiderEndpoint(
+                reactor, uri.host, uri.port, self.blacklist, self.whitelist,
+                endpoint=SSL4ClientEndpoint,
+                endpoint_kw_args={
+                    'sslContextFactory': tlsPolicy,
+                    'timeout': 15
+                },
+            )
+        else:
+            logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme)
+
+
+class SpiderHttpClient(SimpleHttpClient):
+    """
+    Separate HTTP client for spidering arbitrary URLs.
+    Special in that it follows retries and has a UA that looks
+    like a browser.
+
+    used by the preview_url endpoint in the content repo.
+    """
+    def __init__(self, hs):
+        SimpleHttpClient.__init__(self, hs)
+        # clobber the base class's agent and UA:
+        self.agent = ContentDecoderAgent(
+            BrowserLikeRedirectAgent(
+                Agent.usingEndpointFactory(
+                    reactor,
+                    SpiderEndpointFactory(hs)
+                )
+            ), [('gzip', GzipDecoder)]
+        )
+        # We could look like Chrome:
+        # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
+        #                   Chrome Safari" % hs.version_string)
+
+
 def encode_urlencode_args(args):
     return {k: encode_urlencode_arg(v) for k, v in args.items()}
 
@@ -301,5 +463,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
         self._context = SSL.Context(SSL.SSLv23_METHOD)
         self._context.set_verify(VERIFY_NONE, lambda *_: None)
 
-    def getContext(self, hostname, port):
+    def getContext(self, hostname=None, port=None):
         return self._context
+
+    def creatorForNetloc(self, hostname, port):
+        return self
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 4775f6707d..442696d393 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
 import collections
 import logging
 import random
+import time
 
 
 logger = logging.getLogger(__name__)
@@ -31,7 +32,7 @@ SERVER_CACHE = {}
 
 
 _Server = collections.namedtuple(
-    "_Server", "priority weight host port"
+    "_Server", "priority weight host port expires"
 )
 
 
@@ -74,6 +75,41 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
         return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
 
 
+class SpiderEndpoint(object):
+    """An endpoint which refuses to connect to blacklisted IP addresses
+    Implements twisted.internet.interfaces.IStreamClientEndpoint.
+    """
+    def __init__(self, reactor, host, port, blacklist, whitelist,
+                 endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
+        self.reactor = reactor
+        self.host = host
+        self.port = port
+        self.blacklist = blacklist
+        self.whitelist = whitelist
+        self.endpoint = endpoint
+        self.endpoint_kw_args = endpoint_kw_args
+
+    @defer.inlineCallbacks
+    def connect(self, protocolFactory):
+        address = yield self.reactor.resolve(self.host)
+
+        from netaddr import IPAddress
+        ip_address = IPAddress(address)
+
+        if ip_address in self.blacklist:
+            if self.whitelist is None or ip_address not in self.whitelist:
+                raise ConnectError(
+                    "Refusing to spider blacklisted IP address %s" % address
+                )
+
+        logger.info("Connecting to %s:%s", address, self.port)
+        endpoint = self.endpoint(
+            self.reactor, address, self.port, **self.endpoint_kw_args
+        )
+        connection = yield endpoint.connect(protocolFactory)
+        defer.returnValue(connection)
+
+
 class SRVClientEndpoint(object):
     """An endpoint which looks up SRV records for a service.
     Cycles through the list of servers starting with each call to connect
@@ -92,7 +128,8 @@ class SRVClientEndpoint(object):
                 host=domain,
                 port=default_port,
                 priority=0,
-                weight=0
+                weight=0,
+                expires=0,
             )
         else:
             self.default_server = None
@@ -118,7 +155,7 @@ class SRVClientEndpoint(object):
                 return self.default_server
             else:
                 raise ConnectError(
-                    "Not server available for %s", self.service_name
+                    "Not server available for %s" % self.service_name
                 )
 
         min_priority = self.servers[0].priority
@@ -153,7 +190,13 @@ class SRVClientEndpoint(object):
 
 
 @defer.inlineCallbacks
-def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
+def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
+    cache_entry = cache.get(service_name, None)
+    if cache_entry:
+        if all(s.expires > int(clock.time()) for s in cache_entry):
+            servers = list(cache_entry)
+            defer.returnValue(servers)
+
     servers = []
 
     try:
@@ -166,34 +209,33 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
                 and answers[0].type == dns.SRV
                 and answers[0].payload
                 and answers[0].payload.target == dns.Name('.')):
-            raise ConnectError("Service %s unavailable", service_name)
+            raise ConnectError("Service %s unavailable" % service_name)
 
         for answer in answers:
             if answer.type != dns.SRV or not answer.payload:
                 continue
 
             payload = answer.payload
-
             host = str(payload.target)
+            srv_ttl = answer.ttl
 
             try:
                 answers, _, _ = yield dns_client.lookupAddress(host)
             except DNSNameError:
                 continue
 
-            ips = [
-                answer.payload.dottedQuad()
-                for answer in answers
-                if answer.type == dns.A and answer.payload
-            ]
-
-            for ip in ips:
-                servers.append(_Server(
-                    host=ip,
-                    port=int(payload.port),
-                    priority=int(payload.priority),
-                    weight=int(payload.weight)
-                ))
+            for answer in answers:
+                if answer.type == dns.A and answer.payload:
+                    ip = answer.payload.dottedQuad()
+                    host_ttl = min(srv_ttl, answer.ttl)
+
+                    servers.append(_Server(
+                        host=ip,
+                        port=int(payload.port),
+                        priority=int(payload.priority),
+                        weight=int(payload.weight),
+                        expires=int(clock.time()) + host_ttl,
+                    ))
 
         servers.sort()
         cache[service_name] = list(servers)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index b82196fd5e..f705abab94 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -74,7 +74,12 @@ response_db_txn_duration = metrics.register_distribution(
 _next_request_id = 0
 
 
-def request_handler(request_handler):
+def request_handler(report_metrics=True):
+    """Decorator for ``wrap_request_handler``"""
+    return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
+
+
+def wrap_request_handler(request_handler, report_metrics):
     """Wraps a method that acts as a request handler with the necessary logging
     and exception handling.
 
@@ -96,7 +101,12 @@ def request_handler(request_handler):
         global _next_request_id
         request_id = "%s-%s" % (request.method, _next_request_id)
         _next_request_id += 1
+
         with LoggingContext(request_id) as request_context:
+            if report_metrics:
+                request_metrics = RequestMetrics()
+                request_metrics.start(self.clock)
+
             request_context.request = request_id
             with request.processing():
                 try:
@@ -133,6 +143,14 @@ def request_handler(request_handler):
                         },
                         send_cors=True
                     )
+                finally:
+                    try:
+                        if report_metrics:
+                            request_metrics.stop(
+                                self.clock, request, self.__class__.__name__
+                            )
+                    except:
+                        pass
     return wrapped_request_handler
 
 
@@ -197,19 +215,23 @@ class JsonResource(HttpServer, resource.Resource):
         self._async_render(request)
         return server.NOT_DONE_YET
 
-    @request_handler
+    # Disable metric reporting because _async_render does its own metrics.
+    # It does its own metric reporting because _async_render dispatches to
+    # a callback and it's the class name of that callback we want to report
+    # against rather than the JsonResource itself.
+    @request_handler(report_metrics=False)
     @defer.inlineCallbacks
     def _async_render(self, request):
         """ This gets called from render() every time someone sends us a request.
             This checks if anyone has registered a callback for that method and
             path.
         """
-        start = self.clock.time_msec()
         if request.method == "OPTIONS":
             self._send_response(request, 200, {})
             return
 
-        start_context = LoggingContext.current_context()
+        request_metrics = RequestMetrics()
+        request_metrics.start(self.clock)
 
         # Loop through all the registered callbacks to check if the method
         # and path regex match
@@ -241,40 +263,7 @@ class JsonResource(HttpServer, resource.Resource):
                 self._send_response(request, code, response)
 
             try:
-                context = LoggingContext.current_context()
-
-                tag = ""
-                if context:
-                    tag = context.tag
-
-                    if context != start_context:
-                        logger.warn(
-                            "Context have unexpectedly changed %r, %r",
-                            context, self.start_context
-                        )
-                        return
-
-                incoming_requests_counter.inc(request.method, servlet_classname, tag)
-
-                response_timer.inc_by(
-                    self.clock.time_msec() - start, request.method,
-                    servlet_classname, tag
-                )
-
-                ru_utime, ru_stime = context.get_resource_usage()
-
-                response_ru_utime.inc_by(
-                    ru_utime, request.method, servlet_classname, tag
-                )
-                response_ru_stime.inc_by(
-                    ru_stime, request.method, servlet_classname, tag
-                )
-                response_db_txn_count.inc_by(
-                    context.db_txn_count, request.method, servlet_classname, tag
-                )
-                response_db_txn_duration.inc_by(
-                    context.db_txn_duration, request.method, servlet_classname, tag
-                )
+                request_metrics.stop(self.clock, request, servlet_classname)
             except:
                 pass
 
@@ -307,6 +296,48 @@ class JsonResource(HttpServer, resource.Resource):
         )
 
 
+class RequestMetrics(object):
+    def start(self, clock):
+        self.start = clock.time_msec()
+        self.start_context = LoggingContext.current_context()
+
+    def stop(self, clock, request, servlet_classname):
+        context = LoggingContext.current_context()
+
+        tag = ""
+        if context:
+            tag = context.tag
+
+            if context != self.start_context:
+                logger.warn(
+                    "Context have unexpectedly changed %r, %r",
+                    context, self.start_context
+                )
+                return
+
+        incoming_requests_counter.inc(request.method, servlet_classname, tag)
+
+        response_timer.inc_by(
+            clock.time_msec() - self.start, request.method,
+            servlet_classname, tag
+        )
+
+        ru_utime, ru_stime = context.get_resource_usage()
+
+        response_ru_utime.inc_by(
+            ru_utime, request.method, servlet_classname, tag
+        )
+        response_ru_stime.inc_by(
+            ru_stime, request.method, servlet_classname, tag
+        )
+        response_db_txn_count.inc_by(
+            context.db_txn_count, request.method, servlet_classname, tag
+        )
+        response_db_txn_duration.inc_by(
+            context.db_txn_duration, request.method, servlet_classname, tag
+        )
+
+
 class RootRedirect(resource.Resource):
     """Redirects the root '/' path to another path."""
 
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 1c8bd8666f..e41afeab8e 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -26,14 +26,19 @@ logger = logging.getLogger(__name__)
 def parse_integer(request, name, default=None, required=False):
     """Parse an integer parameter from the request string
 
-    :param request: the twisted HTTP request.
-    :param name (str): the name of the query parameter.
-    :param default: value to use if the parameter is absent, defaults to None.
-    :param required (bool): whether to raise a 400 SynapseError if the
-        parameter is absent, defaults to False.
-    :return: An int value or the default.
-    :raises
-        SynapseError if the parameter is absent and required, or if the
+    Args:
+        request: the twisted HTTP request.
+        name (str): the name of the query parameter.
+        default (int|None): value to use if the parameter is absent, defaults
+            to None.
+        required (bool): whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+
+    Returns:
+        int|None: An int value or the default.
+
+    Raises:
+        SynapseError: if the parameter is absent and required, or if the
             parameter is present and not an integer.
     """
     if name in request.args:
@@ -53,14 +58,19 @@ def parse_integer(request, name, default=None, required=False):
 def parse_boolean(request, name, default=None, required=False):
     """Parse a boolean parameter from the request query string
 
-    :param request: the twisted HTTP request.
-    :param name (str): the name of the query parameter.
-    :param default: value to use if the parameter is absent, defaults to None.
-    :param required (bool): whether to raise a 400 SynapseError if the
-        parameter is absent, defaults to False.
-    :return: A bool value or the default.
-    :raises
-        SynapseError if the parameter is absent and required, or if the
+    Args:
+        request: the twisted HTTP request.
+        name (str): the name of the query parameter.
+        default (bool|None): value to use if the parameter is absent, defaults
+            to None.
+        required (bool): whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+
+    Returns:
+        bool|None: A bool value or the default.
+
+    Raises:
+        SynapseError: if the parameter is absent and required, or if the
             parameter is present and not one of "true" or "false".
     """
 
@@ -88,15 +98,20 @@ def parse_string(request, name, default=None, required=False,
                  allowed_values=None, param_type="string"):
     """Parse a string parameter from the request query string.
 
-    :param request: the twisted HTTP request.
-    :param name (str): the name of the query parameter.
-    :param default: value to use if the parameter is absent, defaults to None.
-    :param required (bool): whether to raise a 400 SynapseError if the
-        parameter is absent, defaults to False.
-    :param allowed_values (list): List of allowed values for the string,
-        or None if any value is allowed, defaults to None
-    :return: A string value or the default.
-    :raises
+    Args:
+        request: the twisted HTTP request.
+        name (str): the name of the query parameter.
+        default (str|None): value to use if the parameter is absent, defaults
+            to None.
+        required (bool): whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+        allowed_values (list[str]): List of allowed values for the string,
+            or None if any value is allowed, defaults to None
+
+    Returns:
+        str|None: A string value or the default.
+
+    Raises:
         SynapseError if the parameter is absent and required, or if the
             parameter is present, must be one of a list of allowed values and
             is not one of those allowed values.
@@ -122,9 +137,13 @@ def parse_string(request, name, default=None, required=False,
 def parse_json_value_from_request(request):
     """Parse a JSON value from the body of a twisted HTTP request.
 
-    :param request: the twisted HTTP request.
-    :returns: The JSON value.
-    :raises
+    Args:
+        request: the twisted HTTP request.
+
+    Returns:
+        The JSON value.
+
+    Raises:
         SynapseError if the request body couldn't be decoded as JSON.
     """
     try:
@@ -143,8 +162,10 @@ def parse_json_value_from_request(request):
 def parse_json_object_from_request(request):
     """Parse a JSON object from the body of a twisted HTTP request.
 
-    :param request: the twisted HTTP request.
-    :raises
+    Args:
+        request: the twisted HTTP request.
+
+    Raises:
         SynapseError if the request body couldn't be decoded as JSON or
             if it wasn't a JSON object.
     """
diff --git a/synapse/http/site.py b/synapse/http/site.py
new file mode 100644
index 0000000000..4b09d7ee66
--- /dev/null
+++ b/synapse/http/site.py
@@ -0,0 +1,146 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.logcontext import LoggingContext
+from twisted.web.server import Site, Request
+
+import contextlib
+import logging
+import re
+import time
+
+ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+
+
+class SynapseRequest(Request):
+    def __init__(self, site, *args, **kw):
+        Request.__init__(self, *args, **kw)
+        self.site = site
+        self.authenticated_entity = None
+        self.start_time = 0
+
+    def __repr__(self):
+        # We overwrite this so that we don't log ``access_token``
+        return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
+            self.__class__.__name__,
+            id(self),
+            self.method,
+            self.get_redacted_uri(),
+            self.clientproto,
+            self.site.site_tag,
+        )
+
+    def get_redacted_uri(self):
+        return ACCESS_TOKEN_RE.sub(
+            r'\1<redacted>\3',
+            self.uri
+        )
+
+    def get_user_agent(self):
+        return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
+
+    def started_processing(self):
+        self.site.access_logger.info(
+            "%s - %s - Received request: %s %s",
+            self.getClientIP(),
+            self.site.site_tag,
+            self.method,
+            self.get_redacted_uri()
+        )
+        self.start_time = int(time.time() * 1000)
+
+    def finished_processing(self):
+
+        try:
+            context = LoggingContext.current_context()
+            ru_utime, ru_stime = context.get_resource_usage()
+            db_txn_count = context.db_txn_count
+            db_txn_duration = context.db_txn_duration
+        except:
+            ru_utime, ru_stime = (0, 0)
+            db_txn_count, db_txn_duration = (0, 0)
+
+        self.site.access_logger.info(
+            "%s - %s - {%s}"
+            " Processed request: %dms (%dms, %dms) (%dms/%d)"
+            " %sB %s \"%s %s %s\" \"%s\"",
+            self.getClientIP(),
+            self.site.site_tag,
+            self.authenticated_entity,
+            int(time.time() * 1000) - self.start_time,
+            int(ru_utime * 1000),
+            int(ru_stime * 1000),
+            int(db_txn_duration * 1000),
+            int(db_txn_count),
+            self.sentLength,
+            self.code,
+            self.method,
+            self.get_redacted_uri(),
+            self.clientproto,
+            self.get_user_agent(),
+        )
+
+    @contextlib.contextmanager
+    def processing(self):
+        self.started_processing()
+        yield
+        self.finished_processing()
+
+
+class XForwardedForRequest(SynapseRequest):
+    def __init__(self, *args, **kw):
+        SynapseRequest.__init__(self, *args, **kw)
+
+    """
+    Add a layer on top of another request that only uses the value of an
+    X-Forwarded-For header as the result of C{getClientIP}.
+    """
+    def getClientIP(self):
+        """
+        @return: The client address (the first address) in the value of the
+            I{X-Forwarded-For header}.  If the header is not present, return
+            C{b"-"}.
+        """
+        return self.requestHeaders.getRawHeaders(
+            b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
+
+
+class SynapseRequestFactory(object):
+    def __init__(self, site, x_forwarded_for):
+        self.site = site
+        self.x_forwarded_for = x_forwarded_for
+
+    def __call__(self, *args, **kwargs):
+        if self.x_forwarded_for:
+            return XForwardedForRequest(self.site, *args, **kwargs)
+        else:
+            return SynapseRequest(self.site, *args, **kwargs)
+
+
+class SynapseSite(Site):
+    """
+    Subclass of a twisted http Site that does access logging with python's
+    standard logging
+    """
+    def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
+        Site.__init__(self, resource, *args, **kwargs)
+
+        self.site_tag = site_tag
+
+        proxied = config.get("x_forwarded", False)
+        self.requestFactory = SynapseRequestFactory(self, proxied)
+        self.access_logger = logging.getLogger(logger_name)
+
+    def log(self, request):
+        pass