summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7836.misc1
-rw-r--r--changelog.d/7844.bugfix1
-rw-r--r--changelog.d/7847.feature1
-rw-r--r--changelog.d/7848.misc1
-rw-r--r--changelog.d/7851.misc1
-rw-r--r--changelog.d/7853.misc1
-rw-r--r--changelog.d/7854.bugfix1
-rw-r--r--changelog.d/7856.misc1
-rw-r--r--changelog.d/7870.misc1
-rw-r--r--docs/admin_api/user_admin_api.rst6
-rw-r--r--docs/jwt.md5
-rw-r--r--synapse/api/errors.py4
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/federation/sender/transaction_manager.py2
-rw-r--r--synapse/handlers/_base.py10
-rw-r--r--synapse/handlers/deactivate_account.py48
-rw-r--r--synapse/handlers/e2e_keys.py147
-rw-r--r--synapse/handlers/e2e_room_keys.py75
-rw-r--r--synapse/handlers/typing.py4
-rw-r--r--synapse/handlers/ui_auth/checkers.py3
-rw-r--r--synapse/http/client.py28
-rw-r--r--synapse/http/servlet.py4
-rw-r--r--synapse/rest/admin/users.py10
-rw-r--r--synapse/rest/client/v1/login.py8
-rw-r--r--synapse/rest/client/v1/room.py13
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py4
-rw-r--r--synapse/server.py4
-rw-r--r--synapse/server.pyi5
-rw-r--r--synapse/storage/data_stores/main/events.py67
-rw-r--r--synapse/storage/data_stores/main/state.py65
-rw-r--r--tests/handlers/test_e2e_keys.py286
-rw-r--r--tests/handlers/test_e2e_room_keys.py373
-rw-r--r--tests/handlers/test_typing.py4
-rw-r--r--tests/replication/_base.py168
-rw-r--r--tests/replication/test_client_reader_shard.py59
-rw-r--r--tests/replication/test_federation_sender_shard.py191
-rw-r--r--tests/rest/admin/test_user.py47
-rw-r--r--tests/rest/client/v1/test_login.py43
-rw-r--r--tests/server.py26
39 files changed, 1029 insertions, 691 deletions
diff --git a/changelog.d/7836.misc b/changelog.d/7836.misc
new file mode 100644
index 0000000000..a3a97c7590
--- /dev/null
+++ b/changelog.d/7836.misc
@@ -0,0 +1 @@
+Ensure that calls to `json.dumps` are compatible with the standard library json.
diff --git a/changelog.d/7844.bugfix b/changelog.d/7844.bugfix
new file mode 100644
index 0000000000..ad296f1b3c
--- /dev/null
+++ b/changelog.d/7844.bugfix
@@ -0,0 +1 @@
+Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`.
diff --git a/changelog.d/7847.feature b/changelog.d/7847.feature
new file mode 100644
index 0000000000..4b9a8d8569
--- /dev/null
+++ b/changelog.d/7847.feature
@@ -0,0 +1 @@
+Add the ability to re-activate an account from the admin API.
diff --git a/changelog.d/7848.misc b/changelog.d/7848.misc
new file mode 100644
index 0000000000..d9db1d8357
--- /dev/null
+++ b/changelog.d/7848.misc
@@ -0,0 +1 @@
+Remove redundant `retry_on_integrity_error` wrapper for event persistence code.
diff --git a/changelog.d/7851.misc b/changelog.d/7851.misc
new file mode 100644
index 0000000000..e5cf540edf
--- /dev/null
+++ b/changelog.d/7851.misc
@@ -0,0 +1 @@
+Convert E2E keys and room keys handlers to async/await.
diff --git a/changelog.d/7853.misc b/changelog.d/7853.misc
new file mode 100644
index 0000000000..b4f614084d
--- /dev/null
+++ b/changelog.d/7853.misc
@@ -0,0 +1 @@
+Add support for handling registration requests across multiple client reader workers.
diff --git a/changelog.d/7854.bugfix b/changelog.d/7854.bugfix
new file mode 100644
index 0000000000..b11f9dedfe
--- /dev/null
+++ b/changelog.d/7854.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation.
diff --git a/changelog.d/7856.misc b/changelog.d/7856.misc
new file mode 100644
index 0000000000..7d99fb67be
--- /dev/null
+++ b/changelog.d/7856.misc
@@ -0,0 +1 @@
+Small performance improvement in typing processing.
diff --git a/changelog.d/7870.misc b/changelog.d/7870.misc
new file mode 100644
index 0000000000..27cce2f2f9
--- /dev/null
+++ b/changelog.d/7870.misc
@@ -0,0 +1 @@
+Add some type annotations to `HomeServer` and `BaseHandler`.
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 7b030a6285..be05128b3e 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -91,10 +91,14 @@ Body parameters:
 
 - ``admin``, optional, defaults to ``false``.
 
-- ``deactivated``, optional, defaults to ``false``.
+- ``deactivated``, optional. If unspecified, deactivation state will be left
+  unchanged on existing accounts and set to ``false`` for new accounts.
 
 If the user already exists then optional parameters default to the current value.
 
+In order to re-activate an account ``deactivated`` must be set to ``false``. If
+users do not login via single-sign-on, a new ``password`` must be provided.
+
 List Accounts
 =============
 
diff --git a/docs/jwt.md b/docs/jwt.md
index 93b8d05236..5be9fd26e3 100644
--- a/docs/jwt.md
+++ b/docs/jwt.md
@@ -31,10 +31,7 @@ The `token` field should include the JSON web token with the following claims:
   Providing the audience claim when not configured will cause validation to fail.
 
 In the case that the token is not valid, the homeserver must respond with
-`401 Unauthorized` and an error code of `M_UNAUTHORIZED`.
-
-(Note that this differs from the token based logins which return a
-`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.)
+`403 Forbidden` and an error code of `M_FORBIDDEN`.
 
 As with other login types, there are additional fields (e.g. `device_id` and
 `initial_device_display_name`) which can be included in the above request.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 94db877050..bd004350dd 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -16,12 +16,14 @@
 # limitations under the License.
 
 """Contains exceptions and error codes."""
-import json
+
 import logging
 import typing
 from http import HTTPStatus
 from typing import Dict, List, Optional, Union
 
+from canonicaljson import json
+
 from twisted.web import http
 
 if typing.TYPE_CHECKING:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2aab9c5f55..8c53330c49 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -14,10 +14,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
 import logging
 from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
 
+from canonicaljson import json
 from prometheus_client import Counter, Histogram
 
 from twisted.internet import defer
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index a2752a54a5..8280f8b900 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -61,8 +61,6 @@ class TransactionManager(object):
         # all the edus in that transaction. This needs to be done since there is
         # no active span here, so if the edus were not received by the remote the
         # span would have no causality and it would be forgotten.
-        # The span_contexts is a generator so that it won't be evaluated if
-        # opentracing is disabled. (Yay speed!)
 
         span_contexts = []
         keep_destination = whitelisted_homeserver(destination)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 61dc4beafe..6a4944467a 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -17,6 +17,8 @@ import logging
 
 from twisted.internet import defer
 
+import synapse.state
+import synapse.storage
 import synapse.types
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.ratelimiting import Ratelimiter
@@ -28,10 +30,6 @@ logger = logging.getLogger(__name__)
 class BaseHandler(object):
     """
     Common base class for the event handlers.
-
-    Attributes:
-        store (synapse.storage.DataStore):
-        state_handler (synapse.state.StateHandler):
     """
 
     def __init__(self, hs):
@@ -39,10 +37,10 @@ class BaseHandler(object):
         Args:
             hs (synapse.server.HomeServer):
         """
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastore()  # type: synapse.storage.DataStore
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
-        self.state_handler = hs.get_state_handler()
+        self.state_handler = hs.get_state_handler()  # type: synapse.state.StateHandler
         self.distributor = hs.get_distributor()
         self.clock = hs.get_clock()
         self.hs = hs
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 591657a5c2..3789b1b495 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Optional
 
 from synapse.api.errors import SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -46,19 +47,20 @@ class DeactivateAccountHandler(BaseHandler):
 
         self._account_validity_enabled = hs.config.account_validity.enabled
 
-    async def deactivate_account(self, user_id, erase_data, id_server=None):
+    async def deactivate_account(
+        self, user_id: str, erase_data: bool, id_server: Optional[str] = None
+    ) -> bool:
         """Deactivate a user's account
 
         Args:
-            user_id (str): ID of user to be deactivated
-            erase_data (bool): whether to GDPR-erase the user's data
-            id_server (str|None): Use the given identity server when unbinding
+            user_id: ID of user to be deactivated
+            erase_data: whether to GDPR-erase the user's data
+            id_server: Use the given identity server when unbinding
                 any threepids. If None then will attempt to unbind using the
                 identity server specified when binding (if known).
 
         Returns:
-            Deferred[bool]: True if identity server supports removing
-            threepids, otherwise False.
+            True if identity server supports removing threepids, otherwise False.
         """
         # FIXME: Theoretically there is a race here wherein user resets
         # password using threepid.
@@ -138,11 +140,11 @@ class DeactivateAccountHandler(BaseHandler):
 
         return identity_server_supports_unbinding
 
-    async def _reject_pending_invites_for_user(self, user_id):
+    async def _reject_pending_invites_for_user(self, user_id: str):
         """Reject pending invites addressed to a given user ID.
 
         Args:
-            user_id (str): The user ID to reject pending invites for.
+            user_id: The user ID to reject pending invites for.
         """
         user = UserID.from_string(user_id)
         pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
@@ -170,22 +172,16 @@ class DeactivateAccountHandler(BaseHandler):
                     room.room_id,
                 )
 
-    def _start_user_parting(self):
+    def _start_user_parting(self) -> None:
         """
         Start the process that goes through the table of users
         pending deactivation, if it isn't already running.
-
-        Returns:
-            None
         """
         if not self._user_parter_running:
             run_as_background_process("user_parter_loop", self._user_parter_loop)
 
-    async def _user_parter_loop(self):
+    async def _user_parter_loop(self) -> None:
         """Loop that parts deactivated users from rooms
-
-        Returns:
-            None
         """
         self._user_parter_running = True
         logger.info("Starting user parter")
@@ -202,11 +198,8 @@ class DeactivateAccountHandler(BaseHandler):
         finally:
             self._user_parter_running = False
 
-    async def _part_user(self, user_id):
+    async def _part_user(self, user_id: str) -> None:
         """Causes the given user_id to leave all the rooms they're joined to
-
-        Returns:
-            None
         """
         user = UserID.from_string(user_id)
 
@@ -228,3 +221,18 @@ class DeactivateAccountHandler(BaseHandler):
                     user_id,
                     room_id,
                 )
+
+    async def activate_account(self, user_id: str) -> None:
+        """
+        Activate an account that was previously deactivated.
+
+        This simply marks the user as activate in the database and does not
+        attempt to rejoin rooms, re-add threepids, etc.
+
+        The user will also need a password hash set to actually login.
+
+        Args:
+            user_id: ID of user to be deactivated
+        """
+        # Mark the user as activate.
+        await self.store.set_user_deactivated_status(user_id, False)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index a7e60cbc26..361dd64cd2 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -77,8 +77,7 @@ class E2eKeysHandler(object):
         )
 
     @trace
-    @defer.inlineCallbacks
-    def query_devices(self, query_body, timeout, from_user_id):
+    async def query_devices(self, query_body, timeout, from_user_id):
         """ Handle a device key query from a client
 
         {
@@ -124,7 +123,7 @@ class E2eKeysHandler(object):
         failures = {}
         results = {}
         if local_query:
-            local_result = yield self.query_local_devices(local_query)
+            local_result = await self.query_local_devices(local_query)
             for user_id, keys in local_result.items():
                 if user_id in local_query:
                     results[user_id] = keys
@@ -142,7 +141,7 @@ class E2eKeysHandler(object):
             (
                 user_ids_not_in_cache,
                 remote_results,
-            ) = yield self.store.get_user_devices_from_cache(query_list)
+            ) = await self.store.get_user_devices_from_cache(query_list)
             for user_id, devices in remote_results.items():
                 user_devices = results.setdefault(user_id, {})
                 for device_id, device in devices.items():
