diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
new file mode 100644
index 0000000000..343e932cb1
--- /dev/null
+++ b/synapse/http/additional_resource.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import wrap_request_handler
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+
+class AdditionalResource(Resource):
+ """Resource wrapper for additional_resources
+
+ If the user has configured additional_resources, we need to wrap the
+ handler class with a Resource so that we can map it into the resource tree.
+
+ This class is also where we wrap the request handler with logging, metrics,
+ and exception handling.
+ """
+ def __init__(self, hs, handler):
+ """Initialise AdditionalResource
+
+ The ``handler`` should return a deferred which completes when it has
+ done handling the request. It should write a response with
+ ``request.write()``, and call ``request.finish()``.
+
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
+ function to be called to handle the request.
+ """
+ Resource.__init__(self)
+ self._handler = handler
+
+ # these are required by the request_handler wrapper
+ self.version_string = hs.version_string
+ self.clock = hs.get_clock()
+
+ def render(self, request):
+ self._async_render(request)
+ return NOT_DONE_YET
+
+ @wrap_request_handler
+ def _async_render(self, request):
+ return self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 9eba046bbf..4abb479ae3 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -18,7 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.logcontext import make_deferred_yieldable
from synapse.util import logcontext
import synapse.metrics
from synapse.http.endpoint import SpiderEndpoint
@@ -114,43 +114,73 @@ class SimpleHttpClient(object):
raise e
@defer.inlineCallbacks
- def post_urlencoded_get_json(self, uri, args={}):
+ def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ """
+ Args:
+ uri (str):
+ args (dict[str, str|List[str]]): query params
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
+
+ Returns:
+ Deferred[object]: parsed json
+ """
+
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
+ actual_headers = {
+ b"Content-Type": [b"application/x-www-form-urlencoded"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"POST",
uri.encode("ascii"),
- headers=Headers({
- b"Content-Type": [b"application/x-www-form-urlencoded"],
- b"User-Agent": [self.user_agent],
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def post_json_get_json(self, uri, post_json):
+ def post_json_get_json(self, uri, post_json, headers=None):
+ """
+
+ Args:
+ uri (str):
+ post_json (object):
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
+
+ Returns:
+ Deferred[object]: parsed json
+ """
json_str = encode_canonical_json(post_json)
logger.debug("HTTP POST %s -> %s", json_str, uri)
+ actual_headers = {
+ b"Content-Type": [b"application/json"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"POST",
uri.encode("ascii"),
- headers=Headers({
- b"Content-Type": [b"application/json"],
- b"User-Agent": [self.user_agent],
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -160,7 +190,7 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def get_json(self, uri, args={}):
+ def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
Args:
@@ -169,6 +199,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@@ -177,13 +209,13 @@ class SimpleHttpClient(object):
error message.
"""
try:
- body = yield self.get_raw(uri, args)
+ body = yield self.get_raw(uri, args, headers=headers)
defer.returnValue(json.loads(body))
except CodeMessageException as e:
raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks
- def put_json(self, uri, json_body, args={}):
+ def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
Args:
@@ -193,6 +225,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@@ -205,17 +239,21 @@ class SimpleHttpClient(object):
json_str = encode_canonical_json(json_body)
+ actual_headers = {
+ b"Content-Type": [b"application/json"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"PUT",
uri.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- "Content-Type": ["application/json"]
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -226,7 +264,7 @@ class SimpleHttpClient(object):
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
- def get_raw(self, uri, args={}):
+ def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
Args:
@@ -235,6 +273,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
@@ -246,15 +286,19 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
+ actual_headers = {
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"GET",
uri.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(actual_headers),
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(body)
@@ -274,27 +318,33 @@ class SimpleHttpClient(object):
# The two should be factored out.
@defer.inlineCallbacks
- def get_file(self, url, output_stream, max_size=None):
+ def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
"""
+ actual_headers = {
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"GET",
url.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(actual_headers),
)
- headers = dict(response.headers.getAllRawHeaders())
+ resp_headers = dict(response.headers.getAllRawHeaders())
- if 'Content-Length' in headers and headers['Content-Length'] > max_size:
+ if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
@@ -315,10 +365,9 @@ class SimpleHttpClient(object):
# straight back in again
try:
- length = yield preserve_context_over_fn(
- _readBodyToFile,
- response, output_stream, max_size
- )
+ length = yield make_deferred_yieldable(_readBodyToFile(
+ response, output_stream, max_size,
+ ))
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
@@ -327,7 +376,9 @@ class SimpleHttpClient(object):
Codes.UNKNOWN,
)
- defer.returnValue((length, headers, response.request.absoluteURI, response.code))
+ defer.returnValue(
+ (length, resp_headers, response.request.absoluteURI, response.code),
+ )
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
@@ -395,7 +446,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
)
try:
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(body)
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.
|