summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/frontend_proxy.py10
-rw-r--r--synapse/appservice/api.py10
-rw-r--r--synapse/config/logger.py6
-rw-r--r--synapse/federation/federation_server.py8
-rw-r--r--synapse/handlers/room_list.py9
-rw-r--r--synapse/handlers/sync.py6
-rw-r--r--synapse/http/client.py127
-rw-r--r--synapse/rest/client/v2_alpha/devices.py7
-rw-r--r--synapse/rest/client/v2_alpha/keys.py18
-rw-r--r--synapse/rest/client/v2_alpha/notifications.py2
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py2
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py11
12 files changed, 137 insertions, 79 deletions
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index bee4c47498..abc7ef5725 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -50,8 +50,7 @@ logger = logging.getLogger("synapse.app.frontend_proxy")
 
 
 class KeyUploadServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
-                                  releases=())
+    PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
 
     def __init__(self, hs):
         """
@@ -89,9 +88,16 @@ class KeyUploadServlet(RestServlet):
 
         if body:
             # They're actually trying to upload something, proxy to main synapse.
+            # Pass through the auth headers, if any, in case the access token
+            # is there.
+            auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
+            headers = {
+                "Authorization": auth_headers,
+            }
             result = yield self.http_client.post_json_get_json(
                 self.main_uri + request.uri,
                 body,
+                headers=headers,
             )
 
             defer.returnValue((200, result))
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 6893610e71..40c433d7ae 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -18,6 +18,7 @@ from synapse.api.constants import ThirdPartyEntityKind
 from synapse.api.errors import CodeMessageException
 from synapse.http.client import SimpleHttpClient
 from synapse.events.utils import serialize_event
+from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.types import ThirdPartyInstanceID
 
@@ -192,9 +193,12 @@ class ApplicationServiceApi(SimpleHttpClient):
                 defer.returnValue(None)
 
         key = (service.id, protocol)
-        return self.protocol_meta_cache.get(key) or (
-            self.protocol_meta_cache.set(key, _get())
-        )
+        result = self.protocol_meta_cache.get(key)
+        if not result:
+            result = self.protocol_meta_cache.set(
+                key, preserve_fn(_get)()
+            )
+        return make_deferred_yieldable(result)
 
     @defer.inlineCallbacks
     def push_bulk(self, service, events, txn_id=None):
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 2dbeafa9dd..a1d6e4d4f7 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -148,8 +148,8 @@ def setup_logging(config, use_worker_options=False):
         "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
         " - %(message)s"
     )
-    if log_config is None:
 
+    if log_config is None:
         level = logging.INFO
         level_for_storage = logging.INFO
         if config.verbosity:
@@ -176,6 +176,10 @@ def setup_logging(config, use_worker_options=False):
                 logger.info("Opened new log file due to SIGHUP")
         else:
             handler = logging.StreamHandler()
+
+            def sighup(signum, stack):
+                pass
+
         handler.setFormatter(formatter)
 
         handler.addFilter(LoggingContextFilter(request=""))
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e15228e70b..a2327f24b6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -18,6 +18,7 @@ from .federation_base import FederationBase
 from .units import Transaction, Edu
 
 from synapse.util import async
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
 from synapse.util.logutils import log_function
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.events import FrozenEvent
@@ -253,12 +254,13 @@ class FederationServer(FederationBase):
         result = self._state_resp_cache.get((room_id, event_id))
         if not result:
             with (yield self._server_linearizer.queue((origin, room_id))):
-                resp = yield self._state_resp_cache.set(
+                d = self._state_resp_cache.set(
                     (room_id, event_id),
-                    self._on_context_state_request_compute(room_id, event_id)
+                    preserve_fn(self._on_context_state_request_compute)(room_id, event_id)
                 )
+                resp = yield make_deferred_yieldable(d)
         else:
-            resp = yield result
+            resp = yield make_deferred_yieldable(result)
 
         defer.returnValue((200, resp))
 
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 41e1781df7..2cf34e51cb 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -20,6 +20,7 @@ from ._base import BaseHandler
 from synapse.api.constants import (
     EventTypes, JoinRules,
 )
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
 from synapse.util.async import concurrently_execute
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.util.caches.response_cache import ResponseCache
@@ -70,6 +71,7 @@ class RoomListHandler(BaseHandler):
         if search_filter:
             # We explicitly don't bother caching searches or requests for
             # appservice specific lists.
+            logger.info("Bypassing cache as search request.")
             return self._get_public_room_list(
                 limit, since_token, search_filter, network_tuple=network_tuple,
             )
@@ -77,13 +79,16 @@ class RoomListHandler(BaseHandler):
         key = (limit, since_token, network_tuple)
         result = self.response_cache.get(key)
         if not result:
+            logger.info("No cached result, calculating one.")
             result = self.response_cache.set(
                 key,
-                self._get_public_room_list(
+                preserve_fn(self._get_public_room_list)(
                     limit, since_token, network_tuple=network_tuple
                 )
             )
-        return result
+        else:
+            logger.info("Using cached deferred result.")
+        return make_deferred_yieldable(result)
 
     @defer.inlineCallbacks
     def _get_public_room_list(self, limit=None, since_token=None,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 219529936f..b12988f3c9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -15,7 +15,7 @@
 
 from synapse.api.constants import Membership, EventTypes
 from synapse.util.async import concurrently_execute
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn
 from synapse.util.metrics import Measure, measure_func
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.push.clientformat import format_push_rules_for_user
@@ -184,11 +184,11 @@ class SyncHandler(object):
         if not result:
             result = self.response_cache.set(
                 sync_config.request_key,
-                self._wait_for_sync_for_user(
+                preserve_fn(self._wait_for_sync_for_user)(
                     sync_config, since_token, timeout, full_state
                 )
             )
-        return result
+        return make_deferred_yieldable(result)
 
     @defer.inlineCallbacks
     def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 9eba046bbf..4abb479ae3 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -18,7 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
 from synapse.api.errors import (
     CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
 )
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.logcontext import make_deferred_yieldable
 from synapse.util import logcontext
 import synapse.metrics
 from synapse.http.endpoint import SpiderEndpoint
@@ -114,43 +114,73 @@ class SimpleHttpClient(object):
             raise e
 
     @defer.inlineCallbacks
-    def post_urlencoded_get_json(self, uri, args={}):
+    def post_urlencoded_get_json(self, uri, args={}, headers=None):
+        """
+        Args:
+            uri (str):
+            args (dict[str, str|List[str]]): query params
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
+
+        Returns:
+            Deferred[object]: parsed json
+        """
+
         # TODO: Do we ever want to log message contents?
         logger.debug("post_urlencoded_get_json args: %s", args)
 
         query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
 
+        actual_headers = {
+            b"Content-Type": [b"application/x-www-form-urlencoded"],
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "POST",
             uri.encode("ascii"),
-            headers=Headers({
-                b"Content-Type": [b"application/x-www-form-urlencoded"],
-                b"User-Agent": [self.user_agent],
-            }),
+            headers=Headers(actual_headers),
             bodyProducer=FileBodyProducer(StringIO(query_bytes))
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
-    def post_json_get_json(self, uri, post_json):
+    def post_json_get_json(self, uri, post_json, headers=None):
+        """
+
+        Args:
+            uri (str):
+            post_json (object):
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
+
+        Returns:
+            Deferred[object]: parsed json
+        """
         json_str = encode_canonical_json(post_json)
 
         logger.debug("HTTP POST %s -> %s", json_str, uri)
 
+        actual_headers = {
+            b"Content-Type": [b"application/json"],
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "POST",
             uri.encode("ascii"),
-            headers=Headers({
-                b"Content-Type": [b"application/json"],
-                b"User-Agent": [self.user_agent],
-            }),
+            headers=Headers(actual_headers),
             bodyProducer=FileBodyProducer(StringIO(json_str))
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
             defer.returnValue(json.loads(body))
@@ -160,7 +190,7 @@ class SimpleHttpClient(object):
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
-    def get_json(self, uri, args={}):
+    def get_json(self, uri, args={}, headers=None):
         """ Gets some json from the given URI.
 
         Args:
@@ -169,6 +199,8 @@ class SimpleHttpClient(object):
                 None.
                 **Note**: The value of each key is assumed to be an iterable
                 and *not* a string.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             Deferred: Succeeds when we get *any* 2xx HTTP response, with the
             HTTP body as JSON.
@@ -177,13 +209,13 @@ class SimpleHttpClient(object):
             error message.
         """
         try:
-            body = yield self.get_raw(uri, args)
+            body = yield self.get_raw(uri, args, headers=headers)
             defer.returnValue(json.loads(body))
         except CodeMessageException as e:
             raise self._exceptionFromFailedRequest(e.code, e.msg)
 
     @defer.inlineCallbacks
-    def put_json(self, uri, json_body, args={}):
+    def put_json(self, uri, json_body, args={}, headers=None):
         """ Puts some json to the given URI.
 
         Args:
@@ -193,6 +225,8 @@ class SimpleHttpClient(object):
                 None.
                 **Note**: The value of each key is assumed to be an iterable
                 and *not* a string.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             Deferred: Succeeds when we get *any* 2xx HTTP response, with the
             HTTP body as JSON.
@@ -205,17 +239,21 @@ class SimpleHttpClient(object):
 
         json_str = encode_canonical_json(json_body)
 
+        actual_headers = {
+            b"Content-Type": [b"application/json"],
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "PUT",
             uri.encode("ascii"),
-            headers=Headers({
-                b"User-Agent": [self.user_agent],
-                "Content-Type": ["application/json"]
-            }),
+            headers=Headers(actual_headers),
             bodyProducer=FileBodyProducer(StringIO(json_str))
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
             defer.returnValue(json.loads(body))
@@ -226,7 +264,7 @@ class SimpleHttpClient(object):
             raise CodeMessageException(response.code, body)
 
     @defer.inlineCallbacks
-    def get_raw(self, uri, args={}):
+    def get_raw(self, uri, args={}, headers=None):
         """ Gets raw text from the given URI.
 
         Args:
@@ -235,6 +273,8 @@ class SimpleHttpClient(object):
                 None.
                 **Note**: The value of each key is assumed to be an iterable
                 and *not* a string.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             Deferred: Succeeds when we get *any* 2xx HTTP response, with the
             HTTP body at text.
@@ -246,15 +286,19 @@ class SimpleHttpClient(object):
             query_bytes = urllib.urlencode(args, True)
             uri = "%s?%s" % (uri, query_bytes)
 
+        actual_headers = {
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "GET",
             uri.encode("ascii"),
-            headers=Headers({
-                b"User-Agent": [self.user_agent],
-            })
+            headers=Headers(actual_headers),
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
             defer.returnValue(body)
@@ -274,27 +318,33 @@ class SimpleHttpClient(object):
     # The two should be factored out.
 
     @defer.inlineCallbacks
-    def get_file(self, url, output_stream, max_size=None):
+    def get_file(self, url, output_stream, max_size=None, headers=None):
         """GETs a file from a given URL
         Args:
             url (str): The URL to GET
             output_stream (file): File to write the response body to.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             A (int,dict,string,int) tuple of the file length, dict of the response
             headers, absolute URI of the response and HTTP response code.
         """
 
+        actual_headers = {
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "GET",
             url.encode("ascii"),
-            headers=Headers({
-                b"User-Agent": [self.user_agent],
-            })
+            headers=Headers(actual_headers),
         )
 
-        headers = dict(response.headers.getAllRawHeaders())
+        resp_headers = dict(response.headers.getAllRawHeaders())
 
-        if 'Content-Length' in headers and headers['Content-Length'] > max_size:
+        if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
             logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
             raise SynapseError(
                 502,
@@ -315,10 +365,9 @@ class SimpleHttpClient(object):
         # straight back in again
 
         try:
-            length = yield preserve_context_over_fn(
-                _readBodyToFile,
-                response, output_stream, max_size
-            )
+            length = yield make_deferred_yieldable(_readBodyToFile(
+                response, output_stream, max_size,
+            ))
         except Exception as e:
             logger.exception("Failed to download body")
             raise SynapseError(
@@ -327,7 +376,9 @@ class SimpleHttpClient(object):
                 Codes.UNKNOWN,
             )
 
-        defer.returnValue((length, headers, response.request.absoluteURI, response.code))
+        defer.returnValue(
+            (length, resp_headers, response.request.absoluteURI, response.code),
+        )
 
 
 # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
@@ -395,7 +446,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
         )
 
         try:
-            body = yield preserve_context_over_fn(readBody, response)
+            body = yield make_deferred_yieldable(readBody(response))
             defer.returnValue(body)
         except PartialDownloadError as e:
             # twisted dislikes google's response, no content length.
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index b57ba95d24..2a2438b7dc 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 
 class DevicesRestServlet(servlet.RestServlet):
-    PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
+    PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
 
     def __init__(self, hs):
         """
@@ -51,7 +51,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
     API for bulk deletion of devices. Accepts a JSON object with a devices
     key which lists the device_ids to delete. Requires user interactive auth.
     """
-    PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
+    PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
 
     def __init__(self, hs):
         super(DeleteDevicesRestServlet, self).__init__()
@@ -93,8 +93,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
 
 
 class DeviceRestServlet(servlet.RestServlet):
-    PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
-                                  releases=[], v2_alpha=False)
+    PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
 
     def __init__(self, hs):
         """
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 943e87e7fd..3cc87ea63f 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet):
       },
     }
     """
-    PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
-                                  releases=())
+    PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
 
     def __init__(self, hs):
         """