@@ -161,14 +160,13 @@ class E2eKeysHandler(object):
                 r[user_id] = remote_queries[user_id]
 
         # Get cached cross-signing keys
-        cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+        cross_signing_keys = await self.get_cross_signing_keys_from_cache(
             device_keys_query, from_user_id
         )
 
         # Now fetch any devices that we don't have in our cache
         @trace
-        @defer.inlineCallbacks
-        def do_remote_query(destination):
+        async def do_remote_query(destination):
             """This is called when we are querying the device list of a user on
             a remote homeserver and their device list is not in the device list
             cache. If we share a room with this user and we're not querying for
@@ -192,7 +190,7 @@ class E2eKeysHandler(object):
                 if device_list:
                     continue
 
-                room_ids = yield self.store.get_rooms_for_user(user_id)
+                room_ids = await self.store.get_rooms_for_user(user_id)
                 if not room_ids:
                     continue
 
@@ -201,11 +199,11 @@ class E2eKeysHandler(object):
                 # done an initial sync on the device list so we do it now.
                 try:
                     if self._is_master:
-                        user_devices = yield self.device_handler.device_list_updater.user_device_resync(
+                        user_devices = await self.device_handler.device_list_updater.user_device_resync(
                             user_id
                         )
                     else:
-                        user_devices = yield self._user_device_resync_client(
+                        user_devices = await self._user_device_resync_client(
                             user_id=user_id
                         )
 
@@ -227,7 +225,7 @@ class E2eKeysHandler(object):
                 destination_query.pop(user_id)
 
             try:
-                remote_result = yield self.federation.query_client_keys(
+                remote_result = await self.federation.query_client_keys(
                     destination, {"device_keys": destination_query}, timeout=timeout
                 )
 
@@ -251,7 +249,7 @@ class E2eKeysHandler(object):
                 set_tag("error", True)
                 set_tag("reason", failure)
 
-        yield make_deferred_yieldable(
+        await make_deferred_yieldable(
             defer.gatherResults(
                 [
                     run_in_background(do_remote_query, destination)
@@ -267,8 +265,7 @@ class E2eKeysHandler(object):
 
         return ret
 
-    @defer.inlineCallbacks
-    def get_cross_signing_keys_from_cache(self, query, from_user_id):
+    async def get_cross_signing_keys_from_cache(self, query, from_user_id):
         """Get cross-signing keys for users from the database
 
         Args:
@@ -289,7 +286,7 @@ class E2eKeysHandler(object):
 
         user_ids = list(query)
 
-        keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
+        keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
 
         for user_id, user_info in keys.items():
             if user_info is None:
@@ -315,8 +312,7 @@ class E2eKeysHandler(object):
         }
 
     @trace
-    @defer.inlineCallbacks
-    def query_local_devices(self, query):
+    async def query_local_devices(self, query):
         """Get E2E device keys for local users
 
         Args:
@@ -354,7 +350,7 @@ class E2eKeysHandler(object):
             # make sure that each queried user appears in the result dict
             result_dict[user_id] = {}
 
-        results = yield self.store.get_e2e_device_keys(local_query)
+        results = await self.store.get_e2e_device_keys(local_query)
 
         # Build the result structure
         for user_id, device_keys in results.items():
@@ -364,16 +360,15 @@ class E2eKeysHandler(object):
         log_kv(results)
         return result_dict
 
-    @defer.inlineCallbacks
-    def on_federation_query_client_keys(self, query_body):
+    async def on_federation_query_client_keys(self, query_body):
         """ Handle a device key query from a federated server
         """
         device_keys_query = query_body.get("device_keys", {})
-        res = yield self.query_local_devices(device_keys_query)
+        res = await self.query_local_devices(device_keys_query)
         ret = {"device_keys": res}
 
         # add in the cross-signing keys
-        cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+        cross_signing_keys = await self.get_cross_signing_keys_from_cache(
             device_keys_query, None
         )
 
@@ -382,8 +377,7 @@ class E2eKeysHandler(object):
         return ret
 
     @trace
-    @defer.inlineCallbacks
-    def claim_one_time_keys(self, query, timeout):
+    async def claim_one_time_keys(self, query, timeout):
         local_query = []
         remote_queries = {}
 
@@ -399,7 +393,7 @@ class E2eKeysHandler(object):
         set_tag("local_key_query", local_query)
         set_tag("remote_key_query", remote_queries)
 
-        results = yield self.store.claim_e2e_one_time_keys(local_query)
+        results = await self.store.claim_e2e_one_time_keys(local_query)
 
         json_result = {}
         failures = {}
@@ -411,12 +405,11 @@ class E2eKeysHandler(object):
                     }
 
         @trace
-        @defer.inlineCallbacks
-        def claim_client_keys(destination):
+        async def claim_client_keys(destination):
             set_tag("destination", destination)
             device_keys = remote_queries[destination]
             try:
-                remote_result = yield self.federation.claim_client_keys(
+                remote_result = await self.federation.claim_client_keys(
                     destination, {"one_time_keys": device_keys}, timeout=timeout
                 )
                 for user_id, keys in remote_result["one_time_keys"].items():
@@ -429,7 +422,7 @@ class E2eKeysHandler(object):
                 set_tag("error", True)
                 set_tag("reason", failure)
 
-        yield make_deferred_yieldable(
+        await make_deferred_yieldable(
             defer.gatherResults(
                 [
                     run_in_background(claim_client_keys, destination)
@@ -454,9 +447,8 @@ class E2eKeysHandler(object):
         log_kv({"one_time_keys": json_result, "failures": failures})
         return {"one_time_keys": json_result, "failures": failures}
 
-    @defer.inlineCallbacks
     @tag_args
-    def upload_keys_for_user(self, user_id, device_id, keys):
+    async def upload_keys_for_user(self, user_id, device_id, keys):
 
         time_now = self.clock.time_msec()
 
@@ -477,12 +469,12 @@ class E2eKeysHandler(object):
                 }
             )
             # TODO: Sign the JSON with the server key
-            changed = yield self.store.set_e2e_device_keys(
+            changed = await self.store.set_e2e_device_keys(
                 user_id, device_id, time_now, device_keys
             )
             if changed:
                 # Only notify about device updates *if* the keys actually changed
-                yield self.device_handler.notify_device_update(user_id, [device_id])
+                await self.device_handler.notify_device_update(user_id, [device_id])
         else:
             log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
         one_time_keys = keys.get("one_time_keys", None)
@@ -494,7 +486,7 @@ class E2eKeysHandler(object):
                     "device_id": device_id,
                 }
             )
-            yield self._upload_one_time_keys_for_user(
+            await self._upload_one_time_keys_for_user(
                 user_id, device_id, time_now, one_time_keys
             )
         else:
@@ -507,15 +499,14 @@ class E2eKeysHandler(object):
         # old access_token without an associated device_id. Either way, we
         # need to double-check the device is registered to avoid ending up with
         # keys without a corresponding device.
-        yield self.device_handler.check_device_registered(user_id, device_id)
+        await self.device_handler.check_device_registered(user_id, device_id)
 
-        result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+        result = await self.store.count_e2e_one_time_keys(user_id, device_id)
 
         set_tag("one_time_key_counts", result)
         return {"one_time_key_counts": result}
 
-    @defer.inlineCallbacks
-    def _upload_one_time_keys_for_user(
+    async def _upload_one_time_keys_for_user(
         self, user_id, device_id, time_now, one_time_keys
     ):
         logger.info(
@@ -533,7 +524,7 @@ class E2eKeysHandler(object):
             key_list.append((algorithm, key_id, key_obj))
 
         # First we check if we have already persisted any of the keys.
-        existing_key_map = yield self.store.get_e2e_one_time_keys(
+        existing_key_map = await self.store.get_e2e_one_time_keys(
             user_id, device_id, [k_id for _, k_id, _ in key_list]
         )
 
@@ -556,10 +547,9 @@ class E2eKeysHandler(object):
                 )
 
         log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
-        yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
+        await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
 
-    @defer.inlineCallbacks
-    def upload_signing_keys_for_user(self, user_id, keys):
+    async def upload_signing_keys_for_user(self, user_id, keys):
         """Upload signing keys for cross-signing
 
         Args:
@@ -574,7 +564,7 @@ class E2eKeysHandler(object):
 
             _check_cross_signing_key(master_key, user_id, "master")
         else:
-            master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+            master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
 
         # if there is no master key, then we can't do anything, because all the
         # other cross-signing keys need to be signed by the master key
@@ -613,10 +603,10 @@ class E2eKeysHandler(object):
         # if everything checks out, then store the keys and send notifications
         deviceids = []
         if "master_key" in keys:
-            yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+            await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
             deviceids.append(master_verify_key.version)
         if "self_signing_key" in keys:
-            yield self.store.set_e2e_cross_signing_key(
+            await self.store.set_e2e_cross_signing_key(
                 user_id, "self_signing", self_signing_key
             )
             try:
@@ -626,23 +616,22 @@ class E2eKeysHandler(object):
             except ValueError:
                 raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM)
         if "user_signing_key" in keys:
-            yield self.store.set_e2e_cross_signing_key(
+            await self.store.set_e2e_cross_signing_key(
                 user_id, "user_signing", user_signing_key
             )
             # the signature stream matches the semantics that we want for
             # user-signing key updates: only the user themselves is notified of
             # their own user-signing key updates
-            yield self.device_handler.notify_user_signature_update(user_id, [user_id])
+            await self.device_handler.notify_user_signature_update(user_id, [user_id])
 
         # master key and self-signing key updates match the semantics of device
         # list updates: all users who share an encrypted room are notified
         if len(deviceids):
-            yield self.device_handler.notify_device_update(user_id, deviceids)
+            await self.device_handler.notify_device_update(user_id, deviceids)
 
         return {}
 
-    @defer.inlineCallbacks
-    def upload_signatures_for_device_keys(self, user_id, signatures):
+    async def upload_signatures_for_device_keys(self, user_id, signatures):
         """Upload device signatures for cross-signing
 
         Args:
@@ -667,13 +656,13 @@ class E2eKeysHandler(object):
         self_signatures = signatures.get(user_id, {})
         other_signatures = {k: v for k, v in signatures.items() if k != user_id}
 
-        self_signature_list, self_failures = yield self._process_self_signatures(
+        self_signature_list, self_failures = await self._process_self_signatures(
             user_id, self_signatures
         )
         signature_list.extend(self_signature_list)
         failures.update(self_failures)
 
-        other_signature_list, other_failures = yield self._process_other_signatures(
+        other_signature_list, other_failures = await self._process_other_signatures(
             user_id, other_signatures
         )
         signature_list.extend(other_signature_list)
@@ -681,21 +670,20 @@ class E2eKeysHandler(object):
 
         # store the signature, and send the appropriate notifications for sync
         logger.debug("upload signature failures: %r", failures)
-        yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
+        await self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
 
         self_device_ids = [item.target_device_id for item in self_signature_list]
         if self_device_ids:
-            yield self.device_handler.notify_device_update(user_id, self_device_ids)
+            await self.device_handler.notify_device_update(user_id, self_device_ids)
         signed_users = [item.target_user_id for item in other_signature_list]
         if signed_users:
-            yield self.device_handler.notify_user_signature_update(
+            await self.device_handler.notify_user_signature_update(
                 user_id, signed_users
             )
 
         return {"failures": failures}
 
-    @defer.inlineCallbacks
-    def _process_self_signatures(self, user_id, signatures):
+    async def _process_self_signatures(self, user_id, signatures):
         """Process uploaded signatures of the user's own keys.
 
         Signatures of the user's own keys from this API come in two forms:
@@ -728,7 +716,7 @@ class E2eKeysHandler(object):
                 _,
                 self_signing_key_id,
                 self_signing_verify_key,
-            ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
+            ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
 
             # get our master key, since we may have received a signature of it.
             # We need to fetch it here so that we know what its key ID is, so
@@ -738,12 +726,12 @@ class E2eKeysHandler(object):
                 master_key,
                 _,
                 master_verify_key,
-            ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
+            ) = await self._get_e2e_cross_signing_verify_key(user_id, "master")
 
             # fetch our stored devices.  This is used to 1. verify
             # signatures on the master key, and 2. to compare with what
             # was sent if the device was signed
