summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-01-21 14:04:19 +0000
committerErik Johnston <erik@matrix.org>2019-01-21 14:04:19 +0000
commit35e1d67b4e9c2c8b0abf35d41c2d9c56d486f6e3 (patch)
tree2c844b79665fc32b0bfc49285586394edf367131
parentNewsfile (diff)
parentMerge pull request #4390 from matrix-org/erikj/versioned_fed_apis (diff)
downloadsynapse-35e1d67b4e9c2c8b0abf35d41c2d9c56d486f6e3.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/fed_v2_invite_server
-rw-r--r--.codecov.yml15
-rw-r--r--MANIFEST.in1
-rw-r--r--changelog.d/4387.misc1
-rw-r--r--changelog.d/4392.bugfix1
-rw-r--r--changelog.d/4397.bugfix1
-rw-r--r--changelog.d/4399.misc1
-rw-r--r--changelog.d/4400.misc1
-rw-r--r--changelog.d/4404.bugfix1
-rw-r--r--changelog.d/4407.bugfix1
-rw-r--r--changelog.d/4408.misc1
-rw-r--r--changelog.d/4409.misc1
-rw-r--r--changelog.d/4411.bugfix1
-rw-r--r--synapse/api/constants.py3
-rw-r--r--synapse/config/key.py3
-rw-r--r--synapse/handlers/device.py19
-rw-r--r--synapse/handlers/identity.py9
-rw-r--r--synapse/handlers/room.py1
-rw-r--r--synapse/http/endpoint.py75
-rw-r--r--synapse/http/matrixfederationclient.py53
-rw-r--r--synapse/python_dependencies.py15
-rw-r--r--synapse/storage/events.py13
-rw-r--r--synapse/util/async_helpers.py4
-rw-r--r--tests/http/test_fedclient.py54
-rw-r--r--tests/util/test_async_utils.py104
24 files changed, 261 insertions, 118 deletions
diff --git a/.codecov.yml b/.codecov.yml
new file mode 100644
index 0000000000..a05698a39c
--- /dev/null
+++ b/.codecov.yml
@@ -0,0 +1,15 @@
+comment:
+  layout: "diff"
+
+coverage:
+  status:
+    project:
+      default:
+        target: 0  # Target % coverage, can be auto. Turned off for now
+        threshold: null
+        base: auto
+    patch:
+      default:
+        target: 0
+        threshold: null
+        base: auto
diff --git a/MANIFEST.in b/MANIFEST.in
index 29303cc8b5..7e4c031b79 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -37,6 +37,7 @@ prune docker
 prune .circleci
 prune .coveragerc
 prune debian
+prune .codecov.yml
 
 exclude jenkins*
 recursive-exclude jenkins *.sh