@@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet):
     } } } } } }
     """
 
-    PATTERNS = client_v2_patterns(
-        "/keys/query$",
-        releases=()
-    )
+    PATTERNS = client_v2_patterns("/keys/query$")
 
     def __init__(self, hs):
         """
@@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet):
         200 OK
         { "changed": ["@foo:example.com"] }
     """
-    PATTERNS = client_v2_patterns(
-        "/keys/changes$",
-        releases=()
-    )
+    PATTERNS = client_v2_patterns("/keys/changes$")
 
     def __init__(self, hs):
         """
@@ -213,10 +206,7 @@ class OneTimeKeyServlet(RestServlet):
     } } } }
 
     """
-    PATTERNS = client_v2_patterns(
-        "/keys/claim$",
-        releases=()
-    )
+    PATTERNS = client_v2_patterns("/keys/claim$")
 
     def __init__(self, hs):
         super(OneTimeKeyServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index fd2a3d69d4..ec170109fe 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 
 
 class NotificationsServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/notifications$", releases=())
+    PATTERNS = client_v2_patterns("/notifications$")
 
     def __init__(self, hs):
         super(NotificationsServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index d607bd2970..90bdb1db15 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 class SendToDeviceRestServlet(servlet.RestServlet):
     PATTERNS = client_v2_patterns(
         "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
-        releases=[], v2_alpha=False
+        v2_alpha=False
     )
 
     def __init__(self, hs):
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 6fceb23e26..6773b9ba60 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
 
 
 class ThirdPartyProtocolsServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
+    PATTERNS = client_v2_patterns("/thirdparty/protocols")
 
     def __init__(self, hs):
         super(ThirdPartyProtocolsServlet, self).__init__()
@@ -43,8 +43,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
 
 
 class ThirdPartyProtocolServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
-                                  releases=())
+    PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
 
     def __init__(self, hs):
         super(ThirdPartyProtocolServlet, self).__init__()
@@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet):
 
 
 class ThirdPartyUserServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
-                                  releases=())
+    PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
 
     def __init__(self, hs):
         super(ThirdPartyUserServlet, self).__init__()
@@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet):
 
 
 class ThirdPartyLocationServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
-                                  releases=())
+    PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
 
     def __init__(self, hs):
         super(ThirdPartyLocationServlet, self).__init__()