-            devices = yield self.store.get_e2e_device_keys([(user_id, None)])
+            devices = await self.store.get_e2e_device_keys([(user_id, None)])
 
             if user_id not in devices:
                 raise NotFoundError("No device keys found")
@@ -853,8 +841,7 @@ class E2eKeysHandler(object):
 
         return master_key_signature_list
 
-    @defer.inlineCallbacks
-    def _process_other_signatures(self, user_id, signatures):
+    async def _process_other_signatures(self, user_id, signatures):
         """Process uploaded signatures of other users' keys.  These will be the
         target user's master keys, signed by the uploading user's user-signing
         key.
@@ -882,7 +869,7 @@ class E2eKeysHandler(object):
                 user_signing_key,
                 user_signing_key_id,
                 user_signing_verify_key,
-            ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
+            ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
         except SynapseError as e:
             failure = _exception_to_failure(e)
             for user, devicemap in signatures.items():
@@ -905,7 +892,7 @@ class E2eKeysHandler(object):
                     master_key,
                     master_key_id,
                     _,
-                ) = yield self._get_e2e_cross_signing_verify_key(
+                ) = await self._get_e2e_cross_signing_verify_key(
                     target_user, "master", user_id
                 )
 
@@ -958,8 +945,7 @@ class E2eKeysHandler(object):
 
         return signature_list, failures
 
-    @defer.inlineCallbacks
-    def _get_e2e_cross_signing_verify_key(
+    async def _get_e2e_cross_signing_verify_key(
         self, user_id: str, key_type: str, from_user_id: str = None
     ):
         """Fetch locally or remotely query for a cross-signing public key.
@@ -983,7 +969,7 @@ class E2eKeysHandler(object):
             SynapseError: if `user_id` is invalid
         """
         user = UserID.from_string(user_id)
-        key = yield self.store.get_e2e_cross_signing_key(
+        key = await self.store.get_e2e_cross_signing_key(
             user_id, key_type, from_user_id
         )
 
@@ -1009,15 +995,14 @@ class E2eKeysHandler(object):
             key,
             key_id,
             verify_key,
-        ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
+        ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
 
         if key is None:
             raise NotFoundError("No %s key found for %s" % (key_type, user_id))
 
         return key, key_id, verify_key
 
-    @defer.inlineCallbacks
-    def _retrieve_cross_signing_keys_for_remote_user(
+    async def _retrieve_cross_signing_keys_for_remote_user(
         self, user: UserID, desired_key_type: str,
     ):
         """Queries cross-signing keys for a remote user and saves them to the database
@@ -1035,7 +1020,7 @@ class E2eKeysHandler(object):
             If the key cannot be retrieved, all values in the tuple will instead be None.
         """
         try:
-            remote_result = yield self.federation.query_user_devices(
+            remote_result = await self.federation.query_user_devices(
                 user.domain, user.to_string()
             )
         except Exception as e:
@@ -1101,14 +1086,14 @@ class E2eKeysHandler(object):
                 desired_key_id = key_id
 
             # At the same time, store this key in the db for subsequent queries
-            yield self.store.set_e2e_cross_signing_key(
+            await self.store.set_e2e_cross_signing_key(
                 user.to_string(), key_type, key_content
             )
 
         # Notify clients that new devices for this user have been discovered
         if retrieved_device_ids:
             # XXX is this necessary?
-            yield self.device_handler.notify_device_update(
+            await self.device_handler.notify_device_update(
                 user.to_string(), retrieved_device_ids
             )
 
@@ -1250,8 +1235,7 @@ class SigningKeyEduUpdater(object):
             iterable=True,
         )
 
-    @defer.inlineCallbacks
-    def incoming_signing_key_update(self, origin, edu_content):
+    async def incoming_signing_key_update(self, origin, edu_content):
         """Called on incoming signing key update from federation. Responsible for
         parsing the EDU and adding to pending updates list.
 
@@ -1268,7 +1252,7 @@ class SigningKeyEduUpdater(object):
             logger.warning("Got signing key update edu for %r from %r", user_id, origin)
             return
 
-        room_ids = yield self.store.get_rooms_for_user(user_id)
+        room_ids = await self.store.get_rooms_for_user(user_id)
         if not room_ids:
             # We don't share any rooms with this user. Ignore update, as we
             # probably won't get any further updates.
@@ -1278,10 +1262,9 @@ class SigningKeyEduUpdater(object):
             (master_key, self_signing_key)
         )
 
-        yield self._handle_signing_key_updates(user_id)
+        await self._handle_signing_key_updates(user_id)
 
-    @defer.inlineCallbacks
-    def _handle_signing_key_updates(self, user_id):
+    async def _handle_signing_key_updates(self, user_id):
         """Actually handle pending updates.
 
         Args:
@@ -1291,7 +1274,7 @@ class SigningKeyEduUpdater(object):
         device_handler = self.e2e_keys_handler.device_handler
         device_list_updater = device_handler.device_list_updater
 
-        with (yield self._remote_edu_linearizer.queue(user_id)):
+        with (await self._remote_edu_linearizer.queue(user_id)):
             pending_updates = self._pending_updates.pop(user_id, [])
             if not pending_updates:
                 # This can happen since we batch updates
@@ -1302,9 +1285,9 @@ class SigningKeyEduUpdater(object):
             logger.info("pending updates: %r", pending_updates)
 
             for master_key, self_signing_key in pending_updates:
-                new_device_ids = yield device_list_updater.process_cross_signing_key_update(
+                new_device_ids = await device_list_updater.process_cross_signing_key_update(
                     user_id, master_key, self_signing_key,
                 )
                 device_ids = device_ids + new_device_ids
 
-            yield device_handler.notify_device_update(user_id, device_ids)
+            await device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f55470a707..0bb983dc28 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -16,8 +16,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.errors import (
     Codes,
     NotFoundError,
@@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object):
         self._upload_linearizer = Linearizer("upload_room_keys_lock")
 
     @trace
-    @defer.inlineCallbacks
-    def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
         See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
@@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object):
 
         # we deliberately take the lock to get keys so that changing the version
         # works atomically
-        with (yield self._upload_linearizer.queue(user_id)):
+        with (await self._upload_linearizer.queue(user_id)):
             # make sure the backup version exists
             try:
-                yield self.store.get_e2e_room_keys_version_info(user_id, version)
+                await self.store.get_e2e_room_keys_version_info(user_id, version)
             except StoreError as e:
                 if e.code == 404:
                     raise NotFoundError("Unknown backup version")
                 else:
                     raise
 
-            results = yield self.store.get_e2e_room_keys(
+            results = await self.store.get_e2e_room_keys(
                 user_id, version, room_id, session_id
             )
 
@@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object):
             return results
 
     @trace
-    @defer.inlineCallbacks
-    def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
         """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
         room or a given session.
         See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
@@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object):
         """
 
         # lock for consistency with uploading
-        with (yield self._upload_linearizer.queue(user_id)):
+        with (await self._upload_linearizer.queue(user_id)):
             # make sure the backup version exists
             try:
-                version_info = yield self.store.get_e2e_room_keys_version_info(
+                version_info = await self.store.get_e2e_room_keys_version_info(
                     user_id, version
                 )
             except StoreError as e:
@@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object):
                 else:
                     raise
 
-            yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
+            await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
 
             version_etag = version_info["etag"] + 1
-            yield self.store.update_e2e_room_keys_version(
+            await self.store.update_e2e_room_keys_version(
                 user_id, version, None, version_etag
             )
 
-            count = yield self.store.count_e2e_room_keys(user_id, version)
+            count = await self.store.count_e2e_room_keys(user_id, version)
             return {"etag": str(version_etag), "count": count}
 
     @trace
-    @defer.inlineCallbacks
-    def upload_room_keys(self, user_id, version, room_keys):
+    async def upload_room_keys(self, user_id, version, room_keys):
         """Bulk upload a list of room keys into a given backup version, asserting
         that the given version is the current backup version.  room_keys are merged
         into the current backup as described in RoomKeysServlet.on_PUT().
@@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object):
         # TODO: Validate the JSON to make sure it has the right keys.
 
         # XXX: perhaps we should use a finer grained lock here?
-        with (yield self._upload_linearizer.queue(user_id)):
+        with (await self._upload_linearizer.queue(user_id)):
 
             # Check that the version we're trying to upload is the current version
             try:
-                version_info = yield self.store.get_e2e_room_keys_version_info(user_id)
+                version_info = await self.store.get_e2e_room_keys_version_info(user_id)
             except StoreError as e:
                 if e.code == 404:
                     raise NotFoundError("Version '%s' not found" % (version,))
@@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object):
             if version_info["version"] != version:
                 # Check that the version we're trying to upload actually exists
                 try:
-                    version_info = yield self.store.get_e2e_room_keys_version_info(
+                    version_info = await self.store.get_e2e_room_keys_version_info(
                         user_id, version
                     )
                     # if we get this far, the version must exist
@@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object):
             # submitted.  Then compare them with the submitted keys.  If the
             # key is new, insert it; if the key should be updated, then update
             # it; otherwise, drop it.
-            existing_keys = yield self.store.get_e2e_room_keys_multi(
+            existing_keys = await self.store.get_e2e_room_keys_multi(
                 user_id, version, room_keys["rooms"]
             )
             to_insert = []  # batch the inserts together
@@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object):
                             # updates are done one at a time in the DB, so send
                             # updates right away rather than batching them up,
                             # like we do with the inserts
-                            yield self.store.update_e2e_room_key(
+                            await self.store.update_e2e_room_key(
                                 user_id, version, room_id, session_id, room_key
                             )
                             changed = True
@@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object):
                         changed = True
 
             if len(to_insert):
-                yield self.store.add_e2e_room_keys(user_id, version, to_insert)
+                await self.store.add_e2e_room_keys(user_id, version, to_insert)
 
             version_etag = version_info["etag"]
             if changed:
                 version_etag = version_etag + 1
-                yield self.store.update_e2e_room_keys_version(
+                await self.store.update_e2e_room_keys_version(
                     user_id, version, None, version_etag
                 )
 
-            count = yield self.store.count_e2e_room_keys(user_id, version)
+            count = await self.store.count_e2e_room_keys(user_id, version)
             return {"etag": str(version_etag), "count": count}
 
     @staticmethod
@@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object):
         return True
 
     @trace
-    @defer.inlineCallbacks
-    def create_version(self, user_id, version_info):
+    async def create_version(self, user_id, version_info):
         """Create a new backup version.  This automatically becomes the new
         backup version for the user's keys; previous backups will no longer be
         writeable to.
@@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object):
         # TODO: Validate the JSON to make sure it has the right keys.
 
         # lock everyone out until we've switched version
-        with (yield self._upload_linearizer.queue(user_id)):
-            new_version = yield self.store.create_e2e_room_keys_version(
+        with (await self._upload_linearizer.queue(user_id)):
+            new_version = await self.store.create_e2e_room_keys_version(
                 user_id, version_info
             )
             return new_version
 
-    @defer.inlineCallbacks
-    def get_version_info(self, user_id, version=None):
+    async def get_version_info(self, user_id, version=None):
         """Get the info about a given version of the user's backup
 
         Args:
@@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object):
         }
         """
 
-        with (yield self._upload_linearizer.queue(user_id)):
+        with (await self._upload_linearizer.queue(user_id)):
             try:
-                res = yield self.store.get_e2e_room_keys_version_info(user_id, version)
+                res = await self.store.get_e2e_room_keys_version_info(user_id, version)
             except StoreError as e:
                 if e.code == 404:
                     raise NotFoundError("Unknown backup version")
                 else:
                     raise
 
-            res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
+            res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"])
             res["etag"] = str(res["etag"])
             return res
 
     @trace
-    @defer.inlineCallbacks
-    def delete_version(self, user_id, version=None):
+    async def delete_version(self, user_id, version=None):
         """Deletes a given version of the user's e2e_room_keys backup
 
         Args:
@@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object):
             NotFoundError: if this backup version doesn't exist
         """
 
-        with (yield self._upload_linearizer.queue(user_id)):
+        with (await self._upload_linearizer.queue(user_id)):
             try:
-                yield self.store.delete_e2e_room_keys_version(user_id, version)
+                await self.store.delete_e2e_room_keys_version(user_id, version)
             except StoreError as e:
                 if e.code == 404:
                     raise NotFoundError("Unknown backup version")
@@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object):
                     raise
 
     @trace
-    @defer.inlineCallbacks
-    def update_version(self, user_id, version, version_info):
+    async def update_version(self, user_id, version, version_info):
         """Update the info about a given version of the user's backup
 
         Args:
@@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object):
             raise SynapseError(
                 400, "Version in body does not match", Codes.INVALID_PARAM
             )
-        with (yield self._upload_linearizer.queue(user_id)):
+        with (await self._upload_linearizer.queue(user_id)):
             try:
-                old_info = yield self.store.get_e2e_room_keys_version_info(
+                old_info = await self.store.get_e2e_room_keys_version_info(
                     user_id, version
                 )
             except StoreError as e:
@@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object):
             if old_info["algorithm"] != version_info["algorithm"]:
                 raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
 
-            yield self.store.update_e2e_room_keys_version(
+            await self.store.update_e2e_room_keys_version(
                 user_id, version, version_info
             )
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 879c4c07c6..846ddbdc6c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -185,7 +185,7 @@ class TypingHandler(object):
 
     async def _push_remote(self, member, typing):
         try:
-            users = await self.state.get_current_users_in_room(member.room_id)
+            users = await self.store.get_users_in_room(member.room_id)
             self._member_last_federation_poke[member] = self.clock.time_msec()
 
             now = self.clock.time_msec()
@@ -224,7 +224,7 @@ class TypingHandler(object):
             )
             return
 
-        users = await self.state.get_current_users_in_room(room_id)
+        users = await self.store.get_users_in_room(room_id)
         domains = {get_domain_from_id(u) for u in users}
 
         if self.server_name in domains:
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8b24a73319..a140e9391e 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 
 from canonicaljson import json
@@ -117,7 +118,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
         except PartialDownloadError as pde:
             # Twisted is silly
             data = pde.response
-            resp_body = json.loads(data)
+            resp_body = json.loads(data.decode("utf-8"))
 
         if "success" in resp_body:
             # Note that we do NOT check the hostname here: we explicitly
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 505872ee90..6bc51202cd 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -13,13 +13,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
+
 import logging
 import urllib
 from io import BytesIO
 
 import treq
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
 from netaddr import IPAddress
 from prometheus_client import Counter
 from zope.interface import implementer, provider
@@ -31,6 +31,7 @@ from twisted.internet.interfaces import (
     IReactorPluggableNameResolver,
     IResolutionReceiver,
 )
+from twisted.internet.task import Cooperator
 from twisted.python.failure import Failure
 from twisted.web._newclient import ResponseDone
 from twisted.web.client import Agent, HTTPConnectionPool, readBody
@@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
     return False
 
 
+_EPSILON = 0.00000001
+
+
+def _make_scheduler(reactor):
+    """Makes a schedular suitable for a Cooperator using the given reactor.
+
+    (This is effectively just a copy from `twisted.internet.task`)
+    """
+
+    def _scheduler(x):
+        return reactor.callLater(_EPSILON, x)
+
+    return _scheduler
+
+
 class IPBlacklistingResolver(object):
     """
     A proxy for reactor.nameResolver which only produces non-blacklisted IP
@@ -212,6 +228,10 @@ class SimpleHttpClient(object):
         if hs.config.user_agent_suffix:
             self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
 
+        # We use this for our body producers to ensure that they use the correct
+        # reactor.
+        self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
+
         self.user_agent = self.user_agent.encode("ascii")
 
         if self._ip_blacklist:
@@ -292,7 +312,9 @@ class SimpleHttpClient(object):
             try:
                 body_producer = None
                 if data is not None:
-                    body_producer = QuieterFileBodyProducer(BytesIO(data))
+                    body_producer = QuieterFileBodyProducer(
+                        BytesIO(data), cooperator=self._cooperator,
+                    )
 
                 request_deferred = treq.request(
                     method,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 3cabe9d02e..a34e5ead88 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,9 +14,11 @@
 # limitations under the License.
 
 """ This module contains base REST classes for constructing REST servlets. """
-import json
+
 import logging
 
+from canonicaljson import json
+
 from synapse.api.errors import Codes, SynapseError
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index e4330c39d6..cc0bdfa5c9 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -239,6 +239,15 @@ class UserRestServletV2(RestServlet):
                     await self.deactivate_account_handler.deactivate_account(
                         target_user.to_string(), False
                     )
+                elif not deactivate and user["deactivated"]:
+                    if "password" not in body:
+                        raise SynapseError(
+                            400, "Must provide a password to re-activate an account."
+                        )
+
+                    await self.deactivate_account_handler.activate_account(
+                        target_user.to_string()
+                    )
 
             user = await self.admin_handler.get_user(target_user)
             return 200, user
@@ -254,7 +263,6 @@ class UserRestServletV2(RestServlet):
             admin = body.get("admin", None)
             user_type = body.get("user_type", None)
             displayname = body.get("displayname", None)
-            threepids = body.get("threepids", None)
 
             if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
                 raise SynapseError(400, "Invalid user type")
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 326ffa0056..379f668d6f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -371,7 +371,7 @@ class LoginRestServlet(RestServlet):
         token = login_submission.get("token", None)
         if token is None:
             raise LoginError(
-                401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
+                403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
             )
 
         import jwt
@@ -387,14 +387,12 @@ class LoginRestServlet(RestServlet):
         except jwt.PyJWTError as e:
             # A JWT error occurred, return some info back to the client.
             raise LoginError(
-                401,
-                "JWT validation failed: %s" % (str(e),),
-                errcode=Codes.UNAUTHORIZED,
+                403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
             )
 
         user = payload.get("sub", None)
         if user is None:
-            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
+            raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
 
         user_id = UserID(user, self.hs.hostname).to_string()
         result = await self._complete_login(
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 01b80b86fa..c5a84af047 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 """ This module contains REST servlets to do with rooms: /rooms/<paths> """
+
 import logging
 import re
 from typing import List, Optional
@@ -515,9 +516,9 @@ class RoomMessageListRestServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         pagination_config = PaginationConfig.from_request(request, default_limit=10)
         as_client_event = b"raw" not in request.args
-        filter_bytes = parse_string(request, b"filter", encoding=None)
-        if filter_bytes:
-            filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
+        filter_str = parse_string(request, b"filter", encoding="utf-8")
+        if filter_str:
+            filter_json = urlparse.unquote(filter_str)
             event_filter = Filter(json.loads(filter_json))  # type: Optional[Filter]
             if (
                 event_filter
@@ -627,9 +628,9 @@ class RoomEventContextServlet(RestServlet):
         limit = parse_integer(request, "limit", default=10)
 
         # picking the API shape for symmetry with /messages
-        filter_bytes = parse_string(request, "filter")
-        if filter_bytes:
-            filter_json = urlparse.unquote(filter_bytes)
+        filter_str = parse_string(request, b"filter", encoding="utf-8")
+        if filter_str:
+            filter_json = urlparse.unquote(filter_str)
             event_filter = Filter(json.loads(filter_json))  # type: Optional[Filter]
         else:
             event_filter = None
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index e149ac1733..9b3f85b306 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -202,9 +202,11 @@ class RemoteKey(DirectServeJsonResource):
 
                 if miss:
                     cache_misses.setdefault(server_name, set()).add(key_id)
+                # Cast to bytes since postgresql returns a memoryview.
                 json_results.add(bytes(most_recent_result["key_json"]))
             else:
                 for ts_added, result in results:
+                    # Cast to bytes since postgresql returns a memoryview.
                     json_results.add(bytes(result["key_json"]))
 
         if cache_misses and query_remote_on_cache_miss:
@@ -213,7 +215,7 @@ class RemoteKey(DirectServeJsonResource):
         else:
             signed_keys = []
             for key_json in json_results:
-                key_json = json.loads(key_json)
+                key_json = json.loads(key_json.decode("utf-8"))
                 for signing_key in self.config.key_server_signing_keys:
                     key_json = sign_json(key_json, self.config.server_name, signing_key)
 
diff --git a/synapse/server.py b/synapse/server.py
index ca42c2195a..f838a03d71 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -106,7 +106,7 @@ from synapse.server_notices.worker_server_notices_sender import (
     WorkerServerNoticesSender,
 )
 from synapse.state import StateHandler, StateResolutionHandler
-from synapse.storage import DataStores, Storage
+from synapse.storage import DataStore, DataStores, Storage
 from synapse.streams.events import EventSources
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
@@ -314,7 +314,7 @@ class HomeServer(object):
     def get_clock(self):
         return self.clock
 
-    def get_datastore(self):
+    def get_datastore(self) -> DataStore:
         return self.datastores.main
 
     def get_datastores(self):
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 58cd099e6d..cd50c721b8 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -20,6 +20,7 @@ import synapse.handlers.room
 import synapse.handlers.room_member
 import synapse.handlers.set_password
 import synapse.http.client
+import synapse.http.matrixfederationclient
 import synapse.notifier
 import synapse.push.pusherpool
 import synapse.replication.tcp.client
@@ -143,3 +144,7 @@ class HomeServer(object):
         pass
     def get_replication_streams(self) -> Dict[str, Stream]:
         pass
+    def get_http_client(
+        self,
+    ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
+        pass
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 230fb5cd7f..66f01aad84 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,7 +17,6 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from functools import wraps
 from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
 
 import attr
@@ -69,27 +68,6 @@ def encode_json(json_object):
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
-def _retry_on_integrity_error(func):
-    """Wraps a database function so that it gets retried on IntegrityError,
-    with `delete_existing=True` passed in.
-
-    Args:
-        func: function that returns a Deferred and accepts a `delete_existing` arg
-    """
-
-    @wraps(func)
-    @defer.inlineCallbacks
-    def f(self, *args, **kwargs):
-        try:
-            res = yield func(self, *args, delete_existing=False, **kwargs)
-        except self.database_engine.module.IntegrityError:
-            logger.exception("IntegrityError, retrying.")
-            res = yield func(self, *args, delete_existing=True, **kwargs)
-        return res
-
-    return f
-
-
 @attr.s(slots=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
@@ -134,7 +112,6 @@ class PersistEventsStore:
             hs.config.worker.writers.events == hs.get_instance_name()
         ), "Can only instantiate EventsStore on master"
 
-    @_retry_on_integrity_error
     @defer.inlineCallbacks
     def _persist_events_and_state_updates(
         self,
@@ -143,7 +120,6 @@ class PersistEventsStore:
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremeties: Dict[str, List[str]],
         backfilled: bool = False,
-        delete_existing: bool = False,
     ):
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
@@ -157,7 +133,6 @@ class PersistEventsStore:
             new_forward_extremities: Map from room_id to list of event IDs
                 that are the new forward extremities of the room.
             backfilled
-            delete_existing
 
         Returns:
             Deferred: resolves when the events have been persisted
@@ -197,7 +172,6 @@ class PersistEventsStore:
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
                 backfilled=backfilled,
-                delete_existing=delete_existing,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremeties=new_forward_extremeties,
             )
@@ -341,7 +315,6 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool,
-        delete_existing: bool = False,
         state_delta_for_room: Dict[str, DeltaState] = {},
         new_forward_extremeties: Dict[str, List[str]] = {},
     ):
@@ -393,13 +366,6 @@ class PersistEventsStore:
         # From this point onwards the events are only events that we haven't
         # seen before.
 
-        if delete_existing:
-            # For paranoia reasons, we go and delete all the existing entries
-            # for these events so we can reinsert them.
-            # This gets around any problems with some tables already having
-            # entries.
-            self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
-
         self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
         # Insert into event_to_state_groups.
@@ -797,39 +763,6 @@ class PersistEventsStore:
 
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
-    @classmethod
-    def _delete_existing_rows_txn(cls, txn, events_and_contexts):
-        if not events_and_contexts:
-            # nothing to do here
-            return
-
-        logger.info("Deleting existing")
-
-        for table in (
-            "events",
-            "event_auth",
-            "event_json",
-            "event_edges",
-            "event_forward_extremities",
-            "event_reference_hashes",
-            "event_search",
-            "event_to_state_groups",
-            "state_events",
-            "rejections",
-            "redactions",
-            "room_memberships",
-        ):
-            txn.executemany(
-                "DELETE FROM %s WHERE event_id = ?" % (table,),
-                [(ev.event_id,) for ev, _ in events_and_contexts],
-            )
-
-        for table in ("event_push_actions",):
-            txn.executemany(
-                "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
-                [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
-            )
-
     def _store_event_txn(self, txn, events_and_contexts):
         """Insert new events into the event and event_json tables
 
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 347cc50778..bb38a04ede 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -353,6 +353,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
         last_room_id = progress.get("last_room_id", "")
 
         def _background_remove_left_rooms_txn(txn):