diff --git a/changelog.d/4387.misc b/changelog.d/4387.misc
new file mode 100644
index 0000000000..0c04a0fa9b
--- /dev/null
+++ b/changelog.d/4387.misc
@@ -0,0 +1 @@
+Fix a comment in the generated config file
diff --git a/changelog.d/4392.bugfix b/changelog.d/4392.bugfix
new file mode 100644
index 0000000000..2223f7dcd6
--- /dev/null
+++ b/changelog.d/4392.bugfix
@@ -0,0 +1 @@
+Fix typo in ALL_USER_TYPES definition to ensure type is a tuple
diff --git a/changelog.d/4397.bugfix b/changelog.d/4397.bugfix
new file mode 100644
index 0000000000..e7526d4454
--- /dev/null
+++ b/changelog.d/4397.bugfix
@@ -0,0 +1 @@
+Fix high CPU usage due to remote devicelist updates
diff --git a/changelog.d/4399.misc b/changelog.d/4399.misc
new file mode 100644
index 0000000000..2f77a8fa54
--- /dev/null
+++ b/changelog.d/4399.misc
@@ -0,0 +1 @@
+Update dependencies on msgpack and pymacaroons to use the up-to-date packages.
diff --git a/changelog.d/4400.misc b/changelog.d/4400.misc
new file mode 100644
index 0000000000..3d299dfe95
--- /dev/null
+++ b/changelog.d/4400.misc
@@ -0,0 +1 @@
+Tweak codecov settings to make them less loud.
diff --git a/changelog.d/4404.bugfix b/changelog.d/4404.bugfix
new file mode 100644
index 0000000000..5d40a3a60b
--- /dev/null
+++ b/changelog.d/4404.bugfix
@@ -0,0 +1 @@
+Fix potential bug where creating or joining a room could fail
diff --git a/changelog.d/4407.bugfix b/changelog.d/4407.bugfix
new file mode 100644
index 0000000000..54c5e76d1f
--- /dev/null
+++ b/changelog.d/4407.bugfix
@@ -0,0 +1 @@
+Fix incorrect logcontexts after a Deferred was cancelled
diff --git a/changelog.d/4408.misc b/changelog.d/4408.misc
new file mode 100644
index 0000000000..729bafd62e
--- /dev/null
+++ b/changelog.d/4408.misc
@@ -0,0 +1 @@
+Refactor 'sign_request' as 'build_auth_headers'
\ No newline at end of file
diff --git a/changelog.d/4409.misc b/changelog.d/4409.misc
new file mode 100644
index 0000000000..9cf2adfbb1
--- /dev/null
+++ b/changelog.d/4409.misc
@@ -0,0 +1 @@
+Remove redundant federation connection wrapping code
diff --git a/changelog.d/4411.bugfix b/changelog.d/4411.bugfix
new file mode 100644
index 0000000000..219e98a924
--- /dev/null
+++ b/changelog.d/4411.bugfix
@@ -0,0 +1 @@
+Ensure encrypted room state is persisted across room upgrades.
\ No newline at end of file
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 87bc1cb53d..46c4b4b9dc 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -68,6 +68,7 @@ class EventTypes(object):
     Aliases = "m.room.aliases"
     Redaction = "m.room.redaction"
     ThirdPartyInvite = "m.room.third_party_invite"
+    Encryption = "m.room.encryption"
 
     RoomHistoryVisibility = "m.room.history_visibility"
     CanonicalAlias = "m.room.canonical_alias"
