diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index bfebb0f644..58ef8d3ce4 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,3 +13,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
+
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
+
+from synapse.api.errors import SynapseError
+
+
+class RequestTimedOutError(SynapseError):
+ """Exception representing timeout of an outbound request"""
+ def __init__(self):
+ super(RequestTimedOutError, self).__init__(504, "Timed out")
+
+
+def cancelled_to_request_timed_out_error(value, timeout):
+ """Turns CancelledErrors into RequestTimedOutErrors.
+
+ For use with async.add_timeout_to_deferred
+ """
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise RequestTimedOutError()
+ return value
+
+
+ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+
+
+def redact_uri(uri):
+ """Strips access tokens from the uri replaces with <redacted>"""
+ return ACCESS_TOKEN_RE.sub(
+ br'\1<redacted>\3',
+ uri
+ )
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
new file mode 100644
index 0000000000..0e10e3f8f7
--- /dev/null
+++ b/synapse/http/additional_resource.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+from synapse.http.server import wrap_json_request_handler
+
+
+class AdditionalResource(Resource):
+ """Resource wrapper for additional_resources
+
+ If the user has configured additional_resources, we need to wrap the
+ handler class with a Resource so that we can map it into the resource tree.
+
+ This class is also where we wrap the request handler with logging, metrics,
+ and exception handling.
+ """
+ def __init__(self, hs, handler):
+ """Initialise AdditionalResource
+
+ The ``handler`` should return a deferred which completes when it has
+ done handling the request. It should write a response with
+ ``request.write()``, and call ``request.finish()``.
+
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
+ function to be called to handle the request.
+ """
+ Resource.__init__(self)
+ self._handler = handler
+
+ # required by the request_handler wrapper
+ self.clock = hs.get_clock()
+
+ def render(self, request):
+ self._async_render(request)
+ return NOT_DONE_YET
+
+ @wrap_json_request_handler
+ def _async_render(self, request):
+ return self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 9eba046bbf..25b6307884 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,49 +13,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from OpenSSL import SSL
-from OpenSSL.SSL import VERIFY_NONE
+import logging
+import urllib
-from synapse.api.errors import (
- CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
-)
-from synapse.util.logcontext import preserve_context_over_fn
-from synapse.util import logcontext
-import synapse.metrics
-from synapse.http.endpoint import SpiderEndpoint
+from six import StringIO
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
+from prometheus_client import Counter
-from twisted.internet import defer, reactor, ssl, protocol, task
+from OpenSSL import SSL
+from OpenSSL.SSL import VERIFY_NONE
+from twisted.internet import defer, protocol, reactor, ssl, task
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.web._newclient import ResponseDone
from twisted.web.client import (
- BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
- readBody, PartialDownloadError,
+ Agent,
+ BrowserLikeRedirectAgent,
+ ContentDecoderAgent,
+ FileBodyProducer as TwistedFileBodyProducer,
+ GzipDecoder,
+ HTTPConnectionPool,
+ PartialDownloadError,
+ readBody,
)
-from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web._newclient import ResponseDone
-
-from StringIO import StringIO
-
-import simplejson as json
-import logging
-import urllib
+from synapse.api.errors import (
+ CodeMessageException,
+ Codes,
+ MatrixCodeMessageException,
+ SynapseError,
+)
+from synapse.http import cancelled_to_request_timed_out_error, redact_uri
+from synapse.http.endpoint import SpiderEndpoint
+from synapse.util.async import add_timeout_to_deferred
+from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-outgoing_requests_counter = metrics.register_counter(
- "requests",
- labels=["method"],
-)
-incoming_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
-)
+outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
+incoming_responses_counter = Counter("synapse_http_client_responses", "",
+ ["method", "code"])
class SimpleHttpClient(object):
@@ -64,13 +65,23 @@ class SimpleHttpClient(object):
"""
def __init__(self, hs):
self.hs = hs
+
+ pool = HTTPConnectionPool(reactor)
+
+ # the pusher makes lots of concurrent SSL connections to sygnal, and
+ # tends to do so in batches, so we need to allow the pool to keep lots
+ # of idle connections around.
+ pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
+ pool.cachedConnectionTimeout = 2 * 60
+
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(
reactor,
connectTimeout=15,
- contextFactory=hs.get_http_client_context_factory()
+ contextFactory=hs.get_http_client_context_factory(),
+ pool=pool,
)
self.user_agent = hs.version_string
self.clock = hs.get_clock()
@@ -81,76 +92,103 @@ class SimpleHttpClient(object):
def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
- outgoing_requests_counter.inc(method)
+ outgoing_requests_counter.labels(method).inc()
- def send_request():
+ # log request but strip `access_token` (AS requests for example include this)
+ logger.info("Sending request %s %s", method, redact_uri(uri))
+
+ try:
request_deferred = self.agent.request(
method, uri, *args, **kwargs
)
-
- return self.clock.time_bound_deferred(
- request_deferred,
- time_out=60,
+ add_timeout_to_deferred(
+ request_deferred, 60, self.hs.get_reactor(),
+ cancelled_to_request_timed_out_error,
)
+ response = yield make_deferred_yieldable(request_deferred)
- logger.info("Sending request %s %s", method, uri)
-
- try:
- with logcontext.PreserveLoggingContext():
- response = yield send_request()
-
- incoming_responses_counter.inc(method, response.code)
+ incoming_responses_counter.labels(method, response.code).inc()
logger.info(
"Received response to %s %s: %s",
- method, uri, response.code
+ method, redact_uri(uri), response.code
)
defer.returnValue(response)
except Exception as e:
- incoming_responses_counter.inc(method, "ERR")
+ incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
"Error sending request to %s %s: %s %s",
- method, uri, type(e).__name__, e.message
+ method, redact_uri(uri), type(e).__name__, e.message
)
- raise e
+ raise
@defer.inlineCallbacks
- def post_urlencoded_get_json(self, uri, args={}):
+ def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ """
+ Args:
+ uri (str):
+ args (dict[str, str|List[str]]): query params
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
+
+ Returns:
+ Deferred[object]: parsed json
+ """
+
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
+ actual_headers = {
+ b"Content-Type": [b"application/x-www-form-urlencoded"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"POST",
uri.encode("ascii"),
- headers=Headers({
- b"Content-Type": [b"application/x-www-form-urlencoded"],
- b"User-Agent": [self.user_agent],
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def post_json_get_json(self, uri, post_json):
+ def post_json_get_json(self, uri, post_json, headers=None):
+ """
+
+ Args:
+ uri (str):
+ post_json (object):
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
+
+ Returns:
+ Deferred[object]: parsed json
+ """
json_str = encode_canonical_json(post_json)
logger.debug("HTTP POST %s -> %s", json_str, uri)
+ actual_headers = {
+ b"Content-Type": [b"application/json"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"POST",
uri.encode("ascii"),
- headers=Headers({
- b"Content-Type": [b"application/json"],
- b"User-Agent": [self.user_agent],
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -160,7 +198,7 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def get_json(self, uri, args={}):
+ def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
Args:
@@ -169,6 +207,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@@ -177,13 +217,13 @@ class SimpleHttpClient(object):
error message.
"""
try:
- body = yield self.get_raw(uri, args)
+ body = yield self.get_raw(uri, args, headers=headers)
defer.returnValue(json.loads(body))
except CodeMessageException as e:
raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks
- def put_json(self, uri, json_body, args={}):
+ def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
Args:
@@ -193,6 +233,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@@ -205,17 +247,21 @@ class SimpleHttpClient(object):
json_str = encode_canonical_json(json_body)
+ actual_headers = {
+ b"Content-Type": [b"application/json"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"PUT",
uri.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- "Content-Type": ["application/json"]
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -226,7 +272,7 @@ class SimpleHttpClient(object):
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
- def get_raw(self, uri, args={}):
+ def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
Args:
@@ -235,6 +281,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
@@ -246,15 +294,19 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
+ actual_headers = {
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"GET",
uri.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(actual_headers),
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(body)
@@ -274,27 +326,33 @@ class SimpleHttpClient(object):
# The two should be factored out.
@defer.inlineCallbacks
- def get_file(self, url, output_stream, max_size=None):
+ def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
"""
+ actual_headers = {
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"GET",
url.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(actual_headers),
)
- headers = dict(response.headers.getAllRawHeaders())
+ resp_headers = dict(response.headers.getAllRawHeaders())
- if 'Content-Length' in headers and headers['Content-Length'] > max_size:
+ if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
@@ -315,10 +373,9 @@ class SimpleHttpClient(object):
# straight back in again
try:
- length = yield preserve_context_over_fn(
- _readBodyToFile,
- response, output_stream, max_size
- )
+ length = yield make_deferred_yieldable(_readBodyToFile(
+ response, output_stream, max_size,
+ ))
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
@@ -327,7 +384,9 @@ class SimpleHttpClient(object):
Codes.UNKNOWN,
)
- defer.returnValue((length, headers, response.request.absoluteURI, response.code))
+ defer.returnValue(
+ (length, resp_headers, response.request.absoluteURI, response.code),
+ )
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
@@ -395,7 +454,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
)
try:
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(body)
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.
@@ -446,7 +505,7 @@ class SpiderHttpClient(SimpleHttpClient):
reactor,
SpiderEndpointFactory(hs)
)
- ), [('gzip', GzipDecoder)]
+ ), [(b'gzip', GzipDecoder)]
)
# We could look like Chrome:
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index d8923c9abb..d65daa72bb 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,30 +12,97 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet import defer, reactor
-from twisted.internet.error import ConnectError
-from twisted.names import client, dns
-from twisted.names.error import DNSNameError, DomainError
-
import collections
import logging
import random
+import re
import time
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.error import ConnectError
+from twisted.names import client, dns
+from twisted.names.error import DNSNameError, DomainError
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-
+# our record of an individual server which can be tried to reach a destination.
+#
+# "host" is the hostname acquired from the SRV record. Except when there's
+# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
+def parse_server_name(server_name):
+ """Split a server name into host/port parts.
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ try:
+ if server_name[-1] == ']':
+ # ipv6 literal, hopefully
+ return server_name, None
+
+ domain_port = server_name.rsplit(":", 1)
+ domain = domain_port[0]
+ port = int(domain_port[1]) if domain_port[1:] else None
+ return domain, port
+ except Exception:
+ raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile(
+ "\\A[0-9a-zA-Z.-]+\\Z",
+)
+
+
+def parse_and_validate_server_name(server_name):
+ """Split a server name into host/port parts and do some basic validation.
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ host, port = parse_server_name(server_name)
+
+ # these tests don't need to be bulletproof as we'll find out soon enough
+ # if somebody is giving us invalid data. What we *do* need is to be sure
+ # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+ # look for ipv6 literals
+ if host[0] == '[':
+ if host[-1] != ']':
+ raise ValueError("Mismatched [...] in server name '%s'" % (
+ server_name,
+ ))
+ return host, port
+
+ # otherwise it should only be alphanumerics.
+ if not VALID_HOST_REGEX.match(host):
+ raise ValueError("Server name '%s' contains invalid characters" % (
+ server_name,
+ ))
+
+ return host, port
+
+
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
@@ -48,9 +115,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout (int): connection timeout in seconds
"""
- domain_port = destination.split(":")
- domain = domain_port[0]
- port = int(domain_port[1]) if domain_port[1:] else None
+ domain, port = parse_server_name(destination)
endpoint_kw_args = {}
@@ -72,21 +137,22 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
reactor, "matrix", domain, protocol="tcp",
default_port=default_port, endpoint=transport_endpoint,
endpoint_kw_args=endpoint_kw_args
- ))
+ ), reactor)
else:
return _WrappingEndpointFac(transport_endpoint(
reactor, domain, port, **endpoint_kw_args
- ))
+ ), reactor)
class _WrappingEndpointFac(object):
- def __init__(self, endpoint_fac):
+ def __init__(self, endpoint_fac, reactor):
self.endpoint_fac = endpoint_fac
+ self.reactor = reactor
@defer.inlineCallbacks
def connect(self, protocolFactory):
conn = yield self.endpoint_fac.connect(protocolFactory)
- conn = _WrappedConnection(conn)
+ conn = _WrappedConnection(conn, self.reactor)
defer.returnValue(conn)
@@ -96,9 +162,10 @@ class _WrappedConnection(object):
"""
__slots__ = ["conn", "last_request"]
- def __init__(self, conn):
+ def __init__(self, conn, reactor):
object.__setattr__(self, "conn", conn)
object.__setattr__(self, "last_request", time.time())
+ self._reactor = reactor
def __getattr__(self, name):
return getattr(self.conn, name)
@@ -113,10 +180,15 @@ class _WrappedConnection(object):
if time.time() - self.last_request >= 2.5 * 60:
self.abort()
# Abort the underlying TLS connection. The abort() method calls
- # loseConnection() on the underlying TLS connection which tries to
+ # loseConnection() on the TLS connection which tries to
# shutdown the connection cleanly. We call abortConnection()
- # since that will promptly close the underlying TCP connection.
- self.transport.abortConnection()
+ # since that will promptly close the TLS connection.
+ #
+ # In Twisted >18.4; the TLS connection will be None if it has closed
+ # which will make abortConnection() throw. Check that the TLS connection
+ # is not None before trying to close it.
+ if self.transport.getHandle() is not None:
+ self.transport.abortConnection()
def request(self, request):
self.last_request = time.time()
@@ -124,14 +196,14 @@ class _WrappedConnection(object):
# Time this connection out if we haven't send a request in the last
# N minutes
# TODO: Cancel the previous callLater?
- reactor.callLater(3 * 60, self._time_things_out_maybe)
+ self._reactor.callLater(3 * 60, self._time_things_out_maybe)
d = self.conn.request(request)
def update_request_time(res):
self.last_request = time.time()
# TODO: Cancel the previous callLater?
- reactor.callLater(3 * 60, self._time_things_out_maybe)
+ self._reactor.callLater(3 * 60, self._time_things_out_maybe)
return res
d.addCallback(update_request_time)
@@ -219,9 +291,10 @@ class SRVClientEndpoint(object):
return self.default_server
else:
raise ConnectError(
- "Not server available for %s" % self.service_name
+ "No server available for %s" % self.service_name
)
+ # look for all servers with the same priority
min_priority = self.servers[0].priority
weight_indexes = list(
(index, server.weight + 1)
@@ -231,11 +304,22 @@ class SRVClientEndpoint(object):
total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight)
-
for index, weight in weight_indexes:
target_weight -= weight
if target_weight <= 0:
server = self.servers[index]
+ # XXX: this looks totally dubious:
+ #
+ # (a) we never reuse a server until we have been through
+ # all of the servers at the same priority, so if the
+ # weights are A: 100, B:1, we always do ABABAB instead of
+ # AAAA...AAAB (approximately).
+ #
+ # (b) After using all the servers at the lowest priority,
+ # we move onto the next priority. We should only use the
+ # second priority if servers at the top priority are
+ # unreachable.
+ #
del self.servers[index]
self.used_servers.append(server)
return server
@@ -272,7 +356,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
- and answers[0].payload.target == dns.Name('.')):
+ and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
@@ -280,26 +364,14 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
continue
payload = answer.payload
- host = str(payload.target)
- srv_ttl = answer.ttl
-
- try:
- answers, _, _ = yield dns_client.lookupAddress(host)
- except DNSNameError:
- continue
- for answer in answers:
- if answer.type == dns.A and answer.payload:
- ip = answer.payload.dottedQuad()
- host_ttl = min(srv_ttl, answer.ttl)
-
- servers.append(_Server(
- host=ip,
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight),
- expires=int(clock.time()) + host_ttl,
- ))
+ servers.append(_Server(
+ host=str(payload.target),
+ port=int(payload.port),
+ priority=int(payload.priority),
+ weight=int(payload.weight),
+ expires=int(clock.time()) + answer.ttl,
+ ))
servers.sort()
cache[service_name] = list(servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 747a791f83..bf1aa29502 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,48 +13,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.util.retryutils
-from twisted.internet import defer, reactor, protocol
-from twisted.internet.error import DNSLookupError
-from twisted.web.client import readBody, HTTPConnectionPool, Agent
-from twisted.web.http_headers import Headers
-from twisted.web._newclient import ResponseDone
-
-from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.async import sleep
-from synapse.util import logcontext
-import synapse.metrics
-
-from canonicaljson import encode_canonical_json
-
-from synapse.api.errors import (
- SynapseError, Codes, HttpResponseException,
-)
-
-from signedjson.sign import sign_json
-
import cgi
-import simplejson as json
import logging
import random
import sys
import urllib
-import urlparse
+from six import string_types
+from six.moves.urllib import parse as urlparse
-logger = logging.getLogger(__name__)
-outbound_logger = logging.getLogger("synapse.http.outbound")
+from canonicaljson import encode_canonical_json, json
+from prometheus_client import Counter
+from signedjson.sign import sign_json
-metrics = synapse.metrics.get_metrics_for(__name__)
+from twisted.internet import defer, protocol, reactor
+from twisted.internet.error import DNSLookupError
+from twisted.web._newclient import ResponseDone
+from twisted.web.client import Agent, HTTPConnectionPool, readBody
+from twisted.web.http_headers import Headers
-outgoing_requests_counter = metrics.register_counter(
- "requests",
- labels=["method"],
-)
-incoming_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
+import synapse.metrics
+import synapse.util.retryutils
+from synapse.api.errors import (
+ Codes,
+ FederationDeniedError,
+ HttpResponseException,
+ SynapseError,
)
+from synapse.http import cancelled_to_request_timed_out_error
+from synapse.http.endpoint import matrix_federation_endpoint
+from synapse.util import logcontext
+from synapse.util.async import add_timeout_to_deferred
+from synapse.util.logcontext import make_deferred_yieldable
+
+logger = logging.getLogger(__name__)
+outbound_logger = logging.getLogger("synapse.http.outbound")
+
+outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
+ "", ["method"])
+incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses",
+ "", ["method", "code"])
MAX_LONG_RETRIES = 10
@@ -123,11 +122,22 @@ class MatrixFederationHttpClient(object):
Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300.
+
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
+
(May also fail with plenty of other Exceptions for things like DNS
failures, connection failures, SSL failures.)
"""
+ if (
+ self.hs.config.federation_domain_whitelist and
+ destination not in self.hs.config.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(destination)
+
limiter = yield synapse.util.retryutils.get_retry_limiter(
destination,
self.clock,
@@ -173,21 +183,21 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, http_url_bytes, headers_dict)
try:
- def send_request():
- request_deferred = self.agent.request(
- method,
- url_bytes,
- Headers(headers_dict),
- producer
- )
-
- return self.clock.time_bound_deferred(
- request_deferred,
- time_out=timeout / 1000. if timeout else 60,
- )
-
- with logcontext.PreserveLoggingContext():
- response = yield send_request()
+ request_deferred = self.agent.request(
+ method,
+ url_bytes,
+ Headers(headers_dict),
+ producer
+ )
+ add_timeout_to_deferred(
+ request_deferred,
+ timeout / 1000. if timeout else 60,
+ self.hs.get_reactor(),
+ cancelled_to_request_timed_out_error,
+ )
+ response = yield make_deferred_yieldable(
+ request_deferred,
+ )
log_result = "%d %s" % (response.code, response.phrase,)
break
@@ -204,18 +214,15 @@ class MatrixFederationHttpClient(object):
raise
logger.warn(
- "{%s} Sending request failed to %s: %s %s: %s - %s",
+ "{%s} Sending request failed to %s: %s %s: %s",
txn_id,
destination,
method,
url_bytes,
- type(e).__name__,
_flatten_response_never_received(e),
)
- log_result = "%s - %s" % (
- type(e).__name__, _flatten_response_never_received(e),
- )
+ log_result = _flatten_response_never_received(e)
if retries_left and not timeout:
if long_retries:
@@ -227,7 +234,7 @@ class MatrixFederationHttpClient(object):
delay = min(delay, 2)
delay *= random.uniform(0.8, 1.4)
- yield sleep(delay)
+ yield self.clock.sleep(delay)
retries_left -= 1
else:
raise
@@ -253,14 +260,35 @@ class MatrixFederationHttpClient(object):
defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict,
- content=None):
+ content=None, destination_is=None):
+ """
+ Signs a request by adding an Authorization header to headers_dict
+ Args:
+ destination (bytes|None): The desination home server of the request.
+ May be None if the destination is an identity server, in which case
+ destination_is must be non-None.
+ method (bytes): The HTTP method of the request
+ url_bytes (bytes): The URI path of the request
+ headers_dict (dict): Dictionary of request headers to append to
+ content (bytes): The body of the request
+ destination_is (bytes): As 'destination', but if the destination is an
+ identity server
+
+ Returns:
+ None
+ """
request = {
"method": method,
"uri": url_bytes,
"origin": self.server_name,
- "destination": destination,
}
+ if destination is not None:
+ request["destination"] = destination
+
+ if destination_is is not None:
+ request["destination_is"] = destination_is
+
if content is not None:
request["content"] = content
@@ -278,7 +306,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
- def put_json(self, destination, path, data={}, json_data_callback=None,
+ def put_json(self, destination, path, args={}, data={},
+ json_data_callback=None,
long_retries=False, timeout=None,
ignore_backoff=False,
backoff_on_404=False):
@@ -288,6 +317,7 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
+ args (dict): query params
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
@@ -311,6 +341,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
if not json_data_callback:
@@ -331,6 +364,7 @@ class MatrixFederationHttpClient(object):
path,
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
+ query_bytes=encode_query_args(args),
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
@@ -347,7 +381,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False,
- timeout=None, ignore_backoff=False):
+ timeout=None, ignore_backoff=False, args={}):
""" Sends the specifed json data using POST
Args:
@@ -362,6 +396,7 @@ class MatrixFederationHttpClient(object):
giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
+ args (dict): query params
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
@@ -371,6 +406,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
def body_callback(method, url_bytes, headers_dict):
@@ -383,6 +421,7 @@ class MatrixFederationHttpClient(object):
destination,
"POST",
path,
+ query_bytes=encode_query_args(args),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
@@ -424,16 +463,12 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
logger.debug("get_json args: %s", args)
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, basestring):
- vs = [vs]
- encoded_args[k] = [v.encode("UTF-8") for v in vs]
-
- query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
@@ -444,7 +479,7 @@ class MatrixFederationHttpClient(object):
destination,
"GET",
path,
- query_bytes=query_bytes,
+ query_bytes=encode_query_args(args),
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
@@ -461,6 +496,55 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
+ def delete_json(self, destination, path, long_retries=False,
+ timeout=None, ignore_backoff=False, args={}):
+ """Send a DELETE request to the remote expecting some json response
+
+ Args:
+ destination (str): The remote server to send the HTTP request
+ to.
+ path (str): The HTTP path.
+ long_retries (bool): A boolean that indicates whether we should
+ retry for a short or long time.
+ timeout(int): How long to try (in ms) the destination for before
+ giving up. None indicates no timeout.
+ ignore_backoff (bool): true to ignore the historical backoff data and
+ try the request anyway.
+ Returns:
+ Deferred: Succeeds when we get a 2xx HTTP response. The result
+ will be the decoded JSON body.
+
+ Fails with ``HTTPRequestException`` if we get an HTTP response
+ code >= 300.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
+ """
+
+ response = yield self._request(
+ destination,
+ "DELETE",
+ path,
+ query_bytes=encode_query_args(args),
+ headers_dict={"Content-Type": ["application/json"]},
+ long_retries=long_retries,
+ timeout=timeout,
+ ignore_backoff=ignore_backoff,
+ )
+
+ if 200 <= response.code < 300:
+ # We need to update the transactions table to say it was sent?
+ check_content_type_is_json(response.headers)
+
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
+
+ defer.returnValue(json.loads(body))
+
+ @defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
retry_on_dns_fail=True, max_size=None,
ignore_backoff=False):
@@ -481,11 +565,14 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
encoded_args = {}
for k, vs in args.items():
- if isinstance(vs, basestring):
+ if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
@@ -513,7 +600,7 @@ class MatrixFederationHttpClient(object):
length = yield _readBodyToFile(
response, output_stream, max_size
)
- except:
+ except Exception:
logger.exception("Failed to download body")
raise
@@ -578,12 +665,14 @@ class _JsonProducer(object):
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
- return ", ".join(
+ reasons = ", ".join(
_flatten_response_never_received(f.value)
for f in e.reasons
)
+
+ return "%s:[%s]" % (type(e).__name__, reasons)
else:
- return "%s: %s" % (type(e).__name__, e.message,)
+ return repr(e)
def check_content_type_is_json(headers):
@@ -598,7 +687,7 @@ def check_content_type_is_json(headers):
RuntimeError if the
"""
- c_type = headers.getRawHeaders("Content-Type")
+ c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
raise RuntimeError(
"No Content-Type header"
@@ -610,3 +699,15 @@ def check_content_type_is_json(headers):
raise RuntimeError(
"Content-Type not application/json: was '%s'" % c_type
)
+
+
+def encode_query_args(args):
+ encoded_args = {}
+ for k, vs in args.items():
+ if isinstance(vs, string_types):
+ vs = [vs]
+ encoded_args[k] = [v.encode("UTF-8") for v in vs]
+
+ query_bytes = urllib.urlencode(encoded_args, True)
+
+ return query_bytes
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
new file mode 100644
index 0000000000..588e280571
--- /dev/null
+++ b/synapse/http/request_metrics.py
@@ -0,0 +1,231 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from prometheus_client.core import Counter, Histogram
+
+from synapse.metrics import LaterGauge
+from synapse.util.logcontext import LoggingContext
+
+logger = logging.getLogger(__name__)
+
+
+# total number of responses served, split by method/servlet/tag
+response_count = Counter(
+ "synapse_http_server_response_count", "", ["method", "servlet", "tag"]
+)
+
+requests_counter = Counter(
+ "synapse_http_server_requests_received", "", ["method", "servlet"]
+)
+
+outgoing_responses_counter = Counter(
+ "synapse_http_server_responses", "", ["method", "code"]
+)
+
+response_timer = Histogram(
+ "synapse_http_server_response_time_seconds", "sec",
+ ["method", "servlet", "tag", "code"],
+)
+
+response_ru_utime = Counter(
+ "synapse_http_server_response_ru_utime_seconds", "sec", ["method", "servlet", "tag"]
+)
+
+response_ru_stime = Counter(
+ "synapse_http_server_response_ru_stime_seconds", "sec", ["method", "servlet", "tag"]
+)
+
+response_db_txn_count = Counter(
+ "synapse_http_server_response_db_txn_count", "", ["method", "servlet", "tag"]
+)
+
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+response_db_txn_duration = Counter(
+ "synapse_http_server_response_db_txn_duration_seconds",
+ "",
+ ["method", "servlet", "tag"],
+)
+
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = Counter(
+ "synapse_http_server_response_db_sched_duration_seconds",
+ "",
+ ["method", "servlet", "tag"],
+)
+
+# size in bytes of the response written
+response_size = Counter(
+ "synapse_http_server_response_size", "", ["method", "servlet", "tag"]
+)
+
+# In flight metrics are incremented while the requests are in flight, rather
+# than when the response was written.
+
+in_flight_requests_ru_utime = Counter(
+ "synapse_http_server_in_flight_requests_ru_utime_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+in_flight_requests_ru_stime = Counter(
+ "synapse_http_server_in_flight_requests_ru_stime_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+in_flight_requests_db_txn_count = Counter(
+ "synapse_http_server_in_flight_requests_db_txn_count", "", ["method", "servlet"]
+)
+
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+in_flight_requests_db_txn_duration = Counter(
+ "synapse_http_server_in_flight_requests_db_txn_duration_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+# seconds spent waiting for a db connection, when processing this request
+in_flight_requests_db_sched_duration = Counter(
+ "synapse_http_server_in_flight_requests_db_sched_duration_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+# The set of all in flight requests, set[RequestMetrics]
+_in_flight_requests = set()
+
+
+def _get_in_flight_counts():
+ """Returns a count of all in flight requests by (method, server_name)
+
+ Returns:
+ dict[tuple[str, str], int]
+ """
+ # Cast to a list to prevent it changing while the Prometheus
+ # thread is collecting metrics
+ reqs = list(_in_flight_requests)
+
+ for rm in reqs:
+ rm.update_metrics()
+
+ # Map from (method, name) -> int, the number of in flight requests of that
+ # type
+ counts = {}
+ for rm in reqs:
+ key = (rm.method, rm.name,)
+ counts[key] = counts.get(key, 0) + 1
+
+ return counts
+
+
+LaterGauge(
+ "synapse_http_server_in_flight_requests_count",
+ "",
+ ["method", "servlet"],
+ _get_in_flight_counts,
+)
+
+
+class RequestMetrics(object):
+ def start(self, time_sec, name, method):
+ self.start = time_sec
+ self.start_context = LoggingContext.current_context()
+ self.name = name
+ self.method = method
+
+ # _request_stats records resource usage that we have already added
+ # to the "in flight" metrics.
+ self._request_stats = self.start_context.get_resource_usage()
+
+ _in_flight_requests.add(self)
+
+ def stop(self, time_sec, request):
+ _in_flight_requests.discard(self)
+
+ context = LoggingContext.current_context()
+
+ tag = ""
+ if context:
+ tag = context.tag
+
+ if context != self.start_context:
+ logger.warn(
+ "Context have unexpectedly changed %r, %r",
+ context, self.start_context
+ )
+ return
+
+ response_code = str(request.code)
+
+ outgoing_responses_counter.labels(request.method, response_code).inc()
+
+ response_count.labels(request.method, self.name, tag).inc()
+
+ response_timer.labels(request.method, self.name, tag, response_code).observe(
+ time_sec - self.start
+ )
+
+ resource_usage = context.get_resource_usage()
+
+ response_ru_utime.labels(request.method, self.name, tag).inc(
+ resource_usage.ru_utime,
+ )
+ response_ru_stime.labels(request.method, self.name, tag).inc(
+ resource_usage.ru_stime,
+ )
+ response_db_txn_count.labels(request.method, self.name, tag).inc(
+ resource_usage.db_txn_count
+ )
+ response_db_txn_duration.labels(request.method, self.name, tag).inc(
+ resource_usage.db_txn_duration_sec
+ )
+ response_db_sched_duration.labels(request.method, self.name, tag).inc(
+ resource_usage.db_sched_duration_sec
+ )
+
+ response_size.labels(request.method, self.name, tag).inc(request.sentLength)
+
+ # We always call this at the end to ensure that we update the metrics
+ # regardless of whether a call to /metrics while the request was in
+ # flight.
+ self.update_metrics()
+
+ def update_metrics(self):
+ """Updates the in flight metrics with values from this request.
+ """
+ new_stats = self.start_context.get_resource_usage()
+
+ diff = new_stats - self._request_stats
+ self._request_stats = new_stats
+
+ in_flight_requests_ru_utime.labels(self.method, self.name).inc(diff.ru_utime)
+ in_flight_requests_ru_stime.labels(self.method, self.name).inc(diff.ru_stime)
+
+ in_flight_requests_db_txn_count.labels(self.method, self.name).inc(
+ diff.db_txn_count
+ )
+
+ in_flight_requests_db_txn_duration.labels(self.method, self.name).inc(
+ diff.db_txn_duration_sec
+ )
+
+ in_flight_requests_db_sched_duration.labels(self.method, self.name).inc(
+ diff.db_sched_duration_sec
+ )
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 7ef3d526b1..c70fdbdfd2 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,148 +13,205 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import cgi
+import collections
+import logging
+import urllib
+from six.moves import http_client
-from synapse.api.errors import (
- cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes
-)
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.caches import intern_dict
-from synapse.util.metrics import Measure
-import synapse.metrics
-import synapse.events
-
-from canonicaljson import (
- encode_canonical_json, encode_pretty_printed_json
-)
+from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
from twisted.internet import defer
-from twisted.web import server, resource
+from twisted.python import failure
+from twisted.web import resource, server
from twisted.web.server import NOT_DONE_YET
from twisted.web.util import redirectTo
-import collections
-import logging
-import urllib
-import ujson
+import synapse.events
+import synapse.metrics
+from synapse.api.errors import (
+ CodeMessageException,
+ Codes,
+ SynapseError,
+ UnrecognizedRequestError,
+ cs_exception,
+)
+from synapse.http.request_metrics import requests_counter
+from synapse.util.caches import intern_dict
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
+HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
+<html lang=en>
+ <head>
+ <meta charset="utf-8">
+ <title>Error {code}</title>
+ </head>
+ <body>
+ <p>{msg}</p>
+ </body>
+</html>
+"""
-incoming_requests_counter = metrics.register_counter(
- "requests",
- labels=["method", "servlet", "tag"],
-)
-outgoing_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
-)
-response_timer = metrics.register_distribution(
- "response_time",
- labels=["method", "servlet", "tag"]
-)
+def wrap_json_request_handler(h):
+ """Wraps a request handler method with exception handling.
-response_ru_utime = metrics.register_distribution(
- "response_ru_utime", labels=["method", "servlet", "tag"]
-)
+ Also adds logging as per wrap_request_handler_with_logging.
-response_ru_stime = metrics.register_distribution(
- "response_ru_stime", labels=["method", "servlet", "tag"]
-)
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
-response_db_txn_count = metrics.register_distribution(
- "response_db_txn_count", labels=["method", "servlet", "tag"]
-)
+ The handler must return a deferred. If the deferred succeeds we assume that
+ a response has been sent. If the deferred fails with a SynapseError we use
+ it to send a JSON response with the appropriate HTTP reponse code. If the
+ deferred fails with any other type of error we send a 500 reponse.
+ """
-response_db_txn_duration = metrics.register_distribution(
- "response_db_txn_duration", labels=["method", "servlet", "tag"]
-)
+ @defer.inlineCallbacks
+ def wrapped_request_handler(self, request):
+ try:
+ yield h(self, request)
+ except CodeMessageException as e:
+ code = e.code
+ if isinstance(e, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, e.msg
+ )
+ else:
+ logger.exception(e)
+ respond_with_json(
+ request, code, cs_exception(e), send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
+ except Exception:
+ # failure.Failure() fishes the original Failure out
+ # of our stack, and thus gives us a sensible stack
+ # trace.
+ f = failure.Failure()
+ logger.error(
+ "Failed handle request via %r: %r: %s",
+ h,
+ request,
+ f.getTraceback().rstrip(),
+ )
+ respond_with_json(
+ request,
+ 500,
+ {
+ "error": "Internal server error",
+ "errcode": Codes.UNKNOWN,
+ },
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
-_next_request_id = 0
+ return wrap_request_handler_with_logging(wrapped_request_handler)
-def request_handler(include_metrics=False):
- """Decorator for ``wrap_request_handler``"""
- return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
+def wrap_html_request_handler(h):
+ """Wraps a request handler method with exception handling.
+ Also adds logging as per wrap_request_handler_with_logging.
-def wrap_request_handler(request_handler, include_metrics=False):
- """Wraps a method that acts as a request handler with the necessary logging
- and exception handling.
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
+ """
+ def wrapped_request_handler(self, request):
+ d = defer.maybeDeferred(h, self, request)
+ d.addErrback(_return_html_error, request)
+ return d
- The method must have a signature of "handle_foo(self, request)". The
- argument "self" must have "version_string" and "clock" attributes. The
- argument "request" must be a twisted HTTP request.
+ return wrap_request_handler_with_logging(wrapped_request_handler)
- The method must return a deferred. If the deferred succeeds we assume that
- a response has been sent. If the deferred fails with a SynapseError we use
- it to send a JSON response with the appropriate HTTP reponse code. If the
- deferred fails with any other type of error we send a 500 reponse.
- We insert a unique request-id into the logging context for this request and
- log the response and duration for this request.
+def _return_html_error(f, request):
+ """Sends an HTML error page corresponding to the given failure
+
+ Args:
+ f (twisted.python.failure.Failure):
+ request (twisted.web.iweb.IRequest):
"""
+ if f.check(CodeMessageException):
+ cme = f.value
+ code = cme.code
+ msg = cme.msg
+
+ if isinstance(cme, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, msg
+ )
+ else:
+ logger.error(
+ "Failed handle request %r: %s",
+ request,
+ f.getTraceback().rstrip(),
+ )
+ else:
+ code = http_client.INTERNAL_SERVER_ERROR
+ msg = "Internal server error"
+
+ logger.error(
+ "Failed handle request %r: %s",
+ request,
+ f.getTraceback().rstrip(),
+ )
+
+ body = HTML_ERROR_TEMPLATE.format(
+ code=code, msg=cgi.escape(msg),
+ ).encode("utf-8")
+ request.setResponseCode(code)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%i" % (len(body),))
+ request.write(body)
+ finish_request(request)
+
+def wrap_request_handler_with_logging(h):
+ """Wraps a request handler to provide logging and metrics
+
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
+
+ As well as calling `request.processing` (which will log the response and
+ duration for this request), the wrapped request handler will insert the
+ request id into the logging context.
+ """
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
- global _next_request_id
- request_id = "%s-%s" % (request.method, _next_request_id)
- _next_request_id += 1
+ """
+ Args:
+ self:
+ request (synapse.http.site.SynapseRequest):
+ """
+ request_id = request.get_request_id()
with LoggingContext(request_id) as request_context:
+ request_context.request = request_id
with Measure(self.clock, "wrapped_request_handler"):
- request_metrics = RequestMetrics()
- request_metrics.start(self.clock, name=self.__class__.__name__)
-
- request_context.request = request_id
- with request.processing():
- try:
- with PreserveLoggingContext(request_context):
- if include_metrics:
- yield request_handler(self, request, request_metrics)
- else:
- yield request_handler(self, request)
- except CodeMessageException as e:
- code = e.code
- if isinstance(e, SynapseError):
- logger.info(
- "%s SynapseError: %s - %s", request, code, e.msg
- )
- else:
- logger.exception(e)
- outgoing_responses_counter.inc(request.method, str(code))
- respond_with_json(
- request, code, cs_exception(e), send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
- )
- except:
- logger.exception(
- "Failed handle request %s.%s on %r: %r",
- request_handler.__module__,
- request_handler.__name__,
- self,
- request
- )
- respond_with_json(
- request,
- 500,
- {
- "error": "Internal server error",
- "errcode": Codes.UNKNOWN,
- },
- send_cors=True
- )
- finally:
- try:
- request_metrics.stop(
- self.clock, request
- )
- except Exception as e:
- logger.warn("Failed to stop metrics: %r", e)
+ # we start the request metrics timer here with an initial stab
+ # at the servlet name. For most requests that name will be
+ # JsonResource (or a subclass), and JsonResource._async_render
+ # will update it once it picks a servlet.
+ servlet_name = self.__class__.__name__
+ with request.processing(servlet_name):
+ with PreserveLoggingContext(request_context):
+ d = defer.maybeDeferred(h, self, request)
+
+ # record the arrival of the request *after*
+ # dispatching to the handler, so that the handler
+ # can update the servlet name in the request
+ # metrics
+ requests_counter.labels(request.method,
+ request.request_metrics.name).inc()
+ yield d
return wrapped_request_handler
@@ -183,7 +241,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This implements the HttpServer interface and provides JSON support for
Resources.
- Register callbacks via register_path()
+ Register callbacks via register_paths()
Callbacks can return a tuple of status code and a dict in which case the
the dict will automatically be sent to the client as a JSON object.
@@ -203,7 +261,6 @@ class JsonResource(HttpServer, resource.Resource):
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
- self.version_string = hs.version_string
self.hs = hs
def register_paths(self, method, path_patterns, callback):
@@ -219,122 +276,103 @@ class JsonResource(HttpServer, resource.Resource):
self._async_render(request)
return server.NOT_DONE_YET
- # Disable metric reporting because _async_render does its own metrics.
- # It does its own metric reporting because _async_render dispatches to
- # a callback and it's the class name of that callback we want to report
- # against rather than the JsonResource itself.
- @request_handler(include_metrics=True)
+ @wrap_json_request_handler
@defer.inlineCallbacks
- def _async_render(self, request, request_metrics):
+ def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
"""
- if request.method == "OPTIONS":
- self._send_response(request, 200, {})
- return
+ callback, group_dict = self._get_handler_for_request(request)
- # Loop through all the registered callbacks to check if the method
- # and path regex match
- for path_entry in self.path_regexs.get(request.method, []):
- m = path_entry.pattern.match(request.path)
- if not m:
- continue
+ servlet_instance = getattr(callback, "__self__", None)
+ if servlet_instance is not None:
+ servlet_classname = servlet_instance.__class__.__name__
+ else:
+ servlet_classname = "%r" % callback
+ request.request_metrics.name = servlet_classname
- # We found a match! Trigger callback and then return the
- # returned response. We pass both the request and any
- # matched groups from the regex to the callback.
+ # Now trigger the callback. If it returns a response, we send it
+ # here. If it throws an exception, that is handled by the wrapper
+ # installed by @request_handler.
- callback = path_entry.callback
+ kwargs = intern_dict({
+ name: urllib.unquote(value).decode("UTF-8") if value else value
+ for name, value in group_dict.items()
+ })
- kwargs = intern_dict({
- name: urllib.unquote(value).decode("UTF-8") if value else value
- for name, value in m.groupdict().items()
- })
+ callback_return = yield callback(request, **kwargs)
+ if callback_return is not None:
+ code, response = callback_return
+ self._send_response(request, code, response)
- callback_return = yield callback(request, **kwargs)
- if callback_return is not None:
- code, response = callback_return
- self._send_response(request, code, response)
+ def _get_handler_for_request(self, request):
+ """Finds a callback method to handle the given request
- servlet_instance = getattr(callback, "__self__", None)
- if servlet_instance is not None:
- servlet_classname = servlet_instance.__class__.__name__
- else:
- servlet_classname = "%r" % callback
+ Args:
+ request (twisted.web.http.Request):
- request_metrics.name = servlet_classname
+ Returns:
+ Tuple[Callable, dict[str, str]]: callback method, and the dict
+ mapping keys to path components as specified in the handler's
+ path match regexp.
- return
+ The callback will normally be a method registered via
+ register_paths, so will return (possibly via Deferred) either
+ None, or a tuple of (http code, response body).
+ """
+ if request.method == b"OPTIONS":
+ return _options_handler, {}
+
+ # Loop through all the registered callbacks to check if the method
+ # and path regex match
+ for path_entry in self.path_regexs.get(request.method, []):
+ m = path_entry.pattern.match(request.path)
+ if m:
+ # We found a match!
+ return path_entry.callback, m.groupdict()
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- raise UnrecognizedRequestError()
+ return _unrecognised_request_handler, {}
def _send_response(self, request, code, response_json_object,
response_code_message=None):
- # could alternatively use request.notifyFinish() and flip a flag when
- # the Deferred fires, but since the flag is RIGHT THERE it seems like
- # a waste.
- if request._disconnected:
- logger.warn(
- "Not sending response to request %s, already disconnected.",
- request)
- return
-
- outgoing_responses_counter.inc(request.method, str(code))
-
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request, code, response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
canonical_json=self.canonical_json,
)
-class RequestMetrics(object):
- def start(self, clock, name):
- self.start = clock.time_msec()
- self.start_context = LoggingContext.current_context()
- self.name = name
+def _options_handler(request):
+ """Request handler for OPTIONS requests
- def stop(self, clock, request):
- context = LoggingContext.current_context()
+ This is a request handler suitable for return from
+ _get_handler_for_request. It returns a 200 and an empty body.
- tag = ""
- if context:
- tag = context.tag
+ Args:
+ request (twisted.web.http.Request):
- if context != self.start_context:
- logger.warn(
- "Context have unexpectedly changed %r, %r",
- context, self.start_context
- )
- return
+ Returns:
+ Tuple[int, dict]: http code, response body.
+ """
+ return 200, {}
- incoming_requests_counter.inc(request.method, self.name, tag)
- response_timer.inc_by(
- clock.time_msec() - self.start, request.method,
- self.name, tag
- )
+def _unrecognised_request_handler(request):
+ """Request handler for unrecognised requests
- ru_utime, ru_stime = context.get_resource_usage()
+ This is a request handler suitable for return from
+ _get_handler_for_request. It actually just raises an
+ UnrecognizedRequestError.
- response_ru_utime.inc_by(
- ru_utime, request.method, self.name, tag
- )
- response_ru_stime.inc_by(
- ru_stime, request.method, self.name, tag
- )
- response_db_txn_count.inc_by(
- context.db_txn_count, request.method, self.name, tag
- )
- response_db_txn_duration.inc_by(
- context.db_txn_duration, request.method, self.name, tag
- )
+ Args:
+ request (twisted.web.http.Request):
+ """
+ raise UnrecognizedRequestError()
class RootRedirect(resource.Resource):
@@ -355,26 +393,33 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
- version_string="", canonical_json=True):
+ canonical_json=True):
+ # could alternatively use request.notifyFinish() and flip a flag when
+ # the Deferred fires, but since the flag is RIGHT THERE it seems like
+ # a waste.
+ if request._disconnected:
+ logger.warn(
+ "Not sending response to request %s, already disconnected.",
+ request)
+ return
+
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
json_bytes = encode_canonical_json(json_object)
else:
- # ujson doesn't like frozen_dicts.
- json_bytes = ujson.dumps(json_object, ensure_ascii=False)
+ json_bytes = json.dumps(json_object)
return respond_with_json_bytes(
request, code, json_bytes,
send_cors=send_cors,
response_code_message=response_code_message,
- version_string=version_string
)
def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
- version_string="", response_code_message=None):
+ response_code_message=None):
"""Sends encoded JSON in response to the given request.
Args:
@@ -388,8 +433,8 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setResponseCode(code, message=response_code_message)
request.setHeader(b"Content-Type", b"application/json")
- request.setHeader(b"Server", version_string)
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
+ request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
if send_cors:
set_cors_headers(request)
@@ -437,9 +482,9 @@ def finish_request(request):
def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders(
- "User-Agent", default=[]
+ b"User-Agent", default=[]
)
for user_agent in user_agents:
- if "curl" in user_agent:
+ if b"curl" in user_agent:
return True
return False
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 9a4c36ad5d..882816dc8f 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -15,10 +15,11 @@
""" This module contains base REST classes for constructing REST servlets. """
-from synapse.api.errors import SynapseError, Codes
-
import logging
-import simplejson
+
+from canonicaljson import json
+
+from synapse.api.errors import Codes, SynapseError
logger = logging.getLogger(__name__)
@@ -48,7 +49,7 @@ def parse_integer_from_args(args, name, default=None, required=False):
if name in args:
try:
return int(args[name][0])
- except:
+ except Exception:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message)
else:
@@ -88,7 +89,7 @@ def parse_boolean_from_args(args, name, default=None, required=False):
"true": True,
"false": False,
}[args[name][0]]
- except:
+ except Exception:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
@@ -148,11 +149,13 @@ def parse_string_from_args(args, name, default=None, required=False,
return default
-def parse_json_value_from_request(request):
+def parse_json_value_from_request(request, allow_empty_body=False):
"""Parse a JSON value from the body of a twisted HTTP request.
Args:
request: the twisted HTTP request.
+ allow_empty_body (bool): if True, an empty body will be accepted and
+ turned into None
Returns:
The JSON value.
@@ -162,28 +165,39 @@ def parse_json_value_from_request(request):
"""
try:
content_bytes = request.content.read()
- except:
+ except Exception:
raise SynapseError(400, "Error reading JSON content.")
+ if not content_bytes and allow_empty_body:
+ return None
+
try:
- content = simplejson.loads(content_bytes)
- except simplejson.JSONDecodeError:
+ content = json.loads(content_bytes)
+ except Exception as e:
+ logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content
-def parse_json_object_from_request(request):
+def parse_json_object_from_request(request, allow_empty_body=False):
"""Parse a JSON object from the body of a twisted HTTP request.
Args:
request: the twisted HTTP request.
+ allow_empty_body (bool): if True, an empty body will be accepted and
+ turned into an empty dict.
Raises:
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
- content = parse_json_value_from_request(request)
+ content = parse_json_value_from_request(
+ request, allow_empty_body=allow_empty_body,
+ )
+
+ if allow_empty_body and content is None:
+ return {}
if type(content) != dict:
message = "Content must be a JSON object."
@@ -192,7 +206,7 @@ def parse_json_object_from_request(request):
return content
-def assert_params_in_request(body, required):
+def assert_params_in_dict(body, required):
absent = []
for k in required:
if k not in body:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 4b09d7ee66..5fd30a4c2c 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -12,27 +12,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext
-from twisted.web.server import Site, Request
-
import contextlib
import logging
-import re
import time
-ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+from twisted.web.server import Request, Site
+
+from synapse.http import redact_uri
+from synapse.http.request_metrics import RequestMetrics
+from synapse.util.logcontext import ContextResourceUsage, LoggingContext
+
+logger = logging.getLogger(__name__)
+
+_next_request_seq = 0
class SynapseRequest(Request):
- def __init__(self, site, *args, **kw):
- Request.__init__(self, *args, **kw)
+ """Class which encapsulates an HTTP request to synapse.
+
+ All of the requests processed in synapse are of this type.
+
+ It extends twisted's twisted.web.server.Request, and adds:
+ * Unique request ID
+ * Redaction of access_token query-params in __repr__
+ * Logging at start and end
+ * Metrics to record CPU, wallclock and DB time by endpoint.
+
+ It provides a method `processing` which should be called by the Resource
+ which is handling the request, and returns a context manager.
+
+ """
+ def __init__(self, site, channel, *args, **kw):
+ Request.__init__(self, channel, *args, **kw)
self.site = site
+ self._channel = channel
self.authenticated_entity = None
self.start_time = 0
+ global _next_request_seq
+ self.request_seq = _next_request_seq
+ _next_request_seq += 1
+
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
- return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
+ return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
self.__class__.__name__,
id(self),
self.method,
@@ -41,16 +64,27 @@ class SynapseRequest(Request):
self.site.site_tag,
)
+ def get_request_id(self):
+ return "%s-%i" % (self.method, self.request_seq)
+
def get_redacted_uri(self):
- return ACCESS_TOKEN_RE.sub(
- r'\1<redacted>\3',
- self.uri
- )
+ return redact_uri(self.uri)
def get_user_agent(self):
- return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
+ return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
+
+ def render(self, resrc):
+ # override the Server header which is set by twisted
+ self.setHeader("Server", self.site.server_version_string)
+ return Request.render(self, resrc)
+
+ def _started_processing(self, servlet_name):
+ self.start_time = time.time()
+ self.request_metrics = RequestMetrics()
+ self.request_metrics.start(
+ self.start_time, name=servlet_name, method=self.method,
+ )
- def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
@@ -58,44 +92,85 @@ class SynapseRequest(Request):
self.method,
self.get_redacted_uri()
)
- self.start_time = int(time.time() * 1000)
-
- def finished_processing(self):
+ def _finished_processing(self):
try:
context = LoggingContext.current_context()
- ru_utime, ru_stime = context.get_resource_usage()
- db_txn_count = context.db_txn_count
- db_txn_duration = context.db_txn_duration
- except:
- ru_utime, ru_stime = (0, 0)
- db_txn_count, db_txn_duration = (0, 0)
+ usage = context.get_resource_usage()
+ except Exception:
+ usage = ContextResourceUsage()
+
+ end_time = time.time()
+
+ # need to decode as it could be raw utf-8 bytes
+ # from a IDN servname in an auth header
+ authenticated_entity = self.authenticated_entity
+ if authenticated_entity is not None:
+ authenticated_entity = authenticated_entity.decode("utf-8", "replace")
+
+ # ...or could be raw utf-8 bytes in the User-Agent header.
+ # N.B. if you don't do this, the logger explodes cryptically
+ # with maximum recursion trying to log errors about
+ # the charset problem.
+ # c.f. https://github.com/matrix-org/synapse/issues/3471
+ user_agent = self.get_user_agent()
+ if user_agent is not None:
+ user_agent = user_agent.decode("utf-8", "replace")
self.site.access_logger.info(
"%s - %s - {%s}"
- " Processed request: %dms (%dms, %dms) (%dms/%d)"
- " %sB %s \"%s %s %s\" \"%s\"",
+ " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
+ " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(),
self.site.site_tag,
- self.authenticated_entity,
- int(time.time() * 1000) - self.start_time,
- int(ru_utime * 1000),
- int(ru_stime * 1000),
- int(db_txn_duration * 1000),
- int(db_txn_count),
+ authenticated_entity,
+ end_time - self.start_time,
+ usage.ru_utime,
+ usage.ru_stime,
+ usage.db_sched_duration_sec,
+ usage.db_txn_duration_sec,
+ int(usage.db_txn_count),
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
- self.get_user_agent(),
+ user_agent,
+ usage.evt_db_fetch_count,
)
+ try:
+ self.request_metrics.stop(end_time, self)
+ except Exception as e:
+ logger.warn("Failed to stop metrics: %r", e)
+
@contextlib.contextmanager
- def processing(self):
- self.started_processing()
+ def processing(self, servlet_name):
+ """Record the fact that we are processing this request.
+
+ Returns a context manager; the correct way to use this is:
+
+ @defer.inlineCallbacks
+ def handle_request(request):
+ with request.processing("FooServlet"):
+ yield really_handle_the_request()
+
+ This will log the request's arrival. Once the context manager is
+ closed, the completion of the request will be logged, and the various
+ metrics will be updated.
+
+ Args:
+ servlet_name (str): the name of the servlet which will be
+ processing this request. This is used in the metrics.
+
+ It is possible to update this afterwards by updating
+ self.request_metrics.servlet_name.
+ """
+ # TODO: we should probably just move this into render() and finish(),
+ # to save having to call a separate method.
+ self._started_processing(servlet_name)
yield
- self.finished_processing()
+ self._finished_processing()
class XForwardedForRequest(SynapseRequest):
@@ -133,7 +208,8 @@ class SynapseSite(Site):
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
- def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
+ def __init__(self, logger_name, site_tag, config, resource,
+ server_version_string, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
@@ -141,6 +217,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
+ self.server_version_string = server_version_string
def log(self, request):
pass
|