+            # get a batch of room ids to consider
             sql = """
                 SELECT DISTINCT room_id FROM current_state_events
                 WHERE room_id > ? ORDER BY room_id LIMIT ?
@@ -363,24 +364,68 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
             if not room_ids:
                 return True, set()
 
+            ###########################################################################
+            #
+            # exclude rooms where we have active members
+
             sql = """
                 SELECT room_id
-                FROM current_state_events
+                FROM local_current_membership
                 WHERE
                     room_id > ? AND room_id <= ?
-                    AND type = 'm.room.member'
                     AND membership = 'join'
-                    AND state_key LIKE ?
                 GROUP BY room_id
             """
 
-            txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
-
+            txn.execute(sql, (last_room_id, room_ids[-1]))
             joined_room_ids = {row[0] for row in txn}
+            to_delete = set(room_ids) - joined_room_ids
+
+            ###########################################################################
+            #
+            # exclude rooms which we are in the process of constructing; these otherwise
+            # qualify as "rooms with no local users", and would have their
+            # forward extremities cleaned up.
+
+            # the following query will return a list of rooms which have forward
+            # extremities that are *not* also the create event in the room - ie
+            # those that are not being created currently.
+
+            sql = """
+                SELECT DISTINCT efe.room_id
+                FROM event_forward_extremities efe
+                LEFT JOIN current_state_events cse ON
+                    cse.event_id = efe.event_id
+                    AND cse.type = 'm.room.create'
+                    AND cse.state_key = ''
+                WHERE
+                    cse.event_id IS NULL
+                    AND efe.room_id > ? AND efe.room_id <= ?
+            """
+
+            txn.execute(sql, (last_room_id, room_ids[-1]))
+
+            # build a set of those rooms within `to_delete` that do not appear in
+            # the above, leaving us with the rooms in `to_delete` that *are* being
+            # created.
+            creating_rooms = to_delete.difference(row[0] for row in txn)
+            logger.info("skipping rooms which are being created: %s", creating_rooms)
+
+            # now remove the rooms being created from the list of those to delete.
+            #
+            # (we could have just taken the intersection of `to_delete` with the result
+            # of the sql query, but it's useful to be able to log `creating_rooms`; and
+            # having done so, it's quicker to remove the (few) creating rooms from
+            # `to_delete` than it is to form the intersection with the (larger) list of
+            # not-creating-rooms)
+
+            to_delete -= creating_rooms
 
-            left_rooms = set(room_ids) - joined_room_ids
+            ###########################################################################
+            #
+            # now clear the state for the rooms
 
-            logger.info("Deleting current state left rooms: %r", left_rooms)
+            logger.info("Deleting current state left rooms: %r", to_delete)
 
             # First we get all users that we still think were joined to the
             # room. This is so that we can mark those device lists as
@@ -391,7 +436,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
                 txn,
                 table="current_state_events",
                 column="room_id",
-                iterable=left_rooms,
+                iterable=to_delete,
                 keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
                 retcols=("state_key",),
             )
@@ -403,7 +448,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
                 txn,
                 table="current_state_events",
                 column="room_id",
-                iterable=left_rooms,
+                iterable=to_delete,
                 keyvalues={},
             )
 
@@ -411,7 +456,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
                 txn,
                 table="event_forward_extremities",
                 column="room_id",
-                iterable=left_rooms,
+                iterable=to_delete,
                 keyvalues={},
             )
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 1acf287ca4..cdd093ffa8 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         """If the user has no devices, we expect an empty list.
         """
         local_user = "@boris:" + self.hs.hostname
-        res = yield self.handler.query_local_devices({local_user: None})
+        res = yield defer.ensureDeferred(
+            self.handler.query_local_devices({local_user: None})
+        )
         self.assertDictEqual(res, {local_user: {}})
 
     @defer.inlineCallbacks
@@ -60,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
         }
 
-        res = yield self.handler.upload_keys_for_user(
-            local_user, device_id, {"one_time_keys": keys}
+        res = yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": keys}
+            )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
         # we should be able to change the signature without a problem
         keys["alg2:k2"]["signatures"]["k1"] = "sig2"
-        res = yield self.handler.upload_keys_for_user(
-            local_user, device_id, {"one_time_keys": keys}
+        res = yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": keys}
+            )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
@@ -84,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
         }
 
-        res = yield self.handler.upload_keys_for_user(
-            local_user, device_id, {"one_time_keys": keys}
+        res = yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": keys}
+            )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
         try:
-            yield self.handler.upload_keys_for_user(
-                local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+            yield defer.ensureDeferred(
+                self.handler.upload_keys_for_user(
+                    local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+                )
             )
             self.fail("No error when changing string key")
         except errors.SynapseError:
             pass
 
         try:
-            yield self.handler.upload_keys_for_user(
-                local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+            yield defer.ensureDeferred(
+                self.handler.upload_keys_for_user(
+                    local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+                )
             )
             self.fail("No error when replacing dict key with string")
         except errors.SynapseError:
             pass
 
         try:
-            yield self.handler.upload_keys_for_user(
-                local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
+            yield defer.ensureDeferred(
+                self.handler.upload_keys_for_user(
+                    local_user,
+                    device_id,
+                    {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+                )
             )
             self.fail("No error when replacing string key with dict")
         except errors.SynapseError:
             pass
 
         try:
-            yield self.handler.upload_keys_for_user(
-                local_user,
-                device_id,
-                {
-                    "one_time_keys": {
-                        "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
-                    }
-                },
+            yield defer.ensureDeferred(
+                self.handler.upload_keys_for_user(
+                    local_user,
+                    device_id,
+                    {
+                        "one_time_keys": {
+                            "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+                        }
+                    },
+                )
             )
             self.fail("No error when replacing dict key")
         except errors.SynapseError:
@@ -133,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         device_id = "xyz"
         keys = {"alg1:k1": "key1"}
 
-        res = yield self.handler.upload_keys_for_user(
-            local_user, device_id, {"one_time_keys": keys}
+        res = yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": keys}
+            )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
 
-        res2 = yield self.handler.claim_one_time_keys(
-            {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+        res2 = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
         )
         self.assertEqual(
             res2,
@@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(local_user, keys1)
+        )
 
         keys2 = {
             "master_key": {
@@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield self.handler.upload_signing_keys_for_user(local_user, keys2)
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(local_user, keys2)
+        )
 
-        devices = yield self.handler.query_devices(
-            {"device_keys": {local_user: []}}, 0, local_user
+        devices = yield defer.ensureDeferred(
+            self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
@@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
             "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
         )
-        yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(local_user, keys1)
+        )
 
         # upload two device keys, which will be signed later by the self-signing key
         device_key_1 = {
@@ -245,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "signatures": {local_user: {"ed25519:def": "base64+signature"}},
         }
 
-        yield self.handler.upload_keys_for_user(
-            local_user, "abc", {"device_keys": device_key_1}
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, "abc", {"device_keys": device_key_1}
+            )
         )
-        yield self.handler.upload_keys_for_user(
-            local_user, "def", {"device_keys": device_key_2}
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, "def", {"device_keys": device_key_2}
+            )
         )
 
         # sign the first device key and upload it
         del device_key_1["signatures"]
         sign.sign_json(device_key_1, local_user, signing_key)
-        yield self.handler.upload_signatures_for_device_keys(
-            local_user, {local_user: {"abc": device_key_1}}
+        yield defer.ensureDeferred(
+            self.handler.upload_signatures_for_device_keys(
+                local_user, {local_user: {"abc": device_key_1}}
+            )
         )
 
         # sign the second device key and upload both device keys.  The server
@@ -264,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         # signature for it
         del device_key_2["signatures"]
         sign.sign_json(device_key_2, local_user, signing_key)
-        yield self.handler.upload_signatures_for_device_keys(
-            local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+        yield defer.ensureDeferred(
+            self.handler.upload_signatures_for_device_keys(
+                local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+            )
         )
 
         device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
         device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
-        devices = yield self.handler.query_devices(
-            {"device_keys": {local_user: []}}, 0, local_user
+        devices = yield defer.ensureDeferred(
+            self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         del devices["device_keys"][local_user]["abc"]["unsigned"]
         del devices["device_keys"][local_user]["def"]["unsigned"]
@@ -292,7 +328,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(local_user, keys1)
+        )
 
         res = None
         try:
@@ -305,7 +343,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             res = e.code
         self.assertEqual(res, 400)
 
-        res = yield self.handler.query_local_devices({local_user: None})
+        res = yield defer.ensureDeferred(
+            self.handler.query_local_devices({local_user: None})
+        )
         self.assertDictEqual(res, {local_user: {}})
 
     @defer.inlineCallbacks
@@ -331,8 +371,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
         )
 
-        yield self.handler.upload_keys_for_user(
-            local_user, device_id, {"device_keys": device_key}
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"device_keys": device_key}
+            )
         )
 
         # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
@@ -372,7 +414,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "user_signing_key": usersigning_key,
             "self_signing_key": selfsigning_key,
         }
-        yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+        )
 
         # set up another user with a master key.  This user will be signed by
         # the first user
@@ -384,76 +428,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "usage": ["master"],
             "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
         }
-        yield self.handler.upload_signing_keys_for_user(
-            other_user, {"master_key": other_master_key}
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(
+                other_user, {"master_key": other_master_key}
+            )
         )
 
         # test various signature failures (see below)
-        ret = yield self.handler.upload_signatures_for_device_keys(
-            local_user,
-            {
-                local_user: {
-                    # fails because the signature is invalid
-                    # should fail with INVALID_SIGNATURE
-                    device_id: {
-                        "user_id": local_user,
-                        "device_id": device_id,
-                        "algorithms": [
-                            "m.olm.curve25519-aes-sha2",
-                            RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
-                        ],
-                        "keys": {
-                            "curve25519:xyz": "curve25519+key",
-                            # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
-                            "ed25519:xyz": device_pubkey,
-                        },
-                        "signatures": {
-                            local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+        ret = yield defer.ensureDeferred(
+            self.handler.upload_signatures_for_device_keys(
+                local_user,
+                {
+                    local_user: {
+                        # fails because the signature is invalid
+                        # should fail with INVALID_SIGNATURE
+                        device_id: {
+                            "user_id": local_user,
+                            "device_id": device_id,
+                            "algorithms": [
+                                "m.olm.curve25519-aes-sha2",
+                                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+                            ],
+                            "keys": {
+                                "curve25519:xyz": "curve25519+key",
+                                # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
+                                "ed25519:xyz": device_pubkey,
+                            },
+                            "signatures": {
+                                local_user: {
+                                    "ed25519:" + selfsigning_pubkey: "something"
+                                }
+                            },
                         },
-                    },
-                    # fails because device is unknown
-                    # should fail with NOT_FOUND
-                    "unknown": {
-                        "user_id": local_user,
-                        "device_id": "unknown",
-                        "signatures": {
-                            local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+                        # fails because device is unknown
+                        # should fail with NOT_FOUND
+                        "unknown": {
+                            "user_id": local_user,
+                            "device_id": "unknown",
+                            "signatures": {
+                                local_user: {
+                                    "ed25519:" + selfsigning_pubkey: "something"
+                                }
+                            },
                         },
-                    },
-                    # fails because the signature is invalid
-                    # should fail with INVALID_SIGNATURE
-                    master_pubkey: {
-                        "user_id": local_user,
-                        "usage": ["master"],
-                        "keys": {"ed25519:" + master_pubkey: master_pubkey},
-                        "signatures": {
-                            local_user: {"ed25519:" + device_pubkey: "something"}
+                        # fails because the signature is invalid
+                        # should fail with INVALID_SIGNATURE
+                        master_pubkey: {
+                            "user_id": local_user,
+                            "usage": ["master"],
+                            "keys": {"ed25519:" + master_pubkey: master_pubkey},
+                            "signatures": {
+                                local_user: {"ed25519:" + device_pubkey: "something"}
+                            },
                         },
                     },
-                },
-                other_user: {
-                    # fails because the device is not the user's master-signing key
-                    # should fail with NOT_FOUND
-                    "unknown": {
-                        "user_id": other_user,
-                        "device_id": "unknown",
-                        "signatures": {
-                            local_user: {"ed25519:" + usersigning_pubkey: "something"}
+                    other_user: {
+                        # fails because the device is not the user's master-signing key
+                        # should fail with NOT_FOUND
+                        "unknown": {
+                            "user_id": other_user,
+                            "device_id": "unknown",
+                            "signatures": {
+                                local_user: {
+                                    "ed25519:" + usersigning_pubkey: "something"
+                                }
+                            },
                         },
-                    },
-                    other_master_pubkey: {
-                        # fails because the key doesn't match what the server has
-                        # should fail with UNKNOWN
-                        "user_id": other_user,
-                        "usage": ["master"],
-                        "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
-                        "something": "random",
-                        "signatures": {
-                            local_user: {"ed25519:" + usersigning_pubkey: "something"}
+                        other_master_pubkey: {
+                            # fails because the key doesn't match what the server has
+                            # should fail with UNKNOWN
+                            "user_id": other_user,
+                            "usage": ["master"],
+                            "keys": {
+                                "ed25519:" + other_master_pubkey: other_master_pubkey
+                            },
+                            "something": "random",
+                            "signatures": {
+                                local_user: {
+                                    "ed25519:" + usersigning_pubkey: "something"
+                                }
+                            },
                         },
                     },
                 },
-            },
+            )
         )
 
         user_failures = ret["failures"][local_user]
@@ -478,19 +536,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         sign.sign_json(device_key, local_user, selfsigning_signing_key)
         sign.sign_json(master_key, local_user, device_signing_key)
         sign.sign_json(other_master_key, local_user, usersigning_signing_key)
-        ret = yield self.handler.upload_signatures_for_device_keys(
-            local_user,
-            {
-                local_user: {device_id: device_key, master_pubkey: master_key},
-                other_user: {other_master_pubkey: other_master_key},
-            },
+        ret = yield defer.ensureDeferred(
+            self.handler.upload_signatures_for_device_keys(
+                local_user,
+                {
+                    local_user: {device_id: device_key, master_pubkey: master_key},
+                    other_user: {other_master_pubkey: other_master_key},
+                },
+            )
         )
 
         self.assertEqual(ret["failures"], {})
 
         # fetch the signed keys/devices and make sure that the signatures are there
-        ret = yield self.handler.query_devices(
-            {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+        ret = yield defer.ensureDeferred(
+            self.handler.query_devices(
+                {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+            )
         )
 
         self.assertEqual(
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 822ea42dde..3362050ce0 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.get_version_info(self.local_user)
+            yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.get_version_info(self.local_user, "bogus_version")
+            yield defer.ensureDeferred(
+                self.handler.get_version_info(self.local_user, "bogus_version")
+            )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_create_version(self):
         """Check that we can create and then retrieve versions.
         """
-        res = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        res = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(res, "1")
 
         # check we can retrieve it as the current version
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         version_etag = res["etag"]
         self.assertIsInstance(version_etag, str)
         del res["etag"]
@@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # check we can retrieve it as a specific version
-        res = yield self.handler.get_version_info(self.local_user, "1")
+        res = yield defer.ensureDeferred(
+            self.handler.get_version_info(self.local_user, "1")
+        )
         self.assertEqual(res["etag"], version_etag)
         del res["etag"]
         self.assertDictEqual(
@@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # upload a new one...
-        res = yield self.handler.create_version(
-            self.local_user,
-            {
-                "algorithm": "m.megolm_backup.v1",
-                "auth_data": "second_version_auth_data",
-            },
+        res = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "second_version_auth_data",
+                },
+            )
         )
         self.assertEqual(res, "2")
 
         # check we can retrieve it as the current version
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -149,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_update_version(self):
         """Check that we can update versions.
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        res = yield self.handler.update_version(
-            self.local_user,
-            version,
-            {
-                "algorithm": "m.megolm_backup.v1",
-                "auth_data": "revised_first_version_auth_data",
-                "version": version,
-            },
+        res = yield defer.ensureDeferred(
+            self.handler.update_version(
+                self.local_user,
+                version,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                    "version": version,
+                },
+            )
         )
         self.assertDictEqual(res, {})
 
         # check we can retrieve it as the current version
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -185,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.update_version(
-                self.local_user,
-                "1",
-                {
-                    "algorithm": "m.megolm_backup.v1",
-                    "auth_data": "revised_first_version_auth_data",
-                    "version": "1",
-                },
+            yield defer.ensureDeferred(
+                self.handler.update_version(
+                    self.local_user,
+                    "1",
+                    {
+                        "algorithm": "m.megolm_backup.v1",
+                        "auth_data": "revised_first_version_auth_data",
+                        "version": "1",
+                    },
+                )
             )
         except errors.SynapseError as e:
             res = e.code
@@ -202,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_update_omitted_version(self):
         """Check that the update succeeds if the version is missing from the body
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        yield self.handler.update_version(
-            self.local_user,
-            version,
-            {
-                "algorithm": "m.megolm_backup.v1",
-                "auth_data": "revised_first_version_auth_data",
-            },
+        yield defer.ensureDeferred(
+            self.handler.update_version(
+                self.local_user,
+                version,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                },
+            )
         )
 
         # check we can retrieve it as the current version
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         del res["etag"]  # etag is opaque, so don't test its contents
         self.assertDictEqual(
             res,
@@ -234,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_update_bad_version(self):
         """Check that we get a 400 if the version in the body doesn't match
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
         res = None
         try:
-            yield self.handler.update_version(
-                self.local_user,
-                version,
-                {
-                    "algorithm": "m.megolm_backup.v1",
-                    "auth_data": "revised_first_version_auth_data",
-                    "version": "incorrect",
-                },
+            yield defer.ensureDeferred(
+                self.handler.update_version(
+                    self.local_user,
+                    version,
+                    {
+                        "algorithm": "m.megolm_backup.v1",
+                        "auth_data": "revised_first_version_auth_data",
+                        "version": "incorrect",
+                    },
+                )
             )
         except errors.SynapseError as e:
             res = e.code
@@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.delete_version(self.local_user, "1")
+            yield defer.ensureDeferred(
+                self.handler.delete_version(self.local_user, "1")
+            )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.delete_version(self.local_user)
+            yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_delete_version(self):
         """Check that we can create and then delete versions.
         """
-        res = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        res = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(res, "1")
 
         # check we can delete it
-        yield self.handler.delete_version(self.local_user, "1")
+        yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
 
         # check that it's gone
         res = None
         try:
-            yield self.handler.get_version_info(self.local_user, "1")
+            yield defer.ensureDeferred(
+                self.handler.get_version_info(self.local_user, "1")
+            )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.get_room_keys(self.local_user, "bogus_version")
+            yield defer.ensureDeferred(
+                self.handler.get_room_keys(self.local_user, "bogus_version")
+            )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_get_missing_room_keys(self):
         """Check we get an empty response from an empty backup
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        res = yield self.handler.get_room_keys(self.local_user, version)
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(self.local_user, version)
+        )
         self.assertDictEqual(res, {"rooms": {}})
 
     # TODO: test the locking semantics when uploading room_keys,
@@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.upload_room_keys(
-                self.local_user, "no_version", room_keys
+            yield defer.ensureDeferred(
+                self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
             )
         except errors.SynapseError as e:
             res = e.code
@@ -343,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """Check that we get a 404 on uploading keys when an nonexistent version
         is specified
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
         res = None
         try:
-            yield self.handler.upload_room_keys(
-                self.local_user, "bogus_version", room_keys
+            yield defer.ensureDeferred(
+                self.handler.upload_room_keys(
+                    self.local_user, "bogus_version", room_keys
+                )
             )
         except errors.SynapseError as e:
             res = e.code
@@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_upload_room_keys_wrong_version(self):
         """Check that we get a 403 on uploading keys for an old version
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        version = yield self.handler.create_version(
-            self.local_user,
-            {
-                "algorithm": "m.megolm_backup.v1",
-                "auth_data": "second_version_auth_data",
-            },
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "second_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "2")
 
         res = None
         try:
-            yield self.handler.upload_room_keys(self.local_user, "1", room_keys)
+            yield defer.ensureDeferred(
+                self.handler.upload_room_keys(self.local_user, "1", room_keys)
+            )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 403)
@@ -388,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_upload_room_keys_insert(self):
         """Check that we can insert and retrieve keys for a session
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, room_keys)
+        )
 
-        res = yield self.handler.get_room_keys(self.local_user, version)
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(self.local_user, version)
+        )
         self.assertDictEqual(res, room_keys)
 
         # check getting room_keys for a given room