@@ -128,4 +129,4 @@ class UserTypes(object):
     'admin' and 'guest' users should also be UserTypes. Normal users are type None
     """
     SUPPORT = "support"
-    ALL_USER_TYPES = (SUPPORT)
+    ALL_USER_TYPES = (SUPPORT,)
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 3b11f0cfa9..dce4b19a2d 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -83,9 +83,6 @@ class KeyConfig(Config):
         # a secret which is used to sign access tokens. If none is specified,
         # the registration_shared_secret is used, if one is given; otherwise,
         # a secret key is derived from the signing key.
-        #
-        # Note that changing this will invalidate any active access tokens, so
-        # all clients will have to log back in.
         %(macaroon_secret_key)s
 
         # Used to enable access token expiration.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 9e017116a9..8955cde4ed 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -532,6 +532,25 @@ class DeviceListEduUpdater(object):
 
                 stream_id = result["stream_id"]
                 devices = result["devices"]
+
+                # If the remote server has more than ~1000 devices for this user
+                # we assume that something is going horribly wrong (e.g. a bot
+                # that logs in and creates a new device every time it tries to
+                # send a message).  Maintaining lots of devices per user in the
+                # cache can cause serious performance issues as if this request
+                # takes more than 60s to complete, internal replication from the
+                # inbound federation worker to the synapse master may time out
+                # causing the inbound federation to fail and causing the remote
+                # server to retry, causing a DoS.  So in this scenario we give
+                # up on storing the total list of devices and only handle the
+                # delta instead.
+                if len(devices) > 1000:
+                    logger.warn(
+                        "Ignoring device list snapshot for %s as it has >1K devs (%d)",
+                        user_id, len(devices)
+                    )
+                    devices = []
+
                 yield self.store.update_remote_device_list_cache(
                     user_id, devices, stream_id,
                 )
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 5feb3f22a6..39184f0e22 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -167,18 +167,21 @@ class IdentityHandler(BaseHandler):
             "mxid": mxid,
             "threepid": threepid,
         }
-        headers = {}
+
         # we abuse the federation http client to sign the request, but we have to send it
         # using the normal http client since we don't want the SRV lookup and want normal
         # 'browser-like' HTTPS.
-        self.federation_http_client.sign_request(
+        auth_headers = self.federation_http_client.build_auth_headers(
             destination=None,
             method='POST',
             url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
-            headers_dict=headers,
             content=content,
             destination_is=id_server,
         )
+        headers = {
+            b"Authorization": auth_headers,
+        }
+
         try:
             yield self.http_client.post_json_get_json(
                 url,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 581e96c743..cb8c5f77dd 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -269,6 +269,7 @@ class RoomCreationHandler(BaseHandler):
             (EventTypes.RoomHistoryVisibility, ""),
             (EventTypes.GuestAccess, ""),
             (EventTypes.RoomAvatar, ""),
+            (EventTypes.Encryption, ""),
         )
 
         old_room_state_ids = yield self.store.get_filtered_current_state_ids(
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index f86a0b624e..1c3b7ea28a 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -140,82 +140,15 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=
         default_port = 8448
 
     if port is None:
-        return _WrappingEndpointFac(SRVClientEndpoint(
+        return SRVClientEndpoint(
             reactor, "matrix", domain, protocol="tcp",
             default_port=default_port, endpoint=transport_endpoint,
             endpoint_kw_args=endpoint_kw_args
-        ), reactor)
+        )
     else:
-        return _WrappingEndpointFac(transport_endpoint(
+        return transport_endpoint(
             reactor, domain, port, **endpoint_kw_args
-        ), reactor)
-
-
-class _WrappingEndpointFac(object):
-    def __init__(self, endpoint_fac, reactor):
-        self.endpoint_fac = endpoint_fac
-        self.reactor = reactor
-
-    @defer.inlineCallbacks
-    def connect(self, protocolFactory):
-        conn = yield self.endpoint_fac.connect(protocolFactory)
-        conn = _WrappedConnection(conn, self.reactor)
-        defer.returnValue(conn)
-
-
-class _WrappedConnection(object):
-    """Wraps a connection and calls abort on it if it hasn't seen any action
-    for 2.5-3 minutes.
-    """
-    __slots__ = ["conn", "last_request"]
-
-    def __init__(self, conn, reactor):
-        object.__setattr__(self, "conn", conn)
-        object.__setattr__(self, "last_request", time.time())
-        self._reactor = reactor
-
-    def __getattr__(self, name):
-        return getattr(self.conn, name)
-
-    def __setattr__(self, name, value):
-        setattr(self.conn, name, value)
-
-    def _time_things_out_maybe(self):
-        # We use a slightly shorter timeout here just in case the callLater is
-        # triggered early. Paranoia ftw.
-        # TODO: Cancel the previous callLater rather than comparing time.time()?
-        if time.time() - self.last_request >= 2.5 * 60:
-            self.abort()
-            # Abort the underlying TLS connection. The abort() method calls
-            # loseConnection() on the TLS connection which tries to
-            # shutdown the connection cleanly. We call abortConnection()
-            # since that will promptly close the TLS connection.
-            #
-            # In Twisted >18.4; the TLS connection will be None if it has closed
-            # which will make abortConnection() throw. Check that the TLS connection
-            # is not None before trying to close it.
-            if self.transport.getHandle() is not None:
-                self.transport.abortConnection()
-
-    def request(self, request):
-        self.last_request = time.time()
-
-        # Time this connection out if we haven't send a request in the last
-        # N minutes
-        # TODO: Cancel the previous callLater?
-        self._reactor.callLater(3 * 60, self._time_things_out_maybe)
-
-        d = self.conn.request(request)
-
-        def update_request_time(res):
-            self.last_request = time.time()
-            # TODO: Cancel the previous callLater?
-            self._reactor.callLater(3 * 60, self._time_things_out_maybe)
-            return res
-
-        d.addCallback(update_request_time)
-
-        return d
+        )
 
 
 class SRVClientEndpoint(object):
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index f2a42f97a6..250bb1ef91 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -298,9 +298,9 @@ class MatrixFederationHttpClient(object):
                     json = request.get_json()
                     if json:
                         headers_dict[b"Content-Type"] = [b"application/json"]
-                        self.sign_request(
+                        auth_headers = self.build_auth_headers(
                             destination_bytes, method_bytes, url_to_sign_bytes,
-                            headers_dict, json,
+                            json,
                         )
                         data = encode_canonical_json(json)
                         producer = FileBodyProducer(
@@ -309,34 +309,35 @@ class MatrixFederationHttpClient(object):
                         )
                     else:
                         producer = None
-                        self.sign_request(
+                        auth_headers = self.build_auth_headers(
                             destination_bytes, method_bytes, url_to_sign_bytes,
-                            headers_dict,
                         )
 
+                    headers_dict[b"Authorization"] = auth_headers
+
                     logger.info(
                         "{%s} [%s] Sending request: %s %s",
                         request.txn_id, request.destination, request.method,
                         url_str,
                     )
 
-                    # we don't want all the fancy cookie and redirect handling that
-                    # treq.request gives: just use the raw Agent.
-                    request_deferred = self.agent.request(
-                        method_bytes,
-                        url_bytes,
-                        headers=Headers(headers_dict),
-                        bodyProducer=producer,
-                    )
-
-                    request_deferred = timeout_deferred(
-                        request_deferred,
-                        timeout=_sec_timeout,
-                        reactor=self.hs.get_reactor(),
-                    )
-
                     try:
                         with Measure(self.clock, "outbound_request"):
+                            # we don't want all the fancy cookie and redirect handling
+                            # that treq.request gives: just use the raw Agent.
+                            request_deferred = self.agent.request(
+                                method_bytes,
+                                url_bytes,
+                                headers=Headers(headers_dict),
+                                bodyProducer=producer,
+                            )
+
+                            request_deferred = timeout_deferred(
+                                request_deferred,
+                                timeout=_sec_timeout,
+                                reactor=self.hs.get_reactor(),
+                            )
+
                             response = yield make_deferred_yieldable(
                                 request_deferred,
                             )
@@ -440,24 +441,23 @@ class MatrixFederationHttpClient(object):
 
             defer.returnValue(response)
 
-    def sign_request(self, destination, method, url_bytes, headers_dict,
-                     content=None, destination_is=None):
+    def build_auth_headers(
+        self, destination, method, url_bytes, content=None, destination_is=None,
+    ):
         """
-        Signs a request by adding an Authorization header to headers_dict
+        Builds the Authorization headers for a federation request
         Args:
             destination (bytes|None): The desination home server of the request.
                 May be None if the destination is an identity server, in which case
                 destination_is must be non-None.
             method (bytes): The HTTP method of the request
             url_bytes (bytes): The URI path of the request
-            headers_dict (dict[bytes, list[bytes]]): Dictionary of request headers to
-                append to
             content (object): The body of the request
             destination_is (bytes): As 'destination', but if the destination is an
                 identity server
 
         Returns:
-            None
+            list[bytes]: a list of headers to be added as "Authorization:" headers
         """
         request = {
             "method": method,
@@ -484,8 +484,7 @@ class MatrixFederationHttpClient(object):
                     self.server_name, key, sig,
                 )).encode('ascii')
             )
-
-        headers_dict[b"Authorization"] = auth_headers
+        return auth_headers
 
     @defer.inlineCallbacks
     def put_json(self, destination, path, args={}, data={},
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 69c5f9fe2e..882e844eb1 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -40,7 +40,11 @@ REQUIREMENTS = [
     "signedjson>=1.0.0",
     "pynacl>=1.2.1",
     "service_identity>=16.0.0",
-    "Twisted>=17.1.0",
+
+    # our logcontext handling relies on the ability to cancel inlineCallbacks
+    # (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7.
+    "Twisted>=18.7.0",
+
     "treq>=15.1",
     # Twisted has required pyopenssl 16.0 since about Twisted 16.6.
     "pyopenssl>=16.0.0",
@@ -52,15 +56,18 @@ REQUIREMENTS = [
     "pillow>=3.1.2",
     "sortedcontainers>=1.4.4",
     "psutil>=2.0.0",
-    "pymacaroons-pynacl>=0.9.3",
-    "msgpack-python>=0.4.2",
+    "pymacaroons>=0.13.0",
+    "msgpack>=0.5.0",
     "phonenumbers>=8.2.0",
     "six>=1.10",
     # prometheus_client 0.4.0 changed the format of counter metrics
     # (cf https://github.com/matrix-org/synapse/issues/4001)
     "prometheus_client>=0.0.18,<0.4.0",
+
     # we use attr.s(slots), which arrived in 16.0.0
-    "attrs>=16.0.0",
+    # Twisted 18.7.0 requires attrs>=17.4.0
+    "attrs>=17.4.0",
+
     "netaddr>=0.7.18",
 ]
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 2047110b1d..79e0276de6 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -739,7 +739,18 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         }
 
         events_map = {ev.event_id: ev for ev, _ in events_context}
-        room_version = yield self.get_room_version(room_id)
+
+        # We need to get the room version, which is in the create event.
+        # Normally that'd be in the database, but its also possible that we're
+        # currently trying to persist it.
+        room_version = None
+        for ev, _ in events_context:
+            if ev.type == EventTypes.Create and ev.state_key == "":
+                room_version = ev.content.get("room_version", "1")
+                break
+
+        if not room_version:
+            room_version = yield self.get_room_version(room_id)
 
         logger.debug("calling resolve_state_groups from preserve_events")
         res = yield self._state_resolution_handler.resolve_state_groups(
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index ec7b2c9672..430bb15f51 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -387,12 +387,14 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
     deferred that wraps and times out the given deferred, correctly handling
     the case where the given deferred's canceller throws.
 
+    (See https://twistedmatrix.com/trac/ticket/9534)
+
     NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
 
     Args:
         deferred (Deferred)
         timeout (float): Timeout in seconds
-        reactor (twisted.internet.reactor): The twisted reactor to use
+        reactor (twisted.interfaces.IReactorTime): The twisted reactor to use
         on_timeout_cancel (callable): A callable which is called immediately
             after the deferred times out, and not if this deferred is
             otherwise cancelled before the timeout.
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index b2e38276d8..8426eee400 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -17,6 +17,7 @@ from mock import Mock
 
 from twisted.internet.defer import TimeoutError
 from twisted.internet.error import ConnectingCancelledError, DNSLookupError
+from twisted.test.proto_helpers import StringTransport
 from twisted.web.client import ResponseNeverReceived
 from twisted.web.http import HTTPChannel
 
@@ -44,7 +45,7 @@ class FederationClientTests(HomeserverTestCase):
 
     def test_dns_error(self):
         """