-        res = yield self.handler.get_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org"
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org"
+            )
         )
         self.assertDictEqual(res, room_keys)
 
         # check getting room_keys for a given session_id
-        res = yield self.handler.get_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+            )
         )
         self.assertDictEqual(res, room_keys)
 
@@ -415,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_upload_room_keys_merge(self):
         """Check that we can upload a new room_key for an existing session and
         have it correctly merged"""
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, room_keys)
+        )
 
         # get the etag to compare to future versions
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         backup_etag = res["etag"]
         self.assertEqual(res["count"], 1)
 
@@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # test that increasing the message_index doesn't replace the existing session
         new_room_key["first_message_index"] = 2
         new_room_key["session_data"] = "new"
-        yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        )
 
-        res = yield self.handler.get_room_keys(self.local_user, version)
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(self.local_user, version)
+        )
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
             "SSBBTSBBIEZJU0gK",
         )
 
         # the etag should be the same since the session did not change
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
 
         # test that marking the session as verified however /does/ replace it
         new_room_key["is_verified"] = True
-        yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        )
 
-        res = yield self.handler.get_room_keys(self.local_user, version)
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(self.local_user, version)
+        )
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # the etag should NOT be equal now, since the key changed
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         self.assertNotEqual(res["etag"], backup_etag)
         backup_etag = res["etag"]
 