-        If the DNS raising returns an error, it will bubble up.
+        If the DNS lookup returns an error, it will bubble up.
         """
         d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
         self.pump()
@@ -63,7 +64,7 @@ class FederationClientTests(HomeserverTestCase):
         self.pump()
 
         # Nothing happened yet
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Make sure treq is trying to connect
         clients = self.reactor.tcpClients
@@ -72,7 +73,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertEqual(clients[0][1], 8008)
 
         # Deferred is still without a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Push by enough to time it out
         self.reactor.advance(10.5)
@@ -94,7 +95,7 @@ class FederationClientTests(HomeserverTestCase):
         self.pump()
 
         # Nothing happened yet
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Make sure treq is trying to connect
         clients = self.reactor.tcpClients
@@ -107,7 +108,7 @@ class FederationClientTests(HomeserverTestCase):
         client.makeConnection(conn)
 
         # Deferred is still without a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Push by enough to time it out
         self.reactor.advance(10.5)
@@ -135,7 +136,7 @@ class FederationClientTests(HomeserverTestCase):
         client.makeConnection(conn)
 
         # Deferred does not have a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Send it the HTTP response
         client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
@@ -159,7 +160,7 @@ class FederationClientTests(HomeserverTestCase):
         client.makeConnection(conn)
 
         # Deferred does not have a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Send it the HTTP response
         client.dataReceived(
@@ -195,3 +196,42 @@ class FederationClientTests(HomeserverTestCase):
         request = server.requests[0]
         content = request.content.read()
         self.assertEqual(content, b'{"a":"b"}')
+
+    def test_closes_connection(self):
+        """Check that the client closes unused HTTP connections"""
+        d = self.cl.get_json("testserv:8008", "foo/bar")
+
+        self.pump()
+
+        # there should have been a call to connectTCP
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (_host, _port, factory, _timeout, _bindAddress) = clients[0]
+
+        # complete the connection and wire it up to a fake transport
+        client = factory.buildProtocol(None)
+        conn = StringTransport()
+        client.makeConnection(conn)
+
+        # that should have made it send the request to the connection
+        self.assertRegex(conn.value(), b"^GET /foo/bar")
+
+        # Send the HTTP response
+        client.dataReceived(
+            b"HTTP/1.1 200 OK\r\n"
+            b"Content-Type: application/json\r\n"
+            b"Content-Length: 2\r\n"
+            b"\r\n"
+            b"{}"
+        )
+
+        # We should get a successful response
+        r = self.successResultOf(d)
+        self.assertEqual(r, {})
+
+        self.assertFalse(conn.disconnecting)
+
+        # wait for a while
+        self.pump(120)
+
+        self.assertTrue(conn.disconnecting)
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py
new file mode 100644
index 0000000000..84dd71e47a
--- /dev/null
+++ b/tests/util/test_async_utils.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.task import Clock
+
+from synapse.util import logcontext
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.logcontext import LoggingContext
+
+from tests.unittest import TestCase
+
+
+class TimeoutDeferredTest(TestCase):
+    def setUp(self):
+        self.clock = Clock()
+
+    def test_times_out(self):
+        """Basic test case that checks that the original deferred is cancelled and that
+        the timing-out deferred is errbacked
+        """
+        cancelled = [False]
+
+        def canceller(_d):
+            cancelled[0] = True
+
+        non_completing_d = Deferred(canceller)
+        timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
+
+        self.assertNoResult(timing_out_d)
+        self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
+
+        self.clock.pump((1.0, ))
+
+        self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
+        self.failureResultOf(timing_out_d, defer.TimeoutError, )
+
+    def test_times_out_when_canceller_throws(self):
+        """Test that we have successfully worked around
+        https://twistedmatrix.com/trac/ticket/9534"""
+
+        def canceller(_d):
+            raise Exception("can't cancel this deferred")
+
+        non_completing_d = Deferred(canceller)
+        timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
+
+        self.assertNoResult(timing_out_d)
+
+        self.clock.pump((1.0, ))
+
+        self.failureResultOf(timing_out_d, defer.TimeoutError, )
+
+    def test_logcontext_is_preserved_on_cancellation(self):
+        blocking_was_cancelled = [False]
+
+        @defer.inlineCallbacks
+        def blocking():
+            non_completing_d = Deferred()
+            with logcontext.PreserveLoggingContext():
+                try:
+                    yield non_completing_d
+                except CancelledError:
+                    blocking_was_cancelled[0] = True
+                    raise
+
+        with logcontext.LoggingContext("one") as context_one:
+            # the errbacks should be run in the test logcontext
+            def errback(res, deferred_name):
+                self.assertIs(
+                    LoggingContext.current_context(), context_one,
+                    "errback %s run in unexpected logcontext %s" % (
+                        deferred_name, LoggingContext.current_context(),
+                    )
+                )
+                return res
+
+            original_deferred = blocking()
+            original_deferred.addErrback(errback, "orig")
+            timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
+            self.assertNoResult(timing_out_d)
+            self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+            timing_out_d.addErrback(errback, "timingout")
+
+            self.clock.pump((1.0, ))
+
+            self.assertTrue(
+                blocking_was_cancelled[0],
+                "non-completing deferred was not cancelled",
+            )
+            self.failureResultOf(timing_out_d, defer.TimeoutError, )
+            self.assertIs(LoggingContext.current_context(), context_one)