@@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # with a lower forwarding count
         new_room_key["forwarded_count"] = 2
         new_room_key["session_data"] = "other"
-        yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        )
 
-        res = yield self.handler.get_room_keys(self.local_user, version)
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(self.local_user, version)
+        )
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # the etag should be the same since the session did not change
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
 
         # TODO: check edge cases as well as the common variations here
@@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_delete_room_keys(self):
         """Check that we can insert and delete keys for a session
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
         # check for bulk-delete
-        yield self.handler.upload_room_keys(self.local_user, version, room_keys)
-        yield self.handler.delete_room_keys(self.local_user, version)
-        res = yield self.handler.get_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, room_keys)
+        )
+        yield defer.ensureDeferred(
+            self.handler.delete_room_keys(self.local_user, version)
+        )
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+            )
         )
         self.assertDictEqual(res, {"rooms": {}})
 
         # check for bulk-delete per room
-        yield self.handler.upload_room_keys(self.local_user, version, room_keys)
-        yield self.handler.delete_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org"
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, room_keys)
+        )
+        yield defer.ensureDeferred(
+            self.handler.delete_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org"
+            )
         )
-        res = yield self.handler.get_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+            )
         )
         self.assertDictEqual(res, {"rooms": {}})
 
         # check for bulk-delete per session
-        yield self.handler.upload_room_keys(self.local_user, version, room_keys)
-        yield self.handler.delete_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, room_keys)
+        )
+        yield defer.ensureDeferred(
+            self.handler.delete_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+            )
         )
-        res = yield self.handler.get_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+            )
         )
         self.assertDictEqual(res, {"rooms": {}})
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1e6a53bf7f..5878f74175 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -138,10 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
 
-        def get_current_users_in_room(room_id):
+        def get_users_in_room(room_id):
             return defer.succeed({str(u) for u in self.room_members})
 
-        hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
+        self.datastore.get_users_in_room = get_users_in_room
 
         self.datastore.get_user_directory_stream_pos.return_value = (
             # we deliberately return a non-None stream pos to avoid doing an initial_spam
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9d4f0bbe44..06575ba0a6 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, List, Optional, Tuple
+from typing import Any, Callable, List, Optional, Tuple
 
 import attr
 
@@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
     GenericWorkerServer,
 )
+from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest
-from synapse.replication.http import streams
+from synapse.replication.http import ReplicationRestResource, streams
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -35,7 +36,7 @@ from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import FakeTransport
+from tests.server import FakeTransport, render
 
 logger = logging.getLogger(__name__)
 
@@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(request.method, b"GET")
 
 
+class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
+    """Base class for tests running multiple workers.
+
+    Automatically handle HTTP replication requests from workers to master,
+    unlike `BaseStreamTestCase`.
+    """
+
+    servlets = []  # type: List[Callable[[HomeServer, JsonResource], None]]
+
+    def setUp(self):
+        super().setUp()
+
+        # build a replication server
+        self.server_factory = ReplicationStreamProtocolFactory(self.hs)
+        self.streamer = self.hs.get_replication_streamer()
+
+        store = self.hs.get_datastore()
+        self.database = store.db
+
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+        self._worker_hs_to_resource = {}
+
+        # When we see a connection attempt to the master replication listener we
+        # automatically set up the connection. This is so that tests don't
+        # manually have to go and explicitly set it up each time (plus sometimes
+        # it is impossible to write the handling explicitly in the tests).
+        self.reactor.add_tcp_client_callback(
+            "1.2.3.4", 8765, self._handle_http_replication_attempt
+        )
+
+    def create_test_json_resource(self):
+        """Overrides `HomeserverTestCase.create_test_json_resource`.
+        """
+        # We override this so that it automatically registers all the HTTP
+        # replication servlets, without having to explicitly do that in all
+        # subclassses.
+
+        resource = ReplicationRestResource(self.hs)
+
+        for servlet in self.servlets:
+            servlet(self.hs, resource)
+
+        return resource
+
+    def make_worker_hs(
+        self, worker_app: str, extra_config: dict = {}, **kwargs
+    ) -> HomeServer:
+        """Make a new worker HS instance, correctly connecting replcation
+        stream to the master HS.
+
+        Args:
+            worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
+            extra_config: Any extra config to use for this instances.
+            **kwargs: Options that get passed to `self.setup_test_homeserver`,
+                useful to e.g. pass some mocks for things like `http_client`
+
+        Returns:
+            The new worker HomeServer instance.
+        """
+
+        config = self._get_worker_hs_config()
+        config["worker_app"] = worker_app
+        config.update(extra_config)
+
+        worker_hs = self.setup_test_homeserver(
+            homeserverToUse=GenericWorkerServer,
+            config=config,
+            reactor=self.reactor,
+            **kwargs
+        )
+
+        store = worker_hs.get_datastore()
+        store.db._db_pool = self.database._db_pool
+
+        repl_handler = ReplicationCommandHandler(worker_hs)
+        client = ClientReplicationStreamProtocol(
+            worker_hs, "client", "test", self.clock, repl_handler,
+        )
+        server = self.server_factory.buildProtocol(None)
+
+        client_transport = FakeTransport(server, self.reactor)
+        client.makeConnection(client_transport)
+
+        server_transport = FakeTransport(client, self.reactor)
+        server.makeConnection(server_transport)
+
+        # Set up a resource for the worker
+        resource = ReplicationRestResource(self.hs)
+
+        for servlet in self.servlets:
+            servlet(worker_hs, resource)
+
+        self._worker_hs_to_resource[worker_hs] = resource
+
+        return worker_hs
+
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+        return config
+
+    def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
+        render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+
+    def replicate(self):
+        """Tell the master side of replication that something has happened, and then
+        wait for the replication to occur.
+        """
+        self.streamer.on_notifier_poke()
+        self.pump()
+
+    def _handle_http_replication_attempt(self):
+        """Handles a connection attempt to the master replication HTTP
+        listener.
+        """
+
+        # We should have at least one outbound connection attempt, where the
+        # last is one to the HTTP repication IP/port.
+        clients = self.reactor.tcpClients
+        self.assertGreaterEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8765)
+
+        # Set up client side protocol
+        client_protocol = client_factory.buildProtocol(None)
+
+        request_factory = OneShotRequestFactory()
+
+        # Set up the server side protocol
+        channel = _PushHTTPChannel(self.reactor)
+        channel.requestFactory = request_factory
+        channel.site = self.site
+
+        # Connect client to server and vice versa.
+        client_to_server_transport = FakeTransport(
+            channel, self.reactor, client_protocol
+        )
+        client_protocol.makeConnection(client_to_server_transport)
+
+        server_to_client_transport = FakeTransport(
+            client_protocol, self.reactor, channel
+        )
+        channel.makeConnection(server_to_client_transport)
+
+        # Note: at this point we've wired everything up, but we need to return
+        # before the data starts flowing over the connections as this is called
+        # inside `connecTCP` before the connection has been passed back to the
+        # code that requested the TCP connection.
+
+
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
 
@@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
             # We need to manually stop the _PullToPushProducer.
             self._pull_to_push_producer.stop()
 
+    def checkPersistence(self, request, version):
+        """Check whether the connection can be re-used
+        """
+        # We hijack this to always say no for ease of wiring stuff up in
+        # `handle_http_replication_attempt`.
+        request.responseHeaders.setRawHeaders(b"connection", [b"close"])
+        return False
+
 
 class _PullToPushProducer:
     """A push producer that wraps a pull producer.
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index b7d753e0a3..86c03fd89c 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -15,63 +15,26 @@
 import logging
 
 from synapse.api.constants import LoginType
-from synapse.app.generic_worker import GenericWorkerServer
-from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.rest.client.v2_alpha import register
 
-from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel, render
+from tests.server import FakeChannel
 
 logger = logging.getLogger(__name__)
 
 
-class ClientReaderTestCase(unittest.HomeserverTestCase):
+class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
     """Base class for tests of the replication streams"""
 
-    servlets = [
-        register.register_servlets,
-    ]
+    servlets = [register.register_servlets]
 
     def prepare(self, reactor, clock, hs):
-        # build a replication server
-        self.server_factory = ReplicationStreamProtocolFactory(hs)
-        self.streamer = hs.get_replication_streamer()
-
-        store = hs.get_datastore()
-        self.database = store.db
-
         self.recaptcha_checker = DummyRecaptchaChecker(hs)
         auth_handler = hs.get_auth_handler()
         auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
 
-        self.reactor.lookups["testserv"] = "1.2.3.4"
-
-    def make_worker_hs(self, extra_config={}):
-        config = self._get_worker_hs_config()
-        config.update(extra_config)
-
-        worker_hs = self.setup_test_homeserver(
-            homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor,
-        )
-
-        store = worker_hs.get_datastore()
-        store.db._db_pool = self.database._db_pool
-
-        # Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource.
-        resource = JsonResource(self.hs)
-
-        for servlet in self.servlets:
-            servlet(worker_hs, resource)
-
-        # Essentially HomeserverTestCase.render.
-        def _render(request):
-            render(request, self.resource, self.reactor)
-
-        return worker_hs, _render
-
     def _get_worker_hs_config(self) -> dict:
         config = self.default_config()
         config["worker_app"] = "synapse.app.client_reader"
@@ -82,14 +45,14 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
     def test_register_single_worker(self):
         """Test that registration works when using a single client reader worker.
         """
-        _, worker_render = self.make_worker_hs()
+        worker_hs = self.make_worker_hs("synapse.app.client_reader")
 
         request_1, channel_1 = self.make_request(
             "POST",
             "register",
             {"username": "user", "type": "m.login.password", "password": "bar"},
         )  # type: SynapseRequest, FakeChannel
-        worker_render(request_1)
+        self.render_on_worker(worker_hs, request_1)
         self.assertEqual(request_1.code, 401)
 
         # Grab the session
@@ -99,7 +62,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
         request_2, channel_2 = self.make_request(
             "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
         )  # type: SynapseRequest, FakeChannel
-        worker_render(request_2)
+        self.render_on_worker(worker_hs, request_2)
         self.assertEqual(request_2.code, 200)
 
         # We're given a registered user.
@@ -108,15 +71,15 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
     def test_register_multi_worker(self):
         """Test that registration works when using multiple client reader workers.
         """
-        _, worker_render_1 = self.make_worker_hs()
-        _, worker_render_2 = self.make_worker_hs()
+        worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
+        worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
 
         request_1, channel_1 = self.make_request(
             "POST",
             "register",
             {"username": "user", "type": "m.login.password", "password": "bar"},
         )  # type: SynapseRequest, FakeChannel
-        worker_render_1(request_1)
+        self.render_on_worker(worker_hs_1, request_1)
         self.assertEqual(request_1.code, 401)
 
         # Grab the session
@@ -126,7 +89,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
         request_2, channel_2 = self.make_request(
             "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
         )  # type: SynapseRequest, FakeChannel
-        worker_render_2(request_2)
+        self.render_on_worker(worker_hs_2, request_2)
         self.assertEqual(request_2.code, 200)
 
         # We're given a registered user.
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 519a2dc510..8d4dbf232e 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -19,132 +19,40 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.app.generic_worker import GenericWorkerServer
 from synapse.events.builder import EventBuilderFactory
-from synapse.replication.http import streams
-from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client.v1 import login, room
 from synapse.types import UserID
 
-from tests import unittest
-from tests.server import FakeTransport
+from tests.replication._base import BaseMultiWorkerStreamTestCase
 
 logger = logging.getLogger(__name__)
 
 
-class BaseStreamTestCase(unittest.HomeserverTestCase):
-    """Base class for tests of the replication streams"""
-
+class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
     servlets = [
-        streams.register_servlets,
+        login.register_servlets,
+        register_servlets_for_client_rest_resource,
+        room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
-        # build a replication server
-        self.server_factory = ReplicationStreamProtocolFactory(hs)
-        self.streamer = hs.get_replication_streamer()
-
-        store = hs.get_datastore()
-        self.database = store.db
-
-        self.reactor.lookups["testserv"] = "1.2.3.4"
-
     def default_config(self):
         conf = super().default_config()
         conf["send_federation"] = False
         return conf
 
-    def make_worker_hs(self, extra_config={}):
-        config = self._get_worker_hs_config()
-        config.update(extra_config)
-
-        mock_federation_client = Mock(spec=["put_json"])
-        mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
-
-        worker_hs = self.setup_test_homeserver(
-            http_client=mock_federation_client,
-            homeserverToUse=GenericWorkerServer,
-            config=config,
-            reactor=self.reactor,
-        )
-
-        store = worker_hs.get_datastore()
-        store.db._db_pool = self.database._db_pool
-
-        repl_handler = ReplicationCommandHandler(worker_hs)
-        client = ClientReplicationStreamProtocol(
-            worker_hs, "client", "test", self.clock, repl_handler,
-        )
-        server = self.server_factory.buildProtocol(None)
-
-        client_transport = FakeTransport(server, self.reactor)
-        client.makeConnection(client_transport)
-
-        server_transport = FakeTransport(client, self.reactor)
-        server.makeConnection(server_transport)
-
-        return worker_hs
-
-    def _get_worker_hs_config(self) -> dict:
-        config = self.default_config()
-        config["worker_app"] = "synapse.app.federation_sender"
-        config["worker_replication_host"] = "testserv"
-        config["worker_replication_http_port"] = "8765"
-        return config
-
-    def replicate(self):
-        """Tell the master side of replication that something has happened, and then
-        wait for the replication to occur.
-        """
-        self.streamer.on_notifier_poke()
-        self.pump()
-
-    def create_room_with_remote_server(self, user, token, remote_server="other_server"):
-        room = self.helper.create_room_as(user, tok=token)
-        store = self.hs.get_datastore()
-        federation = self.hs.get_handlers().federation_handler
-
-        prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
-        room_version = self.get_success(store.get_room_version(room))
-
-        factory = EventBuilderFactory(self.hs)
-        factory.hostname = remote_server
-
-        user_id = UserID("user", remote_server).to_string()
-
-        event_dict = {
-            "type": EventTypes.Member,
-            "state_key": user_id,
-            "content": {"membership": Membership.JOIN},
-            "sender": user_id,
-            "room_id": room,
-        }
-
-        builder = factory.for_room_version(room_version, event_dict)
-        join_event = self.get_success(builder.build(prev_event_ids))
-
-        self.get_success(federation.on_send_join_request(remote_server, join_event))
-        self.replicate()
-
-        return room
-
-
-class FederationSenderTestCase(BaseStreamTestCase):
-    servlets = [
-        login.register_servlets,
-        register_servlets_for_client_rest_resource,
-        room.register_servlets,
-    ]
-
     def test_send_event_single_sender(self):
         """Test that using a single federation sender worker correctly sends a
         new event.
         """
-        worker_hs = self.make_worker_hs({"send_federation": True})
-        mock_client = worker_hs.get_http_client()
+        mock_client = Mock(spec=["put_json"])
+        mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
+
+        self.make_worker_hs(
+            "synapse.app.federation_sender",
+            {"send_federation": True},
+            http_client=mock_client,
+        )
 
         user = self.register_user("user", "pass")
         token = self.login("user", "pass")
@@ -165,23 +73,29 @@ class FederationSenderTestCase(BaseStreamTestCase):
         """Test that using two federation sender workers correctly sends
         new events.
         """
-        worker1 = self.make_worker_hs(
+        mock_client1 = Mock(spec=["put_json"])
+        mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        self.make_worker_hs(
+            "synapse.app.federation_sender",
             {
                 "send_federation": True,
                 "worker_name": "sender1",
                 "federation_sender_instances": ["sender1", "sender2"],
-            }
+            },
+            http_client=mock_client1,
         )
-        mock_client1 = worker1.get_http_client()
 
-        worker2 = self.make_worker_hs(
+        mock_client2 = Mock(spec=["put_json"])
+        mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        self.make_worker_hs(
+            "synapse.app.federation_sender",
             {
                 "send_federation": True,
                 "worker_name": "sender2",
                 "federation_sender_instances": ["sender1", "sender2"],
-            }
+            },
+            http_client=mock_client2,
         )
-        mock_client2 = worker2.get_http_client()
 
         user = self.register_user("user2", "pass")
         token = self.login("user2", "pass")
@@ -191,8 +105,8 @@ class FederationSenderTestCase(BaseStreamTestCase):
         for i in range(20):
             server_name = "other_server_%d" % (i,)
             room = self.create_room_with_remote_server(user, token, server_name)
-            mock_client1.reset_mock()
-            mock_client2.reset_mock()
+            mock_client1.reset_mock()  # type: ignore[attr-defined]
+            mock_client2.reset_mock()  # type: ignore[attr-defined]
 
             self.create_and_send_event(room, UserID.from_string(user))
             self.replicate()
@@ -222,23 +136,29 @@ class FederationSenderTestCase(BaseStreamTestCase):
         """Test that using two federation sender workers correctly sends
         new typing EDUs.
         """
-        worker1 = self.make_worker_hs(
+        mock_client1 = Mock(spec=["put_json"])
+        mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        self.make_worker_hs(
+            "synapse.app.federation_sender",
             {
                 "send_federation": True,
                 "worker_name": "sender1",
                 "federation_sender_instances": ["sender1", "sender2"],
-            }
+            },
+            http_client=mock_client1,
         )
-        mock_client1 = worker1.get_http_client()
 
-        worker2 = self.make_worker_hs(
+        mock_client2 = Mock(spec=["put_json"])
+        mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        self.make_worker_hs(
+            "synapse.app.federation_sender",
             {
                 "send_federation": True,
                 "worker_name": "sender2",
                 "federation_sender_instances": ["sender1", "sender2"],
-            }
+            },
+            http_client=mock_client2,
         )
-        mock_client2 = worker2.get_http_client()
 
         user = self.register_user("user3", "pass")
         token = self.login("user3", "pass")
@@ -250,8 +170,8 @@ class FederationSenderTestCase(BaseStreamTestCase):
         for i in range(20):
             server_name = "other_server_%d" % (i,)
             room = self.create_room_with_remote_server(user, token, server_name)
-            mock_client1.reset_mock()
-            mock_client2.reset_mock()
+            mock_client1.reset_mock()  # type: ignore[attr-defined]
+            mock_client2.reset_mock()  # type: ignore[attr-defined]
 
             self.get_success(
                 typing_handler.started_typing(
@@ -284,3 +204,32 @@ class FederationSenderTestCase(BaseStreamTestCase):
 
         self.assertTrue(sent_on_1)
         self.assertTrue(sent_on_2)
+
+    def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+        room = self.helper.create_room_as(user, tok=token)
+        store = self.hs.get_datastore()
+        federation = self.hs.get_handlers().federation_handler
+
+        prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
+        room_version = self.get_success(store.get_room_version(room))
+
+        factory = EventBuilderFactory(self.hs)
+        factory.hostname = remote_server
+
+        user_id = UserID("user", remote_server).to_string()
+
+        event_dict = {
+            "type": EventTypes.Member,
+            "state_key": user_id,
+            "content": {"membership": Membership.JOIN},
+            "sender": user_id,
+            "room_id": room,
+        }
+
+        builder = factory.for_room_version(room_version, event_dict)
+        join_event = self.get_success(builder.build(prev_event_ids))
+
+        self.get_success(federation.on_send_join_request(remote_server, join_event))
+        self.replicate()
+
+        return room
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index cca5f548e6..f16eef15f7 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -857,6 +857,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(True, channel.json_body["deactivated"])
 
+    def test_reactivate_user(self):
+        """
+        Test reactivating another user.
+        """
+
+        # Deactivate the user.
+        request, channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Attempt to reactivate the user (without a password).
+        request, channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+        )
+        self.render(request)
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Reactivate the user.
+        request, channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content=json.dumps({"deactivated": False, "password": "foo"}).encode(
+                encoding="utf_8"
+            ),
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        request, channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+
     def test_set_user_as_admin(self):
         """
         Test setting the admin flag on a user.
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 4413bb3932..db52725cfe 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -547,8 +547,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     def test_login_jwt_invalid_signature(self):
         channel = self.jwt_login({"sub": "frog"}, "notsecret")
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
             "JWT validation failed: Signature verification failed",
@@ -556,8 +556,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     def test_login_jwt_expired(self):
         channel = self.jwt_login({"sub": "frog", "exp": 864000})
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"], "JWT validation failed: Signature has expired"
         )
@@ -565,8 +565,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_jwt_not_before(self):
         now = int(time.time())
         channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
             "JWT validation failed: The token is not yet valid (nbf)",
@@ -574,8 +574,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     def test_login_no_sub(self):
         channel = self.jwt_login({"username": "root"})
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Invalid JWT")
 
     @override_config(
@@ -597,16 +597,16 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
         # An invalid issuer.
         channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"], "JWT validation failed: Invalid issuer"
         )
 
         # Not providing an issuer.
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
             'JWT validation failed: Token is missing the "iss" claim',
@@ -637,16 +637,16 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
         # An invalid audience.
         channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"], "JWT validation failed: Invalid audience"
         )
 
         # Not providing an audience.
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
             'JWT validation failed: Token is missing the "aud" claim',
@@ -655,7 +655,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_aud_no_config(self):
         """Test providing an audience without requiring it in the configuration."""
         channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"], "JWT validation failed: Invalid audience"
         )
@@ -664,8 +665,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         params = json.dumps({"type": "org.matrix.login.jwt"})
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
 
 
@@ -747,8 +748,8 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
 
     def test_login_jwt_invalid_signature(self):
         channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
             "JWT validation failed: Signature verification failed",
diff --git a/tests/server.py b/tests/server.py
index a5e57c52fa..b6e0b14e78 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     def __init__(self):
         self.threadpool = ThreadPool(self)
 
+        self._tcp_callbacks = {}
         self._udp = []
         lookups = self.lookups = {}
 
@@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     def getThreadPool(self):
         return self.threadpool
 
+    def add_tcp_client_callback(self, host, port, callback):
+        """Add a callback that will be invoked when we receive a connection
+        attempt to the given IP/port using `connectTCP`.
+
+        Note that the callback gets run before we return the connection to the
+        client, which means callbacks cannot block while waiting for writes.
+        """
+        self._tcp_callbacks[(host, port)] = callback
+
+    def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
+        """Fake L{IReactorTCP.connectTCP}.
+        """
+
+        conn = super().connectTCP(
+            host, port, factory, timeout=timeout, bindAddress=None
+        )
+
+        callback = self._tcp_callbacks.get((host, port))
+        if callback:
+            callback()
+
+        return conn
+
 
 class ThreadPool:
     """
@@ -486,7 +510,7 @@ class FakeTransport(object):
         try:
             self.other.dataReceived(to_write)
         except Exception as e:
-            logger.warning("Exception writing to protocol: %s", e)
+            logger.exception("Exception writing to protocol: %s", e)
             return
 
         self.buffer = self.buffer[len(to_write) :]