summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9685.misc1
-rw-r--r--changelog.d/9686.misc1
-rw-r--r--changelog.d/9691.feature1
-rw-r--r--changelog.d/9700.feature1
-rw-r--r--changelog.d/9710.feature1
-rw-r--r--changelog.d/9711.bugfix1
-rw-r--r--changelog.d/9717.feature1
-rw-r--r--changelog.d/9718.removal1
-rw-r--r--changelog.d/9719.doc1
-rw-r--r--changelog.d/9736.misc1
-rw-r--r--changelog.d/9742.misc1
-rw-r--r--changelog.d/9743.misc1
-rw-r--r--docker/conf/homeserver.yaml8
-rw-r--r--docs/admin_api/user_admin_api.rst85
-rw-r--r--docs/code_style.md3
-rw-r--r--docs/sample_config.yaml35
-rw-r--r--mypy.ini1
-rwxr-xr-xscripts-dev/complement.sh49
-rw-r--r--synapse/api/constants.py2
-rw-r--r--synapse/api/ratelimiting.py100
-rw-r--r--synapse/api/room_versions.py24
-rw-r--r--synapse/config/api.py139
-rw-r--r--synapse/config/experimental.py7
-rw-r--r--synapse/config/registration.py4
-rw-r--r--synapse/event_auth.py28
-rw-r--r--synapse/federation/federation_server.py31
-rw-r--r--synapse/federation/sender/per_destination_queue.py8
-rw-r--r--synapse/federation/transport/server.py4
-rw-r--r--synapse/handlers/_base.py14
-rw-r--r--synapse/handlers/account_validity.py7
-rw-r--r--synapse/handlers/auth.py24
-rw-r--r--synapse/handlers/devicemessage.py40
-rw-r--r--synapse/handlers/e2e_keys.py2
-rw-r--r--synapse/handlers/federation.py163
-rw-r--r--synapse/handlers/identity.py12
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/handlers/register.py6
-rw-r--r--synapse/handlers/room_member.py23
-rw-r--r--synapse/handlers/sync.py18
-rw-r--r--synapse/handlers/typing.py6
-rw-r--r--synapse/http/client.py2
-rw-r--r--synapse/logging/opentracing.py8
-rw-r--r--synapse/notifier.py47
-rw-r--r--synapse/replication/http/register.py2
-rw-r--r--synapse/replication/tcp/redis.py2
-rw-r--r--synapse/rest/admin/users.py21
-rw-r--r--synapse/rest/client/v1/login.py14
-rw-r--r--synapse/rest/client/v2_alpha/account.py10
-rw-r--r--synapse/rest/client/v2_alpha/register.py8
-rw-r--r--synapse/server.py1
-rw-r--r--synapse/storage/databases/main/__init__.py26
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/group_server.py4
-rw-r--r--synapse/storage/databases/main/stats.py25
-rw-r--r--synapse/storage/prepare_database.py11
-rw-r--r--synapse/util/caches/deferred_cache.py4
-rw-r--r--tests/api/test_ratelimiting.py168
-rw-r--r--tests/rest/admin/test_user.py121
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py7
-rw-r--r--tests/storage/test_devices.py80
-rw-r--r--tests/storage/test_directory.py44
-rw-r--r--tests/storage/test_end_to_end_keys.py59
-rw-r--r--tests/storage/test_event_push_actions.py133
-rw-r--r--tests/storage/test_profile.py35
-rw-r--r--tests/storage/test_redaction.py12
-rw-r--r--tests/storage/test_registration.py108
-rw-r--r--tests/storage/test_room.py61
-rw-r--r--tests/storage/test_state.py145
-rw-r--r--tests/storage/test_user_directory.py86
-rw-r--r--tests/test_event_auth.py246
-rw-r--r--tests/utils.py1
71 files changed, 1472 insertions, 880 deletions
diff --git a/changelog.d/9685.misc b/changelog.d/9685.misc
new file mode 100644
index 0000000000..0506d8af0c
--- /dev/null
+++ b/changelog.d/9685.misc
@@ -0,0 +1 @@
+Update `scripts-dev/complement.sh` to use a local checkout of Complement, allow running a subset of tests and have it use Synapse's Complement test blacklist.
\ No newline at end of file
diff --git a/changelog.d/9686.misc b/changelog.d/9686.misc
new file mode 100644
index 0000000000..bb2335acf9
--- /dev/null
+++ b/changelog.d/9686.misc
@@ -0,0 +1 @@
+Improve Jaeger tracing for `to_device` messages.
diff --git a/changelog.d/9691.feature b/changelog.d/9691.feature
new file mode 100644
index 0000000000..3c711db4f5
--- /dev/null
+++ b/changelog.d/9691.feature
@@ -0,0 +1 @@
+Add `order_by` to the admin API `GET /_synapse/admin/v2/users`. Contributed by @dklimpel. 
\ No newline at end of file
diff --git a/changelog.d/9700.feature b/changelog.d/9700.feature
new file mode 100644
index 0000000000..037de8367f
--- /dev/null
+++ b/changelog.d/9700.feature
@@ -0,0 +1 @@
+Replace the `room_invite_state_types` configuration setting with `room_prejoin_state`.
diff --git a/changelog.d/9710.feature b/changelog.d/9710.feature
new file mode 100644
index 0000000000..fce308cc41
--- /dev/null
+++ b/changelog.d/9710.feature
@@ -0,0 +1 @@
+Experimental Spaces support: include `m.room.create` in the room state sent with room-invites.
diff --git a/changelog.d/9711.bugfix b/changelog.d/9711.bugfix
new file mode 100644
index 0000000000..4ca3438d46
--- /dev/null
+++ b/changelog.d/9711.bugfix
@@ -0,0 +1 @@
+Fix recently added ratelimits to correctly honour the application service `rate_limited` flag.
diff --git a/changelog.d/9717.feature b/changelog.d/9717.feature
new file mode 100644
index 0000000000..c2c74f13d5
--- /dev/null
+++ b/changelog.d/9717.feature
@@ -0,0 +1 @@
+Add experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.
diff --git a/changelog.d/9718.removal b/changelog.d/9718.removal
new file mode 100644
index 0000000000..6de7814217
--- /dev/null
+++ b/changelog.d/9718.removal
@@ -0,0 +1 @@
+Replace deprecated `imp` module with successor `importlib`. Contributed by Cristina Muñoz.
diff --git a/changelog.d/9719.doc b/changelog.d/9719.doc
new file mode 100644
index 0000000000..f018606dd6
--- /dev/null
+++ b/changelog.d/9719.doc
@@ -0,0 +1 @@
+Make the allowed_local_3pids regex example in the sample config stricter.
diff --git a/changelog.d/9736.misc b/changelog.d/9736.misc
new file mode 100644
index 0000000000..1e445e4344
--- /dev/null
+++ b/changelog.d/9736.misc
@@ -0,0 +1 @@
+Convert various testcases to `HomeserverTestCase`.
diff --git a/changelog.d/9742.misc b/changelog.d/9742.misc
new file mode 100644
index 0000000000..681ab04df8
--- /dev/null
+++ b/changelog.d/9742.misc
@@ -0,0 +1 @@
+Start linting mypy with `no_implicit_optional`.
\ No newline at end of file
diff --git a/changelog.d/9743.misc b/changelog.d/9743.misc
new file mode 100644
index 0000000000..c2f75c1df9
--- /dev/null
+++ b/changelog.d/9743.misc
@@ -0,0 +1 @@
+Add missing type hints to federation handler and server.
diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml
index 0dea62a87d..a792899540 100644
--- a/docker/conf/homeserver.yaml
+++ b/docker/conf/homeserver.yaml
@@ -173,18 +173,10 @@ report_stats: False
 
 ## API Configuration ##
 
-room_invite_state_types:
-    - "m.room.join_rules"
-    - "m.room.canonical_alias"
-    - "m.room.avatar"
-    - "m.room.name"
-
 {% if SYNAPSE_APPSERVICES %}
 app_service_config_files:
 {% for appservice in SYNAPSE_APPSERVICES %}    - "{{ appservice }}"
 {% endfor %}
-{% else %}
-app_service_config_files: []
 {% endif %}
 
 macaroon_secret_key: "{{ SYNAPSE_MACAROON_SECRET_KEY }}"
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 8d4ec5a6f9..a8a5a2628c 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -111,35 +111,16 @@ List Accounts
 =============
 
 This API returns all local user accounts.
+By default, the response is ordered by ascending user ID.
 
-The api is::
+The API is::
 
     GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
 
 To use it, you will need to authenticate by providing an ``access_token`` for a
 server admin: see `README.rst <README.rst>`_.
 
-The parameter ``from`` is optional but used for pagination, denoting the
-offset in the returned results. This should be treated as an opaque value and
-not explicitly set to anything other than the return value of ``next_token``
-from a previous call.
-
-The parameter ``limit`` is optional but is used for pagination, denoting the
-maximum number of items to return in this call. Defaults to ``100``.
-
-The parameter ``user_id`` is optional and filters to only return users with user IDs
-that contain this value. This parameter is ignored when using the ``name`` parameter.
-
-The parameter ``name`` is optional and filters to only return users with user ID localparts
-**or** displaynames that contain this value.
-
-The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
-Defaults to ``true`` to include guest users.
-
-The parameter ``deactivated`` is optional and if ``true`` will **include** deactivated users.
-Defaults to ``false`` to exclude deactivated users.
-
-A JSON body is returned with the following shape:
+A response body like the following is returned:
 
 .. code:: json
 
@@ -175,6 +156,66 @@ with ``from`` set to the value of ``next_token``. This will return a new page.
 If the endpoint does not return a ``next_token`` then there are no more users
 to paginate through.
 
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - Is optional and filters to only return users with user IDs
+  that contain this value. This parameter is ignored when using the ``name`` parameter.
+- ``name`` - Is optional and filters to only return users with user ID localparts
+  **or** displaynames that contain this value.
+- ``guests`` - string representing a bool - Is optional and if ``false`` will **exclude** guest users.
+  Defaults to ``true`` to include guest users.
+- ``deactivated`` - string representing a bool - Is optional and if ``true`` will **include** deactivated users.
+  Defaults to ``false`` to exclude deactivated users.
+- ``limit`` - string representing a positive integer - Is optional but is used for pagination,
+  denoting the maximum number of items to return in this call. Defaults to ``100``.
+- ``from`` - string representing a positive integer - Is optional but used for pagination,
+  denoting the offset in the returned results. This should be treated as an opaque value and
+  not explicitly set to anything other than the return value of ``next_token`` from a previous call.
+  Defaults to ``0``.
+- ``order_by`` - The method by which to sort the returned list of users.
+  If the ordered field has duplicates, the second order is always by ascending ``name``,
+  which guarantees a stable ordering. Valid values are:
+
+  - ``name`` - Users are ordered alphabetically by ``name``. This is the default.
+  - ``is_guest`` - Users are ordered by ``is_guest`` status.
+  - ``admin`` - Users are ordered by ``admin`` status.
+  - ``user_type`` - Users are ordered alphabetically by ``user_type``.
+  - ``deactivated`` - Users are ordered by ``deactivated`` status.
+  - ``shadow_banned`` - Users are ordered by ``shadow_banned`` status.
+  - ``displayname`` - Users are ordered alphabetically by ``displayname``.
+  - ``avatar_url`` - Users are ordered alphabetically by avatar URL.
+
+- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards.
+  Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``.
+
+Caution. The database only has indexes on the columns ``name`` and ``created_ts``.
+This means that if a different sort order is used (``is_guest``, ``admin``,
+``user_type``, ``deactivated``, ``shadow_banned``, ``avatar_url`` or ``displayname``),
+this can cause a large load on the database, especially for large environments.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``users`` - An array of objects, each containing information about an user.
+  User objects contain the following fields:
+
+  - ``name`` - string - Fully-qualified user ID (ex. `@user:server.com`).
+  - ``is_guest`` - bool - Status if that user is a guest account.
+  - ``admin`` - bool - Status if that user is a server administrator.
+  - ``user_type`` - string - Type of the user. Normal users are type ``None``.
+    This allows user type specific behaviour. There are also types ``support`` and ``bot``. 
+  - ``deactivated`` - bool - Status if that user has been marked as deactivated.
+  - ``shadow_banned`` - bool - Status if that user has been marked as shadow banned.
+  - ``displayname`` - string - The user's display name if they have set one.
+  - ``avatar_url`` - string -  The user's avatar URL if they have set one.
+
+- ``next_token``: string representing a positive integer - Indication for pagination. See above.
+- ``total`` - integer - Total number of media.
+
+
 Query current sessions for a user
 =================================
 
diff --git a/docs/code_style.md b/docs/code_style.md
index 190f8ab2de..28fb7277c4 100644
--- a/docs/code_style.md
+++ b/docs/code_style.md
@@ -128,6 +128,9 @@ Some guidelines follow:
     will be if no sub-options are enabled).
 -   Lines should be wrapped at 80 characters.
 -   Use two-space indents.
+-   `true` and `false` are spelt thus (as opposed to `True`, etc.)
+-   Use single quotes (`'`) rather than double-quotes (`"`) or backticks
+    (`` ` ``) to refer to configuration options.
 
 Example:
 
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 17cda71adc..b0bf987740 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1246,9 +1246,9 @@ account_validity:
 #
 #allowed_local_3pids:
 #  - medium: email
-#    pattern: '.*@matrix\.org'
+#    pattern: '^[^@]+@matrix\.org$'
 #  - medium: email
-#    pattern: '.*@vector\.im'
+#    pattern: '^[^@]+@vector\.im$'
 #  - medium: msisdn
 #    pattern: '\+44'
 
@@ -1451,14 +1451,31 @@ metrics_flags:
 
 ## API Configuration ##
 
-# A list of event types that will be included in the room_invite_state
+# Controls for the state that is shared with users who receive an invite
+# to a room
 #
-#room_invite_state_types:
-#  - "m.room.join_rules"
-#  - "m.room.canonical_alias"
-#  - "m.room.avatar"
-#  - "m.room.encryption"
-#  - "m.room.name"
+room_prejoin_state:
+   # By default, the following state event types are shared with users who
+   # receive invites to the room:
+   #
+   # - m.room.join_rules
+   # - m.room.canonical_alias
+   # - m.room.avatar
+   # - m.room.encryption
+   # - m.room.name
+   #
+   # Uncomment the following to disable these defaults (so that only the event
+   # types listed in 'additional_event_types' are shared). Defaults to 'false'.
+   #
+   #disable_default_event_types: true
+
+   # Additional state event types to share with users when they are invited
+   # to a room.
+   #
+   # By default, this list is empty (so only the default event types are shared).
+   #
+   #additional_event_types:
+   #  - org.example.custom.event.type
 
 
 # A list of application service config files to use
diff --git a/mypy.ini b/mypy.ini
index 3ae5d45787..32e6197409 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -8,6 +8,7 @@ show_traceback = True
 mypy_path = stubs
 warn_unreachable = True
 local_partial_types = True
+no_implicit_optional = True
 
 # To find all folders that pass mypy you run:
 #
diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index 3cde53f5c0..31cc20a826 100755
--- a/scripts-dev/complement.sh
+++ b/scripts-dev/complement.sh
@@ -1,22 +1,49 @@
-#! /bin/bash -eu
+#!/usr/bin/env bash
 # This script is designed for developers who want to test their code
 # against Complement.
 #
 # It makes a Synapse image which represents the current checkout,
-# then downloads Complement and runs it with that image.
+# builds a synapse-complement image on top, then runs tests with it.
+#
+# By default the script will fetch the latest Complement master branch and
+# run tests with that. This can be overridden to use a custom Complement
+# checkout by setting the COMPLEMENT_DIR environment variable to the
+# filepath of a local Complement checkout.
+#
+# A regular expression of test method names can be supplied as the first
+# argument to the script. Complement will then only run those tests. If
+# no regex is supplied, all tests are run. For example;
+#
+# ./complement.sh "TestOutboundFederation(Profile|Send)"
+#
+
+# Exit if a line returns a non-zero exit code
+set -e
 
+# Change to the repository root
 cd "$(dirname $0)/.."
 
+# Check for a user-specified Complement checkout
+if [[ -z "$COMPLEMENT_DIR" ]]; then
+  echo "COMPLEMENT_DIR not set. Fetching the latest Complement checkout..."
+  wget -Nq https://github.com/matrix-org/complement/archive/master.tar.gz
+  tar -xzf master.tar.gz
+  COMPLEMENT_DIR=complement-master
+  echo "Checkout available at 'complement-master'"
+fi
+
 # Build the base Synapse image from the local checkout
-docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile .
+docker build -t matrixdotorg/synapse -f docker/Dockerfile .
+# Build the Synapse monolith image from Complement, based on the above image we just built
+docker build -t complement-synapse -f "$COMPLEMENT_DIR/dockerfiles/Synapse.Dockerfile" "$COMPLEMENT_DIR/dockerfiles"
 
-# Download Complement
-wget -N https://github.com/matrix-org/complement/archive/master.tar.gz
-tar -xzf master.tar.gz
-cd complement-master
+cd "$COMPLEMENT_DIR"
 
-# Build the Synapse image from Complement, based on the above image we just built
-docker build -t complement-synapse -f dockerfiles/Synapse.Dockerfile ./dockerfiles
+EXTRA_COMPLEMENT_ARGS=""
+if [[ -n "$1" ]]; then
+  # A test name regex has been set, supply it to Complement
+  EXTRA_COMPLEMENT_ARGS+="-run $1 "
+fi
 
-# Run the tests on the resulting image!
-COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -count=1 ./tests
+# Run the tests!
+COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -tags synapse_blacklist -count=1 $EXTRA_COMPLEMENT_ARGS ./tests
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 8f37d2cf3b..6856dab06c 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -59,6 +59,8 @@ class JoinRules:
     KNOCK = "knock"
     INVITE = "invite"
     PRIVATE = "private"
+    # As defined for MSC3083.
+    MSC3083_RESTRICTED = "restricted"
 
 
 class LoginType:
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index c3f07bc1a3..2244b8a340 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -17,6 +17,7 @@ from collections import OrderedDict
 from typing import Hashable, Optional, Tuple
 
 from synapse.api.errors import LimitExceededError
+from synapse.storage.databases.main import DataStore
 from synapse.types import Requester
 from synapse.util import Clock
 
@@ -31,10 +32,13 @@ class Ratelimiter:
         burst_count: How many actions that can be performed before being limited.
     """
 
-    def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
+    def __init__(
+        self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
+    ):
         self.clock = clock
         self.rate_hz = rate_hz
         self.burst_count = burst_count
+        self.store = store
 
         # A ordered dictionary keeping track of actions, when they were last
         # performed and how often. Each entry is a mapping from a key of arbitrary type
@@ -46,45 +50,10 @@ class Ratelimiter:
             OrderedDict()
         )  # type: OrderedDict[Hashable, Tuple[float, int, float]]
 
-    def can_requester_do_action(
-        self,
-        requester: Requester,
-        rate_hz: Optional[float] = None,
-        burst_count: Optional[int] = None,
-        update: bool = True,
-        _time_now_s: Optional[int] = None,
-    ) -> Tuple[bool, float]:
-        """Can the requester perform the action?
-
-        Args:
-            requester: The requester to key off when rate limiting. The user property
-                will be used.
-            rate_hz: The long term number of actions that can be performed in a second.
-                Overrides the value set during instantiation if set.
-            burst_count: How many actions that can be performed before being limited.
-                Overrides the value set during instantiation if set.
-            update: Whether to count this check as performing the action
-            _time_now_s: The current time. Optional, defaults to the current time according
-                to self.clock. Only used by tests.
-
-        Returns:
-            A tuple containing:
-                * A bool indicating if they can perform the action now
-                * The reactor timestamp for when the action can be performed next.
-                  -1 if rate_hz is less than or equal to zero
-        """
-        # Disable rate limiting of users belonging to any AS that is configured
-        # not to be rate limited in its registration file (rate_limited: true|false).
-        if requester.app_service and not requester.app_service.is_rate_limited():
-            return True, -1.0
-
-        return self.can_do_action(
-            requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
-        )
-
-    def can_do_action(
+    async def can_do_action(
         self,
-        key: Hashable,
+        requester: Optional[Requester],
+        key: Optional[Hashable] = None,
         rate_hz: Optional[float] = None,
         burst_count: Optional[int] = None,
         update: bool = True,
@@ -92,9 +61,16 @@ class Ratelimiter:
     ) -> Tuple[bool, float]:
         """Can the entity (e.g. user or IP address) perform the action?
 
+        Checks if the user has ratelimiting disabled in the database by looking
+        for null/zero values in the `ratelimit_override` table. (Non-zero
+        values aren't honoured, as they're specific to the event sending
+        ratelimiter, rather than all ratelimiters)
+
         Args:
-            key: The key we should use when rate limiting. Can be a user ID
-                (when sending events), an IP address, etc.
+            requester: The requester that is doing the action, if any. Used to check
+                if the user has ratelimits disabled in the database.
+            key: An arbitrary key used to classify an action. Defaults to the
+                requester's user ID.
             rate_hz: The long term number of actions that can be performed in a second.
                 Overrides the value set during instantiation if set.
             burst_count: How many actions that can be performed before being limited.
@@ -109,6 +85,30 @@ class Ratelimiter:
                 * The reactor timestamp for when the action can be performed next.
                   -1 if rate_hz is less than or equal to zero
         """
+        if key is None:
+            if not requester:
+                raise ValueError("Must supply at least one of `requester` or `key`")
+
+            key = requester.user.to_string()
+
+        if requester:
+            # Disable rate limiting of users belonging to any AS that is configured
+            # not to be rate limited in its registration file (rate_limited: true|false).
+            if requester.app_service and not requester.app_service.is_rate_limited():
+                return True, -1.0
+
+            # Check if ratelimiting has been disabled for the user.
+            #
+            # Note that we don't use the returned rate/burst count, as the table
+            # is specifically for the event sending ratelimiter. Instead, we
+            # only use it to (somewhat cheekily) infer whether the user should
+            # be subject to any rate limiting or not.
+            override = await self.store.get_ratelimit_for_user(
+                requester.authenticated_entity
+            )
+            if override and not override.messages_per_second:
+                return True, -1.0
+
         # Override default values if set
         time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
         rate_hz = rate_hz if rate_hz is not None else self.rate_hz
@@ -175,9 +175,10 @@ class Ratelimiter:
             else:
                 del self.actions[key]
 
-    def ratelimit(
+    async def ratelimit(
         self,
-        key: Hashable,
+        requester: Optional[Requester],
+        key: Optional[Hashable] = None,
         rate_hz: Optional[float] = None,
         burst_count: Optional[int] = None,
         update: bool = True,
@@ -185,8 +186,16 @@ class Ratelimiter:
     ):
         """Checks if an action can be performed. If not, raises a LimitExceededError
 
+        Checks if the user has ratelimiting disabled in the database by looking
+        for null/zero values in the `ratelimit_override` table. (Non-zero
+        values aren't honoured, as they're specific to the event sending
+        ratelimiter, rather than all ratelimiters)
+
         Args:
-            key: An arbitrary key used to classify an action
+            requester: The requester that is doing the action, if any. Used to check for
+                if the user has ratelimits disabled.
+            key: An arbitrary key used to classify an action. Defaults to the
+                requester's user ID.
             rate_hz: The long term number of actions that can be performed in a second.
                 Overrides the value set during instantiation if set.
             burst_count: How many actions that can be performed before being limited.
@@ -201,7 +210,8 @@ class Ratelimiter:
         """
         time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
 
-        allowed, time_allowed = self.can_do_action(
+        allowed, time_allowed = await self.can_do_action(
+            requester,
             key,
             rate_hz=rate_hz,
             burst_count=burst_count,
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index de2cc15d33..87038d436d 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -57,7 +57,7 @@ class RoomVersion:
     state_res = attr.ib(type=int)  # one of the StateResolutionVersions
     enforce_key_validity = attr.ib(type=bool)
 
-    # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
+    # Before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
     special_case_aliases_auth = attr.ib(type=bool)
     # Strictly enforce canonicaljson, do not allow:
     # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
@@ -69,6 +69,8 @@ class RoomVersion:
     limit_notifications_power_levels = attr.ib(type=bool)
     # MSC2174/MSC2176: Apply updated redaction rules algorithm.
     msc2176_redaction_rules = attr.ib(type=bool)
+    # MSC3083: Support the 'restricted' join_rule.
+    msc3083_join_rules = attr.ib(type=bool)
 
 
 class RoomVersions:
@@ -82,6 +84,7 @@ class RoomVersions:
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
         msc2176_redaction_rules=False,
+        msc3083_join_rules=False,
     )
     V2 = RoomVersion(
         "2",
@@ -93,6 +96,7 @@ class RoomVersions:
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
         msc2176_redaction_rules=False,
+        msc3083_join_rules=False,
     )
     V3 = RoomVersion(
         "3",
@@ -104,6 +108,7 @@ class RoomVersions:
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
         msc2176_redaction_rules=False,
+        msc3083_join_rules=False,
     )
     V4 = RoomVersion(
         "4",
@@ -115,6 +120,7 @@ class RoomVersions:
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
         msc2176_redaction_rules=False,
+        msc3083_join_rules=False,
     )
     V5 = RoomVersion(
         "5",
@@ -126,6 +132,7 @@ class RoomVersions:
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
         msc2176_redaction_rules=False,
+        msc3083_join_rules=False,
     )
     V6 = RoomVersion(
         "6",
@@ -137,6 +144,7 @@ class RoomVersions:
         strict_canonicaljson=True,
         limit_notifications_power_levels=True,
         msc2176_redaction_rules=False,
+        msc3083_join_rules=False,
     )
     MSC2176 = RoomVersion(
         "org.matrix.msc2176",
@@ -148,6 +156,19 @@ class RoomVersions:
         strict_canonicaljson=True,
         limit_notifications_power_levels=True,
         msc2176_redaction_rules=True,
+        msc3083_join_rules=False,
+    )
+    MSC3083 = RoomVersion(
+        "org.matrix.msc3083",
+        RoomDisposition.UNSTABLE,
+        EventFormatVersions.V3,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+        msc3083_join_rules=True,
     )
 
 
@@ -162,4 +183,5 @@ KNOWN_ROOM_VERSIONS = {
         RoomVersions.V6,
         RoomVersions.MSC2176,
     )
+    # Note that we do not include MSC3083 here unless it is enabled in the config.
 }  # type: Dict[str, RoomVersion]
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 74cd53a8ed..55c038c0c4 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -1,4 +1,4 @@
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,38 +12,131 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
+from typing import Iterable
+
 from synapse.api.constants import EventTypes
+from synapse.config._base import Config, ConfigError
+from synapse.config._util import validate_config
+from synapse.types import JsonDict
 
-from ._base import Config
+logger = logging.getLogger(__name__)
 
 
 class ApiConfig(Config):
     section = "api"
 
-    def read_config(self, config, **kwargs):
-        self.room_invite_state_types = config.get(
-            "room_invite_state_types",
-            [
-                EventTypes.JoinRules,
-                EventTypes.CanonicalAlias,
-                EventTypes.RoomAvatar,
-                EventTypes.RoomEncryption,
-                EventTypes.Name,
-            ],
+    def read_config(self, config: JsonDict, **kwargs):
+        validate_config(_MAIN_SCHEMA, config, ())
+        self.room_prejoin_state = list(self._get_prejoin_state_types(config))
+
+    def generate_config_section(cls, **kwargs) -> str:
+        formatted_default_state_types = "\n".join(
+            "           # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES
         )
 
-    def generate_config_section(cls, **kwargs):
         return """\
         ## API Configuration ##
 
-        # A list of event types that will be included in the room_invite_state
+        # Controls for the state that is shared with users who receive an invite
+        # to a room
         #
-        #room_invite_state_types:
-        #  - "{JoinRules}"
-        #  - "{CanonicalAlias}"
-        #  - "{RoomAvatar}"
-        #  - "{RoomEncryption}"
-        #  - "{Name}"
-        """.format(
-            **vars(EventTypes)
-        )
+        room_prejoin_state:
+           # By default, the following state event types are shared with users who
+           # receive invites to the room:
+           #
+%(formatted_default_state_types)s
+           #
+           # Uncomment the following to disable these defaults (so that only the event
+           # types listed in 'additional_event_types' are shared). Defaults to 'false'.
+           #
+           #disable_default_event_types: true
+
+           # Additional state event types to share with users when they are invited
+           # to a room.
+           #
+           # By default, this list is empty (so only the default event types are shared).
+           #
+           #additional_event_types:
+           #  - org.example.custom.event.type
+        """ % {
+            "formatted_default_state_types": formatted_default_state_types
+        }
+
+    def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
+        """Get the event types to include in the prejoin state
+
+        Parses the config and returns an iterable of the event types to be included.
+        """
+        room_prejoin_state_config = config.get("room_prejoin_state") or {}
+
+        # backwards-compatibility support for room_invite_state_types
+        if "room_invite_state_types" in config:
+            # if both "room_invite_state_types" and "room_prejoin_state" are set, then
+            # we don't really know what to do.
+            if room_prejoin_state_config:
+                raise ConfigError(
+                    "Can't specify both 'room_invite_state_types' and 'room_prejoin_state' "
+                    "in config"
+                )
+
+            logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
+
+            yield from config["room_invite_state_types"]
+            return
+
+        if not room_prejoin_state_config.get("disable_default_event_types"):
+            yield from _DEFAULT_PREJOIN_STATE_TYPES
+
+            if self.spaces_enabled:
+                # MSC1772 suggests adding m.room.create to the prejoin state
+                yield EventTypes.Create
+
+        yield from room_prejoin_state_config.get("additional_event_types", [])
+
+
+_ROOM_INVITE_STATE_TYPES_WARNING = """\
+WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
+and replaced with 'room_prejoin_state'. New features may not work correctly
+unless 'room_invite_state_types' is removed. See the sample configuration file for
+details of 'room_prejoin_state'.
+--------------------------------------------------------------------------------
+"""
+
+_DEFAULT_PREJOIN_STATE_TYPES = [
+    EventTypes.JoinRules,
+    EventTypes.CanonicalAlias,
+    EventTypes.RoomAvatar,
+    EventTypes.RoomEncryption,
+    EventTypes.Name,
+]
+
+
+# room_prejoin_state can either be None (as it is in the default config), or
+# an object containing other config settings
+_ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
+    "oneOf": [
+        {
+            "type": "object",
+            "properties": {
+                "disable_default_event_types": {"type": "boolean"},
+                "additional_event_types": {
+                    "type": "array",
+                    "items": {"type": "string"},
+                },
+            },
+        },
+        {"type": "null"},
+    ]
+}
+
+# the legacy room_invite_state_types setting
+_ROOM_INVITE_STATE_TYPES_SCHEMA = {"type": "array", "items": {"type": "string"}}
+
+_MAIN_SCHEMA = {
+    "type": "object",
+    "properties": {
+        "room_prejoin_state": _ROOM_PREJOIN_STATE_CONFIG_SCHEMA,
+        "room_invite_state_types": _ROOM_INVITE_STATE_TYPES_SCHEMA,
+    },
+}
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 86f4d9af9d..eb96ecda74 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.config._base import Config
 from synapse.types import JsonDict
 
@@ -27,7 +28,11 @@ class ExperimentalConfig(Config):
 
         # MSC2858 (multiple SSO identity providers)
         self.msc2858_enabled = experimental.get("msc2858_enabled", False)  # type: bool
-        # Spaces (MSC1772, MSC2946, etc)
+
+        # Spaces (MSC1772, MSC2946, MSC3083, etc)
         self.spaces_enabled = experimental.get("spaces_enabled", False)  # type: bool
+        if self.spaces_enabled:
+            KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083
+
         # MSC3026 (busy presence state)
         self.msc3026_enabled = experimental.get("msc3026_enabled", False)  # type: bool
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ead007ba5a..f27d1e14ac 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -298,9 +298,9 @@ class RegistrationConfig(Config):
         #
         #allowed_local_3pids:
         #  - medium: email
-        #    pattern: '.*@matrix\\.org'
+        #    pattern: '^[^@]+@matrix\\.org$'
         #  - medium: email
-        #    pattern: '.*@vector\\.im'
+        #    pattern: '^[^@]+@vector\\.im$'
         #  - medium: msisdn
         #    pattern: '\\+44'
 
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 91ad5b3d3c..9863953f5c 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -162,7 +162,7 @@ def check(
         logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
 
     if event.type == EventTypes.Member:
-        _is_membership_change_allowed(event, auth_events)
+        _is_membership_change_allowed(room_version_obj, event, auth_events)
         logger.debug("Allowing! %s", event)
         return
 
@@ -220,8 +220,19 @@ def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
 
 
 def _is_membership_change_allowed(
-    event: EventBase, auth_events: StateMap[EventBase]
+    room_version: RoomVersion, event: EventBase, auth_events: StateMap[EventBase]
 ) -> None:
+    """
+    Confirms that the event which changes membership is an allowed change.
+
+    Args:
+        room_version: The version of the room.
+        event: The event to check.
+        auth_events: The current auth events of the room.
+
+    Raises:
+        AuthError if the event is not allowed.
+    """
     membership = event.content["membership"]
 
     # Check if this is the room creator joining:
@@ -315,14 +326,19 @@ def _is_membership_change_allowed(
             if user_level < invite_level:
                 raise AuthError(403, "You don't have permission to invite users")
     elif Membership.JOIN == membership:
-        # Joins are valid iff caller == target and they were:
-        # invited: They are accepting the invitation
-        # joined: It's a NOOP
+        # Joins are valid iff caller == target and:
+        # * They are not banned.
+        # * They are accepting a previously sent invitation.
+        # * They are already joined (it's a NOOP).
+        # * The room is public or restricted.
         if event.user_id != target_user_id:
             raise AuthError(403, "Cannot force another user to join.")
         elif target_banned:
             raise AuthError(403, "You are banned from this room")
-        elif join_rule == JoinRules.PUBLIC:
+        elif join_rule == JoinRules.PUBLIC or (
+            room_version.msc3083_join_rules
+            and join_rule == JoinRules.MSC3083_RESTRICTED
+        ):
             pass
         elif join_rule == JoinRules.INVITE:
             if not caller_in_room and not caller_invited:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d84e362070..b9f8d966a6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -739,22 +739,20 @@ class FederationServer(FederationBase):
 
         await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "<ReplicationLayer(%s)>" % self.server_name
 
     async def exchange_third_party_invite(
         self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
-    ):
-        ret = await self.handler.exchange_third_party_invite(
+    ) -> None:
+        await self.handler.exchange_third_party_invite(
             sender_user_id, target_user_id, room_id, signed
         )
-        return ret
 
-    async def on_exchange_third_party_invite_request(self, event_dict: Dict):
-        ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
-        return ret
+    async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
+        await self.handler.on_exchange_third_party_invite_request(event_dict)
 
-    async def check_server_matches_acl(self, server_name: str, room_id: str):
+    async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
         """Check if the given server is allowed by the server ACLs in the room
 
         Args:
@@ -870,6 +868,7 @@ class FederationHandlerRegistry:
 
         # A rate limiter for incoming room key requests per origin.
         self._room_key_request_rate_limiter = Ratelimiter(
+            store=hs.get_datastore(),
             clock=self.clock,
             rate_hz=self.config.rc_key_requests.per_second,
             burst_count=self.config.rc_key_requests.burst_count,
@@ -877,7 +876,7 @@ class FederationHandlerRegistry:
 
     def register_edu_handler(
         self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
-    ):
+    ) -> None:
         """Sets the handler callable that will be used to handle an incoming
         federation EDU of the given type.
 
@@ -896,7 +895,7 @@ class FederationHandlerRegistry:
 
     def register_query_handler(
         self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
-    ):
+    ) -> None:
         """Sets the handler callable that will be used to handle an incoming
         federation query of the given type.
 
@@ -914,15 +913,17 @@ class FederationHandlerRegistry:
 
         self.query_handlers[query_type] = handler
 
-    def register_instance_for_edu(self, edu_type: str, instance_name: str):
+    def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
         """Register that the EDU handler is on a different instance than master."""
         self._edu_type_to_instance[edu_type] = [instance_name]
 
-    def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
+    def register_instances_for_edu(
+        self, edu_type: str, instance_names: List[str]
+    ) -> None:
         """Register that the EDU handler is on multiple instances."""
         self._edu_type_to_instance[edu_type] = instance_names
 
-    async def on_edu(self, edu_type: str, origin: str, content: dict):
+    async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
         if not self.config.use_presence and edu_type == EduTypes.Presence:
             return
 
@@ -930,7 +931,9 @@ class FederationHandlerRegistry:
         # the limit, drop them.
         if (
             edu_type == EduTypes.RoomKeyRequest
-            and not self._room_key_request_rate_limiter.can_do_action(origin)
+            and not await self._room_key_request_rate_limiter.can_do_action(
+                None, origin
+            )
         ):
             return
 
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 89df9a619b..e9c8a9f20a 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -29,6 +29,7 @@ from synapse.api.presence import UserPresenceState
 from synapse.events import EventBase
 from synapse.federation.units import Edu
 from synapse.handlers.presence import format_user_presence_state
+from synapse.logging.opentracing import SynapseTags, set_tag
 from synapse.metrics import sent_transactions_counter
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import ReadReceipt
@@ -557,6 +558,13 @@ class PerDestinationQueue:
         contents, stream_id = await self._store.get_new_device_msgs_for_remote(
             self._destination, last_device_stream_id, to_device_stream_id, limit
         )
+        for content in contents:
+            message_id = content.get("message_id")
+            if not message_id:
+                continue
+
+            set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+
         edus = [
             Edu(
                 origin=self._server_name,
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 84e39c5a46..5ef0556ef7 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -620,8 +620,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
     PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
 
     async def on_PUT(self, origin, content, query, room_id):
-        content = await self.handler.on_exchange_third_party_invite_request(content)
-        return 200, content
+        await self.handler.on_exchange_third_party_invite_request(content)
+        return 200, {}
 
 
 class FederationClientKeysQueryServlet(BaseFederationServlet):
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index aade2c4a3a..fb899aa90d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -49,7 +49,7 @@ class BaseHandler:
 
         # The rate_hz and burst_count are overridden on a per-user basis
         self.request_ratelimiter = Ratelimiter(
-            clock=self.clock, rate_hz=0, burst_count=0
+            store=self.store, clock=self.clock, rate_hz=0, burst_count=0
         )
         self._rc_message = self.hs.config.rc_message
 
@@ -57,6 +57,7 @@ class BaseHandler:
         # by the presence of rate limits in the config
         if self.hs.config.rc_admin_redaction:
             self.admin_redaction_ratelimiter = Ratelimiter(
+                store=self.store,
                 clock=self.clock,
                 rate_hz=self.hs.config.rc_admin_redaction.per_second,
                 burst_count=self.hs.config.rc_admin_redaction.burst_count,
@@ -91,11 +92,6 @@ class BaseHandler:
         if app_service is not None:
             return  # do not ratelimit app service senders
 
-        # Disable rate limiting of users belonging to any AS that is configured
-        # not to be rate limited in its registration file (rate_limited: true|false).
-        if requester.app_service and not requester.app_service.is_rate_limited():
-            return
-
         messages_per_second = self._rc_message.per_second
         burst_count = self._rc_message.burst_count
 
@@ -113,11 +109,11 @@ class BaseHandler:
         if is_admin_redaction and self.admin_redaction_ratelimiter:
             # If we have separate config for admin redactions, use a separate
             # ratelimiter as to not have user_ids clash
-            self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
+            await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
         else:
             # Override rate and burst count per-user
-            self.request_ratelimiter.ratelimit(
-                user_id,
+            await self.request_ratelimiter.ratelimit(
+                requester,
                 rate_hz=messages_per_second,
                 burst_count=burst_count,
                 update=update,
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d781bb251d..bee1447c2e 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,7 +18,7 @@ import email.utils
 import logging
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
 
 from synapse.api.errors import StoreError, SynapseError
 from synapse.logging.context import make_deferred_yieldable
@@ -241,7 +241,10 @@ class AccountValidityHandler:
         return True
 
     async def renew_account_for_user(
-        self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+        self,
+        user_id: str,
+        expiration_ts: Optional[int] = None,
+        email_sent: bool = False,
     ) -> int:
         """Renews the account attached to a given user by pushing back the
         expiration date by the current validity period in the server's
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d537ea8137..08e413bc98 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -238,6 +238,7 @@ class AuthHandler(BaseHandler):
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
         self._failed_uia_attempts_ratelimiter = Ratelimiter(
+            store=self.store,
             clock=self.clock,
             rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
@@ -248,6 +249,7 @@ class AuthHandler(BaseHandler):
 
         # Ratelimitier for failed /login attempts
         self._failed_login_attempts_ratelimiter = Ratelimiter(
+            store=self.store,
             clock=hs.get_clock(),
             rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
@@ -352,7 +354,7 @@ class AuthHandler(BaseHandler):
         requester_user_id = requester.user.to_string()
 
         # Check if we should be ratelimited due to too many previous failed attempts
-        self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
+        await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False)
 
         # build a list of supported flows
         supported_ui_auth_types = await self._get_available_ui_auth_types(
@@ -373,7 +375,9 @@ class AuthHandler(BaseHandler):
             )
         except LoginError:
             # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
-            self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
+            await self._failed_uia_attempts_ratelimiter.can_do_action(
+                requester,
+            )
             raise
 
         # find the completed login type
@@ -982,8 +986,8 @@ class AuthHandler(BaseHandler):
             # We also apply account rate limiting using the 3PID as a key, as
             # otherwise using 3PID bypasses the ratelimiting based on user ID.
             if ratelimit:
-                self._failed_login_attempts_ratelimiter.ratelimit(
-                    (medium, address), update=False
+                await self._failed_login_attempts_ratelimiter.ratelimit(
+                    None, (medium, address), update=False
                 )
 
             # Check for login providers that support 3pid login types
@@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler):
                 # this code path, which is fine as then the per-user ratelimit
                 # will kick in below.
                 if ratelimit:
-                    self._failed_login_attempts_ratelimiter.can_do_action(
-                        (medium, address)
+                    await self._failed_login_attempts_ratelimiter.can_do_action(
+                        None, (medium, address)
                     )
                 raise LoginError(403, "", errcode=Codes.FORBIDDEN)
 
@@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler):
 
         # Check if we've hit the failed ratelimit (but don't update it)
         if ratelimit:
-            self._failed_login_attempts_ratelimiter.ratelimit(
-                qualified_user_id.lower(), update=False
+            await self._failed_login_attempts_ratelimiter.ratelimit(
+                None, qualified_user_id.lower(), update=False
             )
 
         try:
@@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler):
             # exception and masking the LoginError. The actual ratelimiting
             # should have happened above.
             if ratelimit:
-                self._failed_login_attempts_ratelimiter.can_do_action(
-                    qualified_user_id.lower()
+                await self._failed_login_attempts_ratelimiter.can_do_action(
+                    None, qualified_user_id.lower()
                 )
             raise
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index eb547743be..c971eeb4d2 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -21,10 +21,10 @@ from synapse.api.errors import SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
+    SynapseTags,
     get_active_span_text_map,
     log_kv,
     set_tag,
-    start_active_span,
 )
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
@@ -81,6 +81,7 @@ class DeviceMessageHandler:
             )
 
         self._ratelimiter = Ratelimiter(
+            store=self.store,
             clock=hs.get_clock(),
             rate_hz=hs.config.rc_key_requests.per_second,
             burst_count=hs.config.rc_key_requests.burst_count,
@@ -182,7 +183,10 @@ class DeviceMessageHandler:
     ) -> None:
         sender_user_id = requester.user.to_string()
 
-        set_tag("number_of_messages", len(messages))
+        message_id = random_string(16)
+        set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+
+        log_kv({"number_of_to_device_messages": len(messages)})
         set_tag("sender", sender_user_id)
         local_messages = {}
         remote_messages = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
@@ -191,8 +195,8 @@ class DeviceMessageHandler:
             if (
                 message_type == EduTypes.RoomKeyRequest
                 and user_id != sender_user_id
-                and self._ratelimiter.can_do_action(
-                    (sender_user_id, requester.device_id)
+                and await self._ratelimiter.can_do_action(
+                    requester, (sender_user_id, requester.device_id)
                 )
             ):
                 continue
@@ -204,32 +208,35 @@ class DeviceMessageHandler:
                         "content": message_content,
                         "type": message_type,
                         "sender": sender_user_id,
+                        "message_id": message_id,
                     }
                     for device_id, message_content in by_device.items()
                 }
                 if messages_by_device:
                     local_messages[user_id] = messages_by_device
+                    log_kv(
+                        {
+                            "user_id": user_id,
+                            "device_id": list(messages_by_device),
+                        }
+                    )
             else:
                 destination = get_domain_from_id(user_id)
                 remote_messages.setdefault(destination, {})[user_id] = by_device
 
-        message_id = random_string(16)
-
         context = get_active_span_text_map()
 
         remote_edu_contents = {}
         for destination, messages in remote_messages.items():
-            with start_active_span("to_device_for_user"):
-                set_tag("destination", destination)
-                remote_edu_contents[destination] = {
-                    "messages": messages,
-                    "sender": sender_user_id,
-                    "type": message_type,
-                    "message_id": message_id,
-                    "org.matrix.opentracing_context": json_encoder.encode(context),
-                }
+            log_kv({"destination": destination})
+            remote_edu_contents[destination] = {
+                "messages": messages,
+                "sender": sender_user_id,
+                "type": message_type,
+                "message_id": message_id,
+                "org.matrix.opentracing_context": json_encoder.encode(context),
+            }
 
-        log_kv({"local_messages": local_messages})
         stream_id = await self.store.add_messages_to_device_inbox(
             local_messages, remote_edu_contents
         )
@@ -238,7 +245,6 @@ class DeviceMessageHandler:
             "to_device_key", stream_id, users=local_messages.keys()
         )
 
-        log_kv({"remote_messages": remote_messages})
         if self.federation_sender:
             for destination in remote_messages.keys():
                 # Enqueue a new federation transaction to send the new
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 2ad9b6d930..739653a3fa 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1008,7 +1008,7 @@ class E2eKeysHandler:
         return signature_list, failures
 
     async def _get_e2e_cross_signing_verify_key(
-        self, user_id: str, key_type: str, from_user_id: str = None
+        self, user_id: str, key_type: str, from_user_id: Optional[str] = None
     ) -> Tuple[JsonDict, str, VerifyKey]:
         """Fetch locally or remotely query for a cross-signing public key.
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 598a66f74c..5ea8a7b603 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,7 +21,17 @@ import itertools
 import logging
 from collections.abc import Container
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Union,
+)
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
-    async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
+    async def on_receive_pdu(
+        self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
+    ) -> None:
         """Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
         Args:
-            origin (str): server which initiated the /send/ transaction. Will
+            origin: server which initiated the /send/ transaction. Will
                 be used to fetch missing events or state.
-            pdu (FrozenEvent): received PDU
-            sent_to_us_directly (bool): True if this event was pushed to us; False if
+            pdu: received PDU
+            sent_to_us_directly: True if this event was pushed to us; False if
                 we pulled it as the result of a missing prev_event.
         """
 
@@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
 
         await self._process_received_pdu(origin, pdu, state=state)
 
-    async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+    async def _get_missing_events_for_pdu(
+        self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
+    ) -> None:
         """
         Args:
-            origin (str): Origin of the pdu. Will be called to get the missing events
+            origin: Origin of the pdu. Will be called to get the missing events
             pdu: received pdu
-            prevs (set(str)): List of event ids which we are missing
-            min_depth (int): Minimum depth of events to return.
+            prevs: List of event ids which we are missing
+            min_depth: Minimum depth of events to return.
         """
 
         room_id = pdu.room_id
@@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
-    ):
+    ) -> None:
         """Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
 
@@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
             logger.exception("Failed to resync device for %s", sender)
 
     @log_function
-    async def backfill(self, dest, room_id, limit, extremities):
+    async def backfill(
+        self, dest: str, room_id: str, limit: int, extremities: List[str]
+    ) -> List[EventBase]:
         """Trigger a backfill request to `dest` for the given `room_id`
 
         This will attempt to get more events from the remote. If the other side
@@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
 
         curr_state = await self.state_handler.get_current_state(room_id)
 
-        def get_domains_from_state(state):
+        def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
             """Get joined domains from state
 
             Args:
-                state (dict[tuple, FrozenEvent]): State map from type/state
-                    key to event.
+                state: State map from type/state key to event.
 
             Returns:
-                list[tuple[str, int]]: Returns a list of servers with the
-                lowest depth of their joins. Sorted by lowest depth first.
+                Returns a list of servers with the lowest depth of their joins.
+                 Sorted by lowest depth first.
             """
             joined_users = [
                 (state_key, int(event.depth))
@@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
             domain for domain, depth in curr_domains if domain != self.server_name
         ]
 
-        async def try_backfill(domains):
+        async def try_backfill(domains: List[str]) -> bool:
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
@@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
         }
 
         for e_id, _ in sorted_extremeties_tuple:
-            likely_domains = get_domains_from_state(states[e_id])
+            likely_extremeties_domains = get_domains_from_state(states[e_id])
 
             success = await try_backfill(
-                [dom for dom, _ in likely_domains if dom not in tried_domains]
+                [
+                    dom
+                    for dom, _ in likely_extremeties_domains
+                    if dom not in tried_domains
+                ]
             )
             if success:
                 return True
 
-            tried_domains.update(dom for dom, _ in likely_domains)
+            tried_domains.update(dom for dom, _ in likely_extremeties_domains)
 
         return False
 
     async def _get_events_and_persist(
         self, destination: str, room_id: str, events: Iterable[str]
-    ):
+    ) -> None:
         """Fetch the given events from a server, and persist them as outliers.
 
         This function *does not* recursively get missing auth events of the
@@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
             event_infos,
         )
 
-    def _sanity_check_event(self, ev):
+    def _sanity_check_event(self, ev: EventBase) -> None:
         """
         Do some early sanity checks of a received event
 
@@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
         or cascade of event fetches.
 
         Args:
-            ev (synapse.events.EventBase): event to be checked
-
-        Returns: None
+            ev: event to be checked
 
         Raises:
             SynapseError if the event does not pass muster
@@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
             )
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
 
-    async def send_invite(self, target_host, event):
+    async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
         """Sends the invite to the remote server for signing.
 
         Invites must be signed by the invitee's server before distribution.
@@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
 
             run_in_background(self._handle_queued_pdus, room_queue)
 
-    async def _handle_queued_pdus(self, room_queue):
+    async def _handle_queued_pdus(
+        self, room_queue: List[Tuple[EventBase, str]]
+    ) -> None:
         """Process PDUs which got queued up while we were busy send_joining.
 
         Args:
-            room_queue (list[FrozenEvent, str]): list of PDUs to be processed
-                and the servers that sent them
+            room_queue: list of PDUs to be processed and the servers that sent them
         """
         for p, origin in room_queue:
             try:
@@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
 
         return event
 
-    async def on_send_join_request(self, origin, pdu):
+    async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
         """We have received a join event for a room. Fully process it and
         respond with the current state and auth chains.
         """
@@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
 
     async def on_invite_request(
         self, origin: str, event: EventBase, room_version: RoomVersion
-    ):
+    ) -> EventBase:
         """We've got an invite event. Process and persist it. Sign it.
 
         Respond with the now signed event.
@@ -1711,7 +1729,7 @@ class FederationHandler(BaseHandler):
         member_handler = self.hs.get_room_member_handler()
         # We don't rate limit based on room ID, as that should be done by
         # sending server.
-        member_handler.ratelimit_invite(None, event.state_key)
+        await member_handler.ratelimit_invite(None, None, event.state_key)
 
         # keep a record of the room version, if we don't yet know it.
         # (this may get overwritten if we later get a different room version in a
@@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
 
         return event
 
-    async def on_send_leave_request(self, origin, pdu):
+    async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
         """ We have received a leave event for a room. Fully process it."""
         event = pdu
 
@@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
         else:
             return None
 
-    async def get_min_depth_for_context(self, context):
+    async def get_min_depth_for_context(self, context: str) -> int:
         return await self.store.get_min_depth(context)
 
     async def _handle_new_event(
-        self, origin, event, state=None, auth_events=None, backfilled=False
-    ):
+        self,
+        origin: str,
+        event: EventBase,
+        state: Optional[Iterable[EventBase]] = None,
+        auth_events: Optional[MutableStateMap[EventBase]] = None,
+        backfilled: bool = False,
+    ) -> EventContext:
         context = await self._prep_event(
             origin, event, state=state, auth_events=auth_events, backfilled=backfilled
         )
@@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
             logger.warning("Soft-failing %r because %s", event, e)
             event.internal_metadata.soft_failed = True
 
-    async def on_query_auth(
-        self, origin, event_id, room_id, remote_auth_chain, rejects, missing
-    ):
-        in_room = await self.auth.check_host_in_room(room_id, origin)
-        if not in_room:
-            raise AuthError(403, "Host not in room.")
-
-        event = await self.store.get_event(event_id, check_room_id=room_id)
-
-        # Just go through and process each event in `remote_auth_chain`. We
-        # don't want to fall into the trap of `missing` being wrong.
-        for e in remote_auth_chain:
-            try:
-                await self._handle_new_event(origin, e)
-            except AuthError:
-                pass
-
-        # Now get the current auth_chain for the event.
-        local_auth_chain = await self.store.get_auth_chain(
-            room_id, list(event.auth_event_ids()), include_given=True
-        )
-
-        # TODO: Check if we would now reject event_id. If so we need to tell
-        # everyone.
-
-        ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
-
-        logger.debug("on_query_auth returning: %s", ret)
-
-        return ret
-
     async def on_get_missing_events(
-        self, origin, room_id, earliest_events, latest_events, limit
-    ):
+        self,
+        origin: str,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> List[EventBase]:
         in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
@@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
         assumes that we have already processed all events in remote_auth
 
         Params:
-            local_auth (list)
-            remote_auth (list)
+            local_auth
+            remote_auth
 
         Returns:
             dict
@@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
 
     @log_function
     async def exchange_third_party_invite(
-        self, sender_user_id, target_user_id, room_id, signed
-    ):
+        self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
+    ) -> None:
         third_party_invite = {"signed": signed}
 
         event_dict = {
@@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
         await member_handler.send_membership_event(None, event, context)
 
     async def add_display_name_to_third_party_invite(
-        self, room_version, event_dict, event, context
-    ):
+        self,
+        room_version: str,
+        event_dict: JsonDict,
+        event: EventBase,
+        context: EventContext,
+    ) -> Tuple[EventBase, EventContext]:
         key = (
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"],
@@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
         EventValidator().validate_new(event, self.config)
         return (event, context)
 
-    async def _check_signature(self, event, context):
+    async def _check_signature(self, event: EventBase, context: EventContext) -> None:
         """
         Checks that the signature in the event is consistent with its invite.
 
         Args:
-            event (Event): The m.room.member event to check
-            context (EventContext):
+            event: The m.room.member event to check
+            context:
 
         Raises:
             AuthError: if signature didn't match any keys, or key has been
@@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
 
         raise last_exception
 
-    async def _check_key_revocation(self, public_key, url):
+    async def _check_key_revocation(self, public_key: str, url: str) -> None:
         """
         Checks whether public_key has been revoked.
 
         Args:
-            public_key (str): base-64 encoded public key.
-            url (str): Key revocation URL.
+            public_key: base-64 encoded public key.
+            url: Key revocation URL.
 
         Raises:
             AuthError: if they key has been revoked.
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 5f346f6d6d..d89fa5fb30 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler):
 
         # Ratelimiters for `/requestToken` endpoints.
         self._3pid_validation_ratelimiter_ip = Ratelimiter(
+            store=self.store,
             clock=hs.get_clock(),
             rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
             burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
         )
         self._3pid_validation_ratelimiter_address = Ratelimiter(
+            store=self.store,
             clock=hs.get_clock(),
             rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
             burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
         )
 
-    def ratelimit_request_token_requests(
+    async def ratelimit_request_token_requests(
         self,
         request: SynapseRequest,
         medium: str,
@@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler):
             address: The actual threepid ID, e.g. the phone number or email address
         """
 
-        self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
-        self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+        await self._3pid_validation_ratelimiter_ip.ratelimit(
+            None, (medium, request.getClientIP())
+        )
+        await self._3pid_validation_ratelimiter_address.ratelimit(
+            None, (medium, address)
+        )
 
     async def threepid_from_creds(
         self, id_server: str, creds: Dict[str, str]
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1b7c065b34..6069968f7f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -385,7 +385,7 @@ class EventCreationHandler:
         self._events_shard_config = self.config.worker.events_shard_config
         self._instance_name = hs.get_instance_name()
 
-        self.room_invite_state_types = self.hs.config.room_invite_state_types
+        self.room_invite_state_types = self.hs.config.api.room_prejoin_state
 
         self.membership_types_to_include_profile_data_in = (
             {Membership.JOIN, Membership.INVITE}
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0fc2bf15d5..9701b76d0f 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -204,7 +204,7 @@ class RegistrationHandler(BaseHandler):
         Raises:
             SynapseError if there was a problem registering.
         """
-        self.check_registration_ratelimit(address)
+        await self.check_registration_ratelimit(address)
 
         result = await self.spam_checker.check_registration_for_spam(
             threepid,
@@ -583,7 +583,7 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE,
                 )
 
-    def check_registration_ratelimit(self, address: Optional[str]) -> None:
+    async def check_registration_ratelimit(self, address: Optional[str]) -> None:
         """A simple helper method to check whether the registration rate limit has been hit
         for a given IP address
 
@@ -597,7 +597,7 @@ class RegistrationHandler(BaseHandler):
         if not address:
             return
 
-        self.ratelimiter.ratelimit(address)
+        await self.ratelimiter.ratelimit(None, address)
 
     async def register_with_store(
         self,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4d20ed8357..1cf12f3255 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -75,22 +75,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.allow_per_room_profiles = self.config.allow_per_room_profiles
 
         self._join_rate_limiter_local = Ratelimiter(
+            store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
         )
         self._join_rate_limiter_remote = Ratelimiter(
+            store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
 
         self._invites_per_room_limiter = Ratelimiter(
+            store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
             burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
         )
         self._invites_per_user_limiter = Ratelimiter(
+            store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
             burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
@@ -159,15 +163,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
     async def forget(self, user: UserID, room_id: str) -> None:
         raise NotImplementedError()
 
-    def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
+    async def ratelimit_invite(
+        self,
+        requester: Optional[Requester],
+        room_id: Optional[str],
+        invitee_user_id: str,
+    ):
         """Ratelimit invites by room and by target user.
 
         If room ID is missing then we just rate limit by target user.
         """
         if room_id:
-            self._invites_per_room_limiter.ratelimit(room_id)
+            await self._invites_per_room_limiter.ratelimit(requester, room_id)
 
-        self._invites_per_user_limiter.ratelimit(invitee_user_id)
+        await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
 
     async def _local_membership_update(
         self,
@@ -237,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 (
                     allowed,
                     time_allowed,
-                ) = self._join_rate_limiter_local.can_requester_do_action(requester)
+                ) = await self._join_rate_limiter_local.can_do_action(requester)
 
                 if not allowed:
                     raise LimitExceededError(
@@ -421,9 +430,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         if effective_membership_state == Membership.INVITE:
             target_id = target.to_string()
             if ratelimit:
-                # Don't ratelimit application services.
-                if not requester.app_service or requester.app_service.is_rate_limited():
-                    self.ratelimit_invite(room_id, target_id)
+                await self.ratelimit_invite(requester, room_id, target_id)
 
             # block any attempts to invite the server notices mxid
             if target_id == self._server_notices_mxid:
@@ -534,7 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     (
                         allowed,
                         time_allowed,
-                    ) = self._join_rate_limiter_remote.can_requester_do_action(
+                    ) = await self._join_rate_limiter_remote.can_do_action(
                         requester,
                     )
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ee607e6e65..7b356ba7e5 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -24,6 +24,7 @@ from synapse.api.constants import AccountDataTypes, EventTypes, Membership
 from synapse.api.filtering import FilterCollection
 from synapse.events import EventBase
 from synapse.logging.context import current_context
+from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
@@ -340,7 +341,14 @@ class SyncHandler:
         full_state: bool = False,
     ) -> SyncResult:
         """Get the sync for client needed to match what the server has now."""
-        return await self.generate_sync_result(sync_config, since_token, full_state)
+        with start_active_span("current_sync_for_user"):
+            log_kv({"since_token": since_token})
+            sync_result = await self.generate_sync_result(
+                sync_config, since_token, full_state
+            )
+
+            set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
+            return sync_result
 
     async def push_rules_for_user(self, user: UserID) -> JsonDict:
         user_id = user.to_string()
@@ -964,6 +972,7 @@ class SyncHandler:
         # to query up to a given point.
         # Always use the `now_token` in `SyncResultBuilder`
         now_token = self.event_sources.get_current_token()
+        log_kv({"now_token": now_token})
 
         logger.debug(
             "Calculating sync response for %r between %s and %s",
@@ -1225,6 +1234,13 @@ class SyncHandler:
                 user_id, device_id, since_stream_id, now_token.to_device_key
             )
 
+            for message in messages:
+                # We pop here as we shouldn't be sending the message ID down
+                # `/sync`
+                message_id = message.pop("message_id", None)
+                if message_id:
+                    set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+
             logger.debug(
                 "Returning %d to-device messages between %d and %d (current token: %d)",
                 len(messages),
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 096d199f4c..bb35af099d 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -19,7 +19,10 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
 from synapse.appservice import ApplicationService
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.replication.tcp.streams import TypingStream
 from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -86,6 +89,7 @@ class FollowerTypingHandler:
         self._member_last_federation_poke = {}
         self.wheel_timer = WheelTimer(bucket_size=5000)
 
+    @wrap_as_background_process("typing._handle_timeouts")
     def _handle_timeouts(self) -> None:
         logger.debug("Checking for typing timeouts")
 
diff --git a/synapse/http/client.py b/synapse/http/client.py
index a0caba84e4..e691ba6d88 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -590,7 +590,7 @@ class SimpleHttpClient:
         uri: str,
         json_body: Any,
         args: Optional[QueryParams] = None,
-        headers: RawHeaders = None,
+        headers: Optional[RawHeaders] = None,
     ) -> Any:
         """Puts some json to the given URI.
 
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index aa146e8bb8..b8081f197e 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -259,6 +259,14 @@ except ImportError:
 logger = logging.getLogger(__name__)
 
 
+class SynapseTags:
+    # The message ID of any to_device message processed
+    TO_DEVICE_MESSAGE_ID = "to_device.message_id"
+
+    # Whether the sync response has new data to be returned to the client.
+    SYNC_RESULT = "sync.new_data"
+
+
 # Block everything by default
 # A regex which matches the server_names to expose traces for.
 # None means 'block everything'.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 1374aae490..c178db57e3 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -39,6 +39,7 @@ from synapse.api.errors import AuthError
 from synapse.events import EventBase
 from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.opentracing import log_kv, start_active_span
 from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
 from synapse.streams.config import PaginationConfig
@@ -136,6 +137,15 @@ class _NotifierUserStream:
         self.last_notified_ms = time_now_ms
         noify_deferred = self.notify_deferred
 
+        log_kv(
+            {
+                "notify": self.user_id,
+                "stream": stream_key,
+                "stream_id": stream_id,
+                "listeners": self.count_listeners(),
+            }
+        )
+
         users_woken_by_stream_counter.labels(stream_key).inc()
 
         with PreserveLoggingContext():
@@ -404,6 +414,13 @@ class Notifier:
         with Measure(self.clock, "on_new_event"):
             user_streams = set()
 
+            log_kv(
+                {
+                    "waking_up_explicit_users": len(users),
+                    "waking_up_explicit_rooms": len(rooms),
+                }
+            )
+
             for user in users:
                 user_stream = self.user_to_user_stream.get(str(user))
                 if user_stream is not None:
@@ -476,12 +493,34 @@ class Notifier:
                         (end_time - now) / 1000.0,
                         self.hs.get_reactor(),
                     )
-                    with PreserveLoggingContext():
-                        await listener.deferred
+
+                    with start_active_span("wait_for_events.deferred"):
+                        log_kv(
+                            {
+                                "wait_for_events": "sleep",
+                                "token": prev_token,
+                            }
+                        )
+
+                        with PreserveLoggingContext():
+                            await listener.deferred
+
+                        log_kv(
+                            {
+                                "wait_for_events": "woken",
+                                "token": user_stream.current_token,
+                            }
+                        )
 
                     current_token = user_stream.current_token
 
                     result = await callback(prev_token, current_token)
+                    log_kv(
+                        {
+                            "wait_for_events": "result",
+                            "result": bool(result),
+                        }
+                    )
                     if result:
                         break
 
@@ -489,8 +528,10 @@ class Notifier:
                     # has happened between the old prev_token and the current_token
                     prev_token = current_token
                 except defer.TimeoutError:
+                    log_kv({"wait_for_events": "timeout"})
                     break
                 except defer.CancelledError:
+                    log_kv({"wait_for_events": "cancelled"})
                     break
 
         if result is None:
@@ -507,7 +548,7 @@ class Notifier:
         pagination_config: PaginationConfig,
         timeout: int,
         is_guest: bool = False,
-        explicit_room_id: str = None,
+        explicit_room_id: Optional[str] = None,
     ) -> EventStreamResult:
         """For the given user and rooms, return any new events for them. If
         there are no new events wait for up to `timeout` milliseconds for any
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index d005f38767..73d7477854 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
     async def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
-        self.registration_handler.check_registration_ratelimit(content["address"])
+        await self.registration_handler.check_registration_ratelimit(content["address"])
 
         await self.registration_handler.register_with_store(
             user_id=user_id,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 2f4d407f94..98bdeb0ec6 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -60,7 +60,7 @@ class ConstantProperty(Generic[T, V]):
 
     constant = attr.ib()  # type: V
 
-    def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V:
+    def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
         return self.constant
 
     def __set__(self, obj: Optional[T], value: V):
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 309bd2771b..fa7804583a 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -36,6 +36,7 @@ from synapse.rest.admin._base import (
 )
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.storage.databases.main.media_repository import MediaSortOrder
+from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
@@ -117,8 +118,26 @@ class UsersRestServletV2(RestServlet):
         guests = parse_boolean(request, "guests", default=True)
         deactivated = parse_boolean(request, "deactivated", default=False)
 
+        order_by = parse_string(
+            request,
+            "order_by",
+            default=UserSortOrder.NAME.value,
+            allowed_values=(
+                UserSortOrder.NAME.value,
+                UserSortOrder.DISPLAYNAME.value,
+                UserSortOrder.GUEST.value,
+                UserSortOrder.ADMIN.value,
+                UserSortOrder.DEACTIVATED.value,
+                UserSortOrder.USER_TYPE.value,
+                UserSortOrder.AVATAR_URL.value,
+                UserSortOrder.SHADOW_BANNED.value,
+            ),
+        )
+
+        direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+
         users, total = await self.store.get_users_paginate(
-            start, limit, user_id, name, guests, deactivated
+            start, limit, user_id, name, guests, deactivated, order_by, direction
         )
         ret = {"users": users, "total": total}
         if (start + limit) < total:
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index e4c352f572..3151e72d4f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet):
 
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
+            store=hs.get_datastore(),
             clock=hs.get_clock(),
             rate_hz=self.hs.config.rc_login_address.per_second,
             burst_count=self.hs.config.rc_login_address.burst_count,
         )
         self._account_ratelimiter = Ratelimiter(
+            store=hs.get_datastore(),
             clock=hs.get_clock(),
             rate_hz=self.hs.config.rc_login_account.per_second,
             burst_count=self.hs.config.rc_login_account.burst_count,
@@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet):
                 appservice = self.auth.get_appservice_by_req(request)
 
                 if appservice.is_rate_limited():
-                    self._address_ratelimiter.ratelimit(request.getClientIP())
+                    await self._address_ratelimiter.ratelimit(
+                        None, request.getClientIP()
+                    )
 
                 result = await self._do_appservice_login(login_submission, appservice)
             elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
-                self._address_ratelimiter.ratelimit(request.getClientIP())
+                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
                 result = await self._do_jwt_login(login_submission)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
-                self._address_ratelimiter.ratelimit(request.getClientIP())
+                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
                 result = await self._do_token_login(login_submission)
             else:
-                self._address_ratelimiter.ratelimit(request.getClientIP())
+                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
                 result = await self._do_other_login(login_submission)
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
@@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet):
         # too often. This happens here rather than before as we don't
         # necessarily know the user before now.
         if ratelimit:
-            self._account_ratelimiter.ratelimit(user_id.lower())
+            await self._account_ratelimiter.ratelimit(None, user_id.lower())
 
         if create_non_existent_users:
             canonical_uid = await self.auth_handler.check_user_exists(user_id)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index c2ba790bab..411fb57c47 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -103,7 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
 
-        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+        await self.identity_handler.ratelimit_request_token_requests(
+            request, "email", email
+        )
 
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
@@ -387,7 +389,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
-        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+        await self.identity_handler.ratelimit_request_token_requests(
+            request, "email", email
+        )
 
         if next_link:
             # Raise if the provided next_link value isn't valid
@@ -468,7 +472,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
-        self.identity_handler.ratelimit_request_token_requests(
+        await self.identity_handler.ratelimit_request_token_requests(
             request, "msisdn", msisdn
         )
 
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 8f68d8dfc8..c212da0cb2 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -126,7 +126,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
-        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+        await self.identity_handler.ratelimit_request_token_requests(
+            request, "email", email
+        )
 
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "email", email
@@ -208,7 +210,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
-        self.identity_handler.ratelimit_request_token_requests(
+        await self.identity_handler.ratelimit_request_token_requests(
             request, "msisdn", msisdn
         )
 
@@ -406,7 +408,7 @@ class RegisterRestServlet(RestServlet):
 
         client_addr = request.getClientIP()
 
-        self.ratelimiter.ratelimit(client_addr, update=False)
+        await self.ratelimiter.ratelimit(None, client_addr, update=False)
 
         kind = b"user"
         if b"kind" in request.args:
diff --git a/synapse/server.py b/synapse/server.py
index e85b9391fa..e42f7b1a18 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -329,6 +329,7 @@ class HomeServer(metaclass=abc.ABCMeta):
     @cache_in_self
     def get_registration_ratelimiter(self) -> Ratelimiter:
         return Ratelimiter(
+            store=self.get_datastore(),
             clock=self.get_clock(),
             rate_hz=self.config.rc_registration.per_second,
             burst_count=self.config.rc_registration.burst_count,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1d44c3aa2c..b3d16ca7ac 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -21,6 +21,7 @@ from typing import List, Optional, Tuple
 from synapse.api.constants import PresenceState
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     IdGenerator,
@@ -292,6 +293,8 @@ class DataStore(
         name: Optional[str] = None,
         guests: bool = True,
         deactivated: bool = False,
+        order_by: UserSortOrder = UserSortOrder.USER_ID.value,
+        direction: str = "f",
     ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users and the
@@ -304,6 +307,8 @@ class DataStore(
             name: search for local part of user_id or display name
             guests: whether to in include guest users
             deactivated: whether to include deactivated users
+            order_by: the sort order of the returned list
+            direction: sort ascending or descending
         Returns:
             A tuple of a list of mappings from user to information and a count of total users.
         """
@@ -312,6 +317,14 @@ class DataStore(
             filters = []
             args = [self.hs.config.server_name]
 
+            # Set ordering
+            order_by_column = UserSortOrder(order_by).value
+
+            if direction == "b":
+                order = "DESC"
+            else:
+                order = "ASC"
+
             # `name` is in database already in lower case
             if name:
                 filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
@@ -339,10 +352,15 @@ class DataStore(
             txn.execute(sql, args)
             count = txn.fetchone()[0]
 
-            sql = (
-                "SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url "
-                + sql_base
-                + " ORDER BY u.name LIMIT ? OFFSET ?"
+            sql = """
+                SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url
+                {sql_base}
+                ORDER BY {order_by_column} {order}, u.name ASC
+                LIMIT ? OFFSET ?
+            """.format(
+                sql_base=sql_base,
+                order_by_column=order_by_column,
+                order=order,
             )
             args += [limit, start]
             txn.execute(sql, args)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 952d4969b2..c00780969f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -16,7 +16,7 @@
 import logging
 import threading
 from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple, overload
+from typing import Container, Dict, Iterable, List, Optional, Tuple, overload
 
 from constantly import NamedConstant, Names
 from typing_extensions import Literal
@@ -544,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore):
     async def get_stripped_room_state_from_event_context(
         self,
         context: EventContext,
-        state_types_to_include: List[EventTypes],
+        state_types_to_include: Container[str],
         membership_user_id: Optional[str] = None,
     ) -> List[JsonDict]:
         """
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ac07e0197b..8f462dfc31 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1027,8 +1027,8 @@ class GroupServerStore(GroupServerWorkerStore):
         user_id: str,
         is_admin: bool = False,
         is_public: bool = True,
-        local_attestation: dict = None,
-        remote_attestation: dict = None,
+        local_attestation: Optional[dict] = None,
+        remote_attestation: Optional[dict] = None,
     ) -> None:
         """Add a user to the group server.
 
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 1c99393c65..bce8946c21 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -66,18 +66,37 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
 class UserSortOrder(Enum):
     """
     Enum to define the sorting method used when returning users
-    with get_users_media_usage_paginate
+    with get_users_paginate in __init__.py
+    and get_users_media_usage_paginate in stats.py
 
-    MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest.
-    MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest.
+    When moves this to __init__.py gets `builtins.ImportError` with
+    `most likely due to a circular import`
+
+    MEDIA_LENGTH = ordered by size of uploaded media.
+    MEDIA_COUNT = ordered by number of uploaded media.
     USER_ID = ordered alphabetically by `user_id`.
+    NAME = ordered alphabetically by `user_id`. This is for compatibility reasons,
+    as the user_id is returned in the name field in the response in list users admin API.
     DISPLAYNAME = ordered alphabetically by `displayname`
+    GUEST = ordered by `is_guest`
+    ADMIN = ordered by `admin`
+    DEACTIVATED = ordered by `deactivated`
+    USER_TYPE = ordered alphabetically by `user_type`
+    AVATAR_URL = ordered alphabetically by `avatar_url`
+    SHADOW_BANNED = ordered by `shadow_banned`
     """
 
     MEDIA_LENGTH = "media_length"
     MEDIA_COUNT = "media_count"
     USER_ID = "user_id"
+    NAME = "name"
     DISPLAYNAME = "displayname"
+    GUEST = "is_guest"
+    ADMIN = "admin"
+    DEACTIVATED = "deactivated"
+    USER_TYPE = "user_type"
+    AVATAR_URL = "avatar_url"
+    SHADOW_BANNED = "shadow_banned"
 
 
 class StatsStore(StateDeltasStore):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6c3c2da520..c7f0b8ccb5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import imp
+import importlib.util
 import logging
 import os
 import re
@@ -454,8 +454,13 @@ def _upgrade_existing_database(
                     )
 
                 module_name = "synapse.storage.v%d_%s" % (v, root_name)
-                with open(absolute_path) as python_file:
-                    module = imp.load_source(module_name, absolute_path, python_file)  # type: ignore
+
+                spec = importlib.util.spec_from_file_location(
+                    module_name, absolute_path
+                )
+                module = importlib.util.module_from_spec(spec)
+                spec.loader.exec_module(module)  # type: ignore
+
                 logger.info("Running script %s", relative_path)
                 module.run_create(cur, database_engine)  # type: ignore
                 if not is_empty:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1adc92eb90..dd392cf694 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -283,7 +283,9 @@ class DeferredCache(Generic[KT, VT]):
         # we return a new Deferred which will be called before any subsequent observers.
         return observable.observe()
 
-    def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
+    def prefill(
+        self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
+    ):
         callbacks = [callback] if callback else []
         self.cache.set(key, value, callbacks=callbacks)
 
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483418192c..fa96ba07a5 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -5,38 +5,25 @@ from synapse.types import create_requester
 from tests import unittest
 
 
-class TestRatelimiter(unittest.TestCase):
+class TestRatelimiter(unittest.HomeserverTestCase):
     def test_allowed_via_can_do_action(self):
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
-        self.assertTrue(allowed)
-        self.assertEquals(10.0, time_allowed)
-
-        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
-        self.assertFalse(allowed)
-        self.assertEquals(10.0, time_allowed)
-
-        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
-        self.assertTrue(allowed)
-        self.assertEquals(20.0, time_allowed)
-
-    def test_allowed_user_via_can_requester_do_action(self):
-        user_requester = create_requester("@user:example.com")
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_requester_do_action(
-            user_requester, _time_now_s=0
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id", _time_now_s=0)
         )
         self.assertTrue(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            user_requester, _time_now_s=5
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id", _time_now_s=5)
         )
         self.assertFalse(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            user_requester, _time_now_s=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id", _time_now_s=10)
         )
         self.assertTrue(allowed)
         self.assertEquals(20.0, time_allowed)
@@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase):
         )
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=0
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=0)
         )
         self.assertTrue(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=5
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=5)
         )
         self.assertFalse(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=10)
         )
         self.assertTrue(allowed)
         self.assertEquals(20.0, time_allowed)
@@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase):
         )
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=0
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=0)
         )
         self.assertTrue(allowed)
         self.assertEquals(-1, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=5
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=5)
         )
         self.assertTrue(allowed)
         self.assertEquals(-1, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=10)
         )
         self.assertTrue(allowed)
         self.assertEquals(-1, time_allowed)
 
     def test_allowed_via_ratelimit(self):
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
 
         # Shouldn't raise
-        limiter.ratelimit(key="test_id", _time_now_s=0)
+        self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
 
         # Should raise
         with self.assertRaises(LimitExceededError) as context:
-            limiter.ratelimit(key="test_id", _time_now_s=5)
+            self.get_success_or_raise(
+                limiter.ratelimit(None, key="test_id", _time_now_s=5)
+            )
         self.assertEqual(context.exception.retry_after_ms, 5000)
 
         # Shouldn't raise
-        limiter.ratelimit(key="test_id", _time_now_s=10)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key="test_id", _time_now_s=10)
+        )
 
     def test_allowed_via_can_do_action_and_overriding_parameters(self):
         """Test that we can override options of can_do_action that would otherwise fail
         an action
         """
         # Create a Ratelimiter with a very low allowed rate_hz and burst_count
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
 
         # First attempt should be allowed
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",),
-            _time_now_s=0,
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(
+                None,
+                ("test_id",),
+                _time_now_s=0,
+            )
         )
         self.assertTrue(allowed)
         self.assertEqual(10.0, time_allowed)
 
         # Second attempt, 1s later, will fail
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",),
-            _time_now_s=1,
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(
+                None,
+                ("test_id",),
+                _time_now_s=1,
+            )
         )
         self.assertFalse(allowed)
         self.assertEqual(10.0, time_allowed)
 
         # But, if we allow 10 actions/sec for this request, we should be allowed
         # to continue.
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",), _time_now_s=1, rate_hz=10.0
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
         )
         self.assertTrue(allowed)
         self.assertEqual(1.1, time_allowed)
 
         # Similarly if we allow a burst of 10 actions
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",), _time_now_s=1, burst_count=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
         )
         self.assertTrue(allowed)
         self.assertEqual(1.0, time_allowed)
@@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase):
         fail an action
         """
         # Create a Ratelimiter with a very low allowed rate_hz and burst_count
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
 
         # First attempt should be allowed
-        limiter.ratelimit(key=("test_id",), _time_now_s=0)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
+        )
 
         # Second attempt, 1s later, will fail
         with self.assertRaises(LimitExceededError) as context:
-            limiter.ratelimit(key=("test_id",), _time_now_s=1)
+            self.get_success_or_raise(
+                limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
+            )
         self.assertEqual(context.exception.retry_after_ms, 9000)
 
         # But, if we allow 10 actions/sec for this request, we should be allowed
         # to continue.
-        limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
+        )
 
         # Similarly if we allow a burst of 10 actions
-        limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
+        )
 
     def test_pruning(self):
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        limiter.can_do_action(key="test_id_1", _time_now_s=0)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
+        )
 
         self.assertIn("test_id_1", limiter.actions)
 
-        limiter.can_do_action(key="test_id_2", _time_now_s=10)
+        self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
+        )
 
         self.assertNotIn("test_id_1", limiter.actions)
+
+    def test_db_user_override(self):
+        """Test that users that have ratelimiting disabled in the DB aren't
+        ratelimited.
+        """
+        store = self.hs.get_datastore()
+
+        user_id = "@user:test"
+        requester = create_requester(user_id)
+
+        self.get_success(
+            store.db_pool.simple_insert(
+                table="ratelimit_override",
+                values={
+                    "user_id": user_id,
+                    "messages_per_second": None,
+                    "burst_count": None,
+                },
+                desc="test_db_user_override",
+            )
+        )
+
+        limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
+
+        # Shouldn't raise
+        for _ in range(20):
+            self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index cf61f284cb..0c9ec133c2 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,7 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
 from synapse.api.room_versions import RoomVersions
 from synapse.rest.client.v1 import login, logout, profile, room
 from synapse.rest.client.v2_alpha import devices, sync
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 
 from tests import unittest
 from tests.server import FakeSite, make_request
@@ -467,6 +467,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
     url = "/_synapse/admin/v2/users"
 
     def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -634,6 +636,26 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
 
+        # unkown order_by
+        channel = self.make_request(
+            "GET",
+            self.url + "?order_by=bar",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+        # invalid search order
+        channel = self.make_request(
+            "GET",
+            self.url + "?dir=bar",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
     def test_limit(self):
         """
         Testing list of users with limit
@@ -759,6 +781,103 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(channel.json_body["users"]), 1)
         self.assertNotIn("next_token", channel.json_body)
 
+    def test_order_by(self):
+        """
+        Testing order list with parameter `order_by`
+        """
+
+        user1 = self.register_user("user1", "pass1", admin=False, displayname="Name Z")
+        user2 = self.register_user("user2", "pass2", admin=False, displayname="Name Y")
+
+        # Modify user
+        self.get_success(self.store.set_user_deactivated_status(user1, True))
+        self.get_success(self.store.set_shadow_banned(UserID.from_string(user1), True))
+
+        # Set avatar URL to all users, that no user has a NULL value to avoid
+        # different sort order between SQlite and PostreSQL
+        self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3"))
+        self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2"))
+        self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1"))
+
+        # order by default (name)
+        self._order_test([self.admin_user, user1, user2], None)
+        self._order_test([self.admin_user, user1, user2], None, "f")
+        self._order_test([user2, user1, self.admin_user], None, "b")
+
+        # order by name
+        self._order_test([self.admin_user, user1, user2], "name")
+        self._order_test([self.admin_user, user1, user2], "name", "f")
+        self._order_test([user2, user1, self.admin_user], "name", "b")
+
+        # order by displayname
+        self._order_test([user2, user1, self.admin_user], "displayname")
+        self._order_test([user2, user1, self.admin_user], "displayname", "f")
+        self._order_test([self.admin_user, user1, user2], "displayname", "b")
+
+        # order by is_guest
+        # like sort by ascending name, as no guest user here
+        self._order_test([self.admin_user, user1, user2], "is_guest")
+        self._order_test([self.admin_user, user1, user2], "is_guest", "f")
+        self._order_test([self.admin_user, user1, user2], "is_guest", "b")
+
+        # order by admin
+        self._order_test([user1, user2, self.admin_user], "admin")
+        self._order_test([user1, user2, self.admin_user], "admin", "f")
+        self._order_test([self.admin_user, user1, user2], "admin", "b")
+
+        # order by deactivated
+        self._order_test([self.admin_user, user2, user1], "deactivated")
+        self._order_test([self.admin_user, user2, user1], "deactivated", "f")
+        self._order_test([user1, self.admin_user, user2], "deactivated", "b")
+
+        # order by user_type
+        # like sort by ascending name, as no special user type here
+        self._order_test([self.admin_user, user1, user2], "user_type")
+        self._order_test([self.admin_user, user1, user2], "user_type", "f")
+        self._order_test([self.admin_user, user1, user2], "is_guest", "b")
+
+        # order by shadow_banned
+        self._order_test([self.admin_user, user2, user1], "shadow_banned")
+        self._order_test([self.admin_user, user2, user1], "shadow_banned", "f")
+        self._order_test([user1, self.admin_user, user2], "shadow_banned", "b")
+
+        # order by avatar_url
+        self._order_test([self.admin_user, user2, user1], "avatar_url")
+        self._order_test([self.admin_user, user2, user1], "avatar_url", "f")
+        self._order_test([user1, user2, self.admin_user], "avatar_url", "b")
+
+    def _order_test(
+        self,
+        expected_user_list: List[str],
+        order_by: Optional[str],
+        dir: Optional[str] = None,
+    ):
+        """Request the list of users in a certain order. Assert that order is what
+        we expect
+        Args:
+            expected_user_list: The list of user_id in the order we expect to get
+                back from the server
+            order_by: The type of ordering to give the server
+            dir: The direction of ordering to give the server
+        """
+
+        url = self.url + "?deactivated=true&"
+        if order_by is not None:
+            url += "order_by=%s&" % (order_by,)
+        if dir is not None and dir in ("b", "f"):
+            url += "dir=%s" % (dir,)
+        channel = self.make_request(
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(channel.json_body["total"], len(expected_user_list))
+
+        returned_order = [row["name"] for row in channel.json_body["users"]]
+        self.assertEqual(expected_user_list, returned_order)
+        self._check_fields(channel.json_body["users"])
+
     def _check_fields(self, content: JsonDict):
         """Checks that the expected user attributes are present in content
         Args:
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 9734a2159a..ed433d9333 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Union
+from typing import Optional, Union
 
 from twisted.internet.defer import succeed
 
@@ -74,7 +74,10 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         return channel
 
     def recaptcha(
-        self, session: str, expected_post_response: int, post_session: str = None
+        self,
+        session: str,
+        expected_post_response: int,
+        post_session: Optional[str] = None,
     ) -> None:
         """Get and respond to a fallback recaptcha. Returns the second request."""
         if post_session is None:
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index dabc1c5f09..ef4cf8d0f1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,32 +13,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 import synapse.api.errors
 
-import tests.unittest
-import tests.utils
-
-
-class DeviceStoreTestCase(tests.unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.store = None  # type: synapse.storage.DataStore
+from tests.unittest import HomeserverTestCase
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
 
+class DeviceStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_store_new_device(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device_id", "display_name")
         )
 
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertDictContainsSubset(
             {
                 "user_id": "user_id",
@@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
             res,
         )
 
-    @defer.inlineCallbacks
     def test_get_devices_by_user(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device1", "display_name 1")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device2", "display_name 2")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id2", "device3", "display_name 3")
         )
 
-        res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
+        res = self.get_success(self.store.get_devices_by_user("user_id"))
         self.assertEqual(2, len(res.keys()))
         self.assertDictContainsSubset(
             {
@@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
             res["device2"],
         )
 
-    @defer.inlineCallbacks
     def test_count_devices_by_users(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device1", "display_name 1")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device2", "display_name 2")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id2", "device3", "display_name 3")
         )
 
-        res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+        res = self.get_success(self.store.count_devices_by_users())
         self.assertEqual(0, res)
 
-        res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+        res = self.get_success(self.store.count_devices_by_users(["unknown"]))
         self.assertEqual(0, res)
 
-        res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+        res = self.get_success(self.store.count_devices_by_users(["user_id"]))
         self.assertEqual(2, res)
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.count_devices_by_users(["user_id", "user_id2"])
         )
         self.assertEqual(3, res)
 
-    @defer.inlineCallbacks
     def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with a single stream_id
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
         )
 
         # Get all device updates ever meant for this remote
-        now_stream_id, device_updates = yield defer.ensureDeferred(
+        now_stream_id, device_updates = self.get_success(
             self.store.get_device_updates_by_remote("somehost", -1, limit=100)
         )
 
@@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         }
         self.assertEqual(received_device_ids, set(expected_device_ids))
 
-    @defer.inlineCallbacks
     def test_update_device(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device_id", "display_name 1")
         )
 
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do a no-op first
-        yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        self.get_success(self.store.update_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do the update
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.update_device(
                 "user_id", "device_id", new_display_name="display_name 2"
             )
         )
 
         # check it worked
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 2", res["display_name"])
 
-    @defer.inlineCallbacks
     def test_update_unknown_device(self):
-        with self.assertRaises(synapse.api.errors.StoreError) as cm:
-            yield defer.ensureDeferred(
-                self.store.update_device(
-                    "user_id", "unknown_device_id", new_display_name="display_name 2"
-                )
-            )
-        self.assertEqual(404, cm.exception.code)
+        exc = self.get_failure(
+            self.store.update_device(
+                "user_id", "unknown_device_id", new_display_name="display_name 2"
+            ),
+            synapse.api.errors.StoreError,
+        )
+        self.assertEqual(404, exc.value.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index da93ca3980..0db233fd68 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,28 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
 from synapse.types import RoomAlias, RoomID
 
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase
 
 
-class DirectoryStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
-
+class DirectoryStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
         self.room = RoomID.from_string("!abcde:test")
         self.alias = RoomAlias.from_string("#my-room:test")
 
-    @defer.inlineCallbacks
     def test_room_to_alias(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.create_room_alias_association(
                 room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
             )
@@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
 
         self.assertEquals(
             ["#my-room:test"],
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_aliases_for_room(self.room.to_string())
-                )
-            ),
+            (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
         )
 
-    @defer.inlineCallbacks
     def test_alias_to_room(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.create_room_alias_association(
                 room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
             )
@@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
 
         self.assertObjectHasAttributes(
             {"room_id": self.room.to_string(), "servers": ["test"]},
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_association_from_room_alias(self.alias)
-                )
-            ),
+            (self.get_success(self.store.get_association_from_room_alias(self.alias))),
         )
 
-    @defer.inlineCallbacks
     def test_delete_alias(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.create_room_alias_association(
                 room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
             )
         )
 
-        room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
+        room_id = self.get_success(self.store.delete_room_alias(self.alias))
         self.assertEqual(self.room.to_string(), room_id)
 
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_association_from_room_alias(self.alias)
-                )
-            )
+            (self.get_success(self.store.get_association_from_room_alias(self.alias)))
         )
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 3fc4bb13b6..1e54b940fd 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,30 +13,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
+from tests.unittest import HomeserverTestCase
 
-import tests.unittest
-import tests.utils
 
-
-class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EndToEndKeyStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_key_without_device_name(self):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+        self.get_success(self.store.store_device("user", "device", None))
 
-        yield defer.ensureDeferred(
-            self.store.set_e2e_device_keys("user", "device", now, json)
-        )
+        self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
         )
         self.assertIn("user", res)
@@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         dev = res["user"]["device"]
         self.assertDictContainsSubset(json, dev)
 
-    @defer.inlineCallbacks
     def test_reupload_key(self):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+        self.get_success(self.store.store_device("user", "device", None))
 
-        changed = yield defer.ensureDeferred(
+        changed = self.get_success(
             self.store.set_e2e_device_keys("user", "device", now, json)
         )
         self.assertTrue(changed)
 
         # If we try to upload the same key then we should be told nothing
         # changed
-        changed = yield defer.ensureDeferred(
+        changed = self.get_success(
             self.store.set_e2e_device_keys("user", "device", now, json)
         )
         self.assertFalse(changed)
 
-    @defer.inlineCallbacks
     def test_get_key_with_device_name(self):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield defer.ensureDeferred(
-            self.store.set_e2e_device_keys("user", "device", now, json)
-        )
-        yield defer.ensureDeferred(
-            self.store.store_device("user", "device", "display_name")
-        )
+        self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
+        self.get_success(self.store.store_device("user", "device", "display_name"))
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
         )
         self.assertIn("user", res)
@@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
             {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
         )
 
-    @defer.inlineCallbacks
     def test_multiple_devices(self):
         now = 1470174257070
 
-        yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
-        yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
-        yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
-        yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
+        self.get_success(self.store.store_device("user1", "device1", None))
+        self.get_success(self.store.store_device("user1", "device2", None))
+        self.get_success(self.store.store_device("user2", "device1", None))
+        self.get_success(self.store.store_device("user2", "device2", None))
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
         )
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_device_keys_for_cs_api(
                 (("user1", "device1"), ("user2", "device2"))
             )
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 485f1ee033..239f7c9faf 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,10 +15,7 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
 
 USER_ID = "@user:example.com"
 
@@ -30,37 +27,31 @@ HIGHLIGHT = [
 ]
 
 
-class EventPushActionsStoreTestCase(tests.unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventPushActionsStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
         self.persist_events_store = hs.get_datastores().persist_events
 
-    @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_http(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.get_unread_push_actions_for_user_in_range_for_http(
                 USER_ID, 0, 1000, 20
             )
         )
 
-    @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_email(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.get_unread_push_actions_for_user_in_range_for_email(
                 USER_ID, 0, 1000, 20
             )
         )
 
-    @defer.inlineCallbacks
     def test_count_aggregation(self):
         room_id = "!foo:example.com"
         user_id = "@user1235:example.com"
 
-        @defer.inlineCallbacks
         def _assert_counts(noitf_count, highlight_count):
-            counts = yield defer.ensureDeferred(
+            counts = self.get_success(
                 self.store.db_pool.runInteraction(
                     "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
                 )
@@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
                 },
             )
 
-        @defer.inlineCallbacks
         def _inject_actions(stream, action):
             event = Mock()
             event.room_id = room_id
@@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             event.internal_metadata.stream_ordering = stream
             event.depth = stream
 
-            yield defer.ensureDeferred(
+            self.get_success(
                 self.store.add_push_actions_to_staging(
                     event.event_id,
                     {user_id: action},
                     False,
                 )
             )
-            yield defer.ensureDeferred(
+            self.get_success(
                 self.store.db_pool.runInteraction(
                     "",
                     self.persist_events_store._set_push_actions_for_event_and_users_txn,
@@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             )
 
         def _rotate(stream):
-            return defer.ensureDeferred(
+            self.get_success(
                 self.store.db_pool.runInteraction(
                     "", self.store._rotate_notifs_before_txn, stream
                 )
             )
 
         def _mark_read(stream, depth):
-            return defer.ensureDeferred(
+            self.get_success(
                 self.store.db_pool.runInteraction(
                     "",
                     self.store._remove_old_push_actions_before_txn,
@@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
                 )
             )
 
-        yield _assert_counts(0, 0)
-        yield _inject_actions(1, PlAIN_NOTIF)
-        yield _assert_counts(1, 0)
-        yield _rotate(2)
-        yield _assert_counts(1, 0)
+        _assert_counts(0, 0)
+        _inject_actions(1, PlAIN_NOTIF)
+        _assert_counts(1, 0)
+        _rotate(2)
+        _assert_counts(1, 0)
 
-        yield _inject_actions(3, PlAIN_NOTIF)
-        yield _assert_counts(2, 0)
-        yield _rotate(4)
-        yield _assert_counts(2, 0)
+        _inject_actions(3, PlAIN_NOTIF)
+        _assert_counts(2, 0)
+        _rotate(4)
+        _assert_counts(2, 0)
 
-        yield _inject_actions(5, PlAIN_NOTIF)
-        yield _mark_read(3, 3)
-        yield _assert_counts(1, 0)
+        _inject_actions(5, PlAIN_NOTIF)
+        _mark_read(3, 3)
+        _assert_counts(1, 0)
 
-        yield _mark_read(5, 5)
-        yield _assert_counts(0, 0)
+        _mark_read(5, 5)
+        _assert_counts(0, 0)
 
-        yield _inject_actions(6, PlAIN_NOTIF)
-        yield _rotate(7)
+        _inject_actions(6, PlAIN_NOTIF)
+        _rotate(7)
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.db_pool.simple_delete(
                 table="event_push_actions", keyvalues={"1": 1}, desc=""
             )
         )
 
-        yield _assert_counts(1, 0)
+        _assert_counts(1, 0)
 
-        yield _mark_read(7, 7)
-        yield _assert_counts(0, 0)
+        _mark_read(7, 7)
+        _assert_counts(0, 0)
 
-        yield _inject_actions(8, HIGHLIGHT)
-        yield _assert_counts(1, 1)
-        yield _rotate(9)
-        yield _assert_counts(1, 1)
-        yield _rotate(10)
-        yield _assert_counts(1, 1)
+        _inject_actions(8, HIGHLIGHT)
+        _assert_counts(1, 1)
+        _rotate(9)
+        _assert_counts(1, 1)
+        _rotate(10)
+        _assert_counts(1, 1)
 
-    @defer.inlineCallbacks
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):
-            return defer.ensureDeferred(
+            self.get_success(
                 self.store.db_pool.simple_insert(
                     "events",
                     {
@@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             )
 
         # start with the base case where there are no events in the table
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(11)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
         self.assertEqual(r, 0)
 
         # now with one event
-        yield add_event(2, 10)
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(9)
-        )
+        add_event(2, 10)
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(9))
         self.assertEqual(r, 2)
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(10)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(10))
         self.assertEqual(r, 2)
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(11)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
         self.assertEqual(r, 3)
 
         # add a bunch of dummy events to the events table
@@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             (10, 130),
             (20, 140),
         ):
-            yield add_event(stream_ordering, ts)
+            add_event(stream_ordering, ts)
 
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(110)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(110))
         self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
 
         # 4 and 5 are both after 120: we want 4 rather than 5
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(120)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(120))
         self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
 
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(129)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(129))
         self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
 
         # check we can get the last event
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(140)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(140))
         self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
 
         # off the end
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(160)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(160))
         self.assertEqual(r, 21)
 
         # check we can find an event at ordering zero
-        yield add_event(0, 5)
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(1)
-        )
+        add_event(0, 5)
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(1))
         self.assertEqual(r, 0)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index ea63bd56b4..d18ceb41a9 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,59 +13,50 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
 from synapse.types import UserID
 
 from tests import unittest
-from tests.utils import setup_test_homeserver
-
 
-class ProfileStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
 
+class ProfileStoreTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
-    @defer.inlineCallbacks
     def test_displayname(self):
-        yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+        self.get_success(self.store.create_profile(self.u_frank.localpart))
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
         )
 
         self.assertEquals(
             "Frank",
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.u_frank.localpart)
                 )
             ),
         )
 
         # test set to None
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.u_frank.localpart, None)
         )
 
         self.assertIsNone(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.u_frank.localpart)
                 )
             )
         )
 
-    @defer.inlineCallbacks
     def test_avatar_url(self):
-        yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+        self.get_success(self.store.create_profile(self.u_frank.localpart))
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(
                 self.u_frank.localpart, "http://my.site/here"
             )
@@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase):
         self.assertEquals(
             "http://my.site/here",
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_avatar_url(self.u_frank.localpart)
                 )
             ),
         )
 
         # test set to None
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(self.u_frank.localpart, None)
         )
 
         self.assertIsNone(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_avatar_url(self.u_frank.localpart)
                 )
             )
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b2a0e60856..2622207639 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -1,6 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,8 +15,6 @@
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.types import RoomID, UserID
@@ -230,10 +227,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 self._base_builder = base_builder
                 self._event_id = event_id
 
-            @defer.inlineCallbacks
-            def build(self, prev_event_ids, auth_event_ids):
-                built_event = yield defer.ensureDeferred(
-                    self._base_builder.build(prev_event_ids, auth_event_ids)
+            async def build(self, prev_event_ids, auth_event_ids):
+                built_event = await self._base_builder.build(
+                    prev_event_ids, auth_event_ids
                 )
 
                 built_event._event_id = self._event_id
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eb41c46e8..c82cf15bc2 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,21 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
 from synapse.api.constants import UserTypes
 from synapse.api.errors import ThreepidValidationError
 
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
 
-class RegistrationStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
 
+class RegistrationStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
         self.user_id = "@my-user:test"
@@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
         self.pwhash = "{xx1}123456789"
         self.device_id = "akgjhdjklgshg"
 
-    @defer.inlineCallbacks
     def test_register(self):
-        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
 
         self.assertEquals(
             {
@@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase):
                 "consent_version": None,
                 "consent_server_notice_sent": None,
                 "appservice_id": None,
-                "creation_ts": 1000,
+                "creation_ts": 0,
                 "user_type": None,
                 "deactivated": 0,
                 "shadow_banned": 0,
             },
-            (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
+            (self.get_success(self.store.get_user_by_id(self.user_id))),
         )
 
-    @defer.inlineCallbacks
     def test_add_tokens(self):
-        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
-        yield defer.ensureDeferred(
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
+        self.get_success(
             self.store.add_access_token_to_user(
                 self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
             )
         )
 
-        result = yield defer.ensureDeferred(
-            self.store.get_user_by_access_token(self.tokens[1])
-        )
+        result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
 
         self.assertEqual(result.user_id, self.user_id)
         self.assertEqual(result.device_id, self.device_id)
         self.assertIsNotNone(result.token_id)
 
-    @defer.inlineCallbacks
     def test_user_delete_access_tokens(self):
         # add some tokens
-        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
-        yield defer.ensureDeferred(
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
+        self.get_success(
             self.store.add_access_token_to_user(
                 self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
             )
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.add_access_token_to_user(
                 self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
             )
         )
 
         # now delete some
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
         )
 
         # check they were deleted
-        user = yield defer.ensureDeferred(
-            self.store.get_user_by_access_token(self.tokens[1])
-        )
+        user = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
         self.assertIsNone(user, "access token was not deleted by device_id")
 
         # check the one not associated with the device was not deleted
-        user = yield defer.ensureDeferred(
-            self.store.get_user_by_access_token(self.tokens[0])
-        )
+        user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
         self.assertEqual(self.user_id, user.user_id)
 
         # now delete the rest
-        yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
+        self.get_success(self.store.user_delete_access_tokens(self.user_id))
 
-        user = yield defer.ensureDeferred(
-            self.store.get_user_by_access_token(self.tokens[0])
-        )
+        user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
         self.assertIsNone(user, "access token was not deleted without device_id")
 
-    @defer.inlineCallbacks
     def test_is_support_user(self):
         TEST_USER = "@test:test"
         SUPPORT_USER = "@support:test"
 
-        res = yield defer.ensureDeferred(self.store.is_support_user(None))
+        res = self.get_success(self.store.is_support_user(None))
         self.assertFalse(res)
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.register_user(user_id=TEST_USER, password_hash=None)
         )
-        res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
+        res = self.get_success(self.store.is_support_user(TEST_USER))
         self.assertFalse(res)
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.register_user(
                 user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
             )
         )
-        res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
+        res = self.get_success(self.store.is_support_user(SUPPORT_USER))
         self.assertTrue(res)
 
-    @defer.inlineCallbacks
     def test_3pid_inhibit_invalid_validation_session_error(self):
         """Tests that enabling the configuration option to inhibit 3PID errors on
         /requestToken also inhibits validation errors caused by an unknown session ID.
@@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
 
         # Check that, with the config setting set to false (the default value), a
         # validation error is caused by the unknown session ID.
-        try:
-            yield defer.ensureDeferred(
-                self.store.validate_threepid_session(
-                    "fake_sid",
-                    "fake_client_secret",
-                    "fake_token",
-                    0,
-                )
-            )
-        except ThreepidValidationError as e:
-            self.assertEquals(e.msg, "Unknown session_id", e)
+        e = self.get_failure(
+            self.store.validate_threepid_session(
+                "fake_sid",
+                "fake_client_secret",
+                "fake_token",
+                0,
+            ),
+            ThreepidValidationError,
+        )
+        self.assertEquals(e.value.msg, "Unknown session_id", e)
 
         # Set the config setting to true.
         self.store._ignore_unknown_session_error = True
 
         # Check that now the validation error is caused by the token not matching.
-        try:
-            yield defer.ensureDeferred(
-                self.store.validate_threepid_session(
-                    "fake_sid",
-                    "fake_client_secret",
-                    "fake_token",
-                    0,
-                )
-            )
-        except ThreepidValidationError as e:
-            self.assertEquals(e.msg, "Validation token not found or has expired", e)
+        e = self.get_failure(
+            self.store.validate_threepid_session(
+                "fake_sid",
+                "fake_client_secret",
+                "fake_token",
+                0,
+            ),
+            ThreepidValidationError,
+        )
+        self.assertEquals(e.value.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index bc8400f240..0089d33c93 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,22 +13,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import RoomVersions
 from synapse.types import RoomAlias, RoomID, UserID
 
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
 
-class RoomStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
 
+class RoomStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         # We can't test RoomStore on its own without the DirectoryStore, for
         # management of the 'room_aliases' table
         self.store = hs.get_datastore()
@@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase):
         self.alias = RoomAlias.from_string("#a-room-name:test")
         self.u_creator = UserID.from_string("@creator:test")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_room(
                 self.room.to_string(),
                 room_creator_user_id=self.u_creator.to_string(),
@@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase):
             )
         )
 
-    @defer.inlineCallbacks
     def test_get_room(self):
         self.assertDictContainsSubset(
             {
@@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase):
                 "creator": self.u_creator.to_string(),
                 "is_public": True,
             },
-            (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
+            (self.get_success(self.store.get_room(self.room.to_string()))),
         )
 
-    @defer.inlineCallbacks
     def test_get_room_unknown_room(self):
-        self.assertIsNone(
-            (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
-        )
+        self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
 
-    @defer.inlineCallbacks
     def test_get_room_with_stats(self):
         self.assertDictContainsSubset(
             {
@@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase):
                 "creator": self.u_creator.to_string(),
                 "public": True,
             },
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_room_with_stats(self.room.to_string())
-                )
-            ),
+            (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
         )
 
-    @defer.inlineCallbacks
     def test_get_room_with_stats_unknown_room(self):
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_room_with_stats("!uknown:test")
-                )
-            ),
+            (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
         )
 
 
-class RoomEventsStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = setup_test_homeserver(self.addCleanup)
-
+class RoomEventsStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
         self.store = hs.get_datastore()
@@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
 
         self.room = RoomID.from_string("!abcde:test")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_room(
                 self.room.to_string(),
                 room_creator_user_id="@creator:text",
@@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
             )
         )
 
-    @defer.inlineCallbacks
     def inject_room_event(self, **kwargs):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.storage.persistence.persist_event(
                 self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
             )
         )
 
-    @defer.inlineCallbacks
     def STALE_test_room_name(self):
         name = "A-Room-Name"
 
-        yield self.inject_room_event(
+        self.inject_room_event(
             etype=EventTypes.Name, name=name, content={"name": name}, depth=1
         )
 
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.store.get_current_state(room_id=self.room.to_string())
         )
 
@@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
             state[0],
         )
 
-    @defer.inlineCallbacks
     def STALE_test_room_topic(self):
         topic = "A place for things"
 
-        yield self.inject_room_event(
+        self.inject_room_event(
             etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
         )
 
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.store.get_current_state(room_id=self.room.to_string())
         )
 
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 2471f1267d..f06b452fa9 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,24 +15,18 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.storage.state import StateFilter
 from synapse.types import RoomID, UserID
 
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
 
 logger = logging.getLogger(__name__)
 
 
-class StateStoreTestCase(tests.unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
-
+class StateStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
         self.state_datastore = self.storage.state.stores.state
@@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.room = RoomID.from_string("!abc123:test")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_room(
                 self.room.to_string(),
                 room_creator_user_id="@creator:text",
@@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
             )
         )
 
-    @defer.inlineCallbacks
     def inject_state_event(self, room, sender, typ, state_key, content):
         builder = self.event_builder_factory.for_room_version(
             RoomVersions.V1,
@@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield defer.ensureDeferred(
+        event, context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield defer.ensureDeferred(
-            self.storage.persistence.persist_event(event, context)
-        )
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.assertEqual(s1[t].event_id, s2[t].event_id)
         self.assertEqual(len(s1), len(s2))
 
-    @defer.inlineCallbacks
     def test_get_state_groups_ids(self):
-        e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, "", {}
-        )
-        e2 = yield self.inject_state_event(
+        e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+        e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield defer.ensureDeferred(
+        state_group_map = self.get_success(
             self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
@@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
             {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
         )
 
-    @defer.inlineCallbacks
     def test_get_state_groups(self):
-        e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, "", {}
-        )
-        e2 = yield self.inject_state_event(
+        e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+        e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield defer.ensureDeferred(
+        state_group_map = self.get_success(
             self.storage.state.get_state_groups(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
@@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
 
-    @defer.inlineCallbacks
     def test_get_state_for_event(self):
 
         # this defaults to a linear DAG as each new injection defaults to whatever
         # forward extremities are currently in the DB for this room.
-        e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, "", {}
-        )
-        e2 = yield self.inject_state_event(
+        e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+        e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
-        e3 = yield self.inject_state_event(
+        e3 = self.inject_state_event(
             self.room,
             self.u_alice,
             EventTypes.Member,
             self.u_alice.to_string(),
             {"membership": Membership.JOIN},
         )
-        e4 = yield self.inject_state_event(
+        e4 = self.inject_state_event(
             self.room,
             self.u_bob,
             EventTypes.Member,
             self.u_bob.to_string(),
             {"membership": Membership.JOIN},
         )
-        e5 = yield self.inject_state_event(
+        e5 = self.inject_state_event(
             self.room,
             self.u_bob,
             EventTypes.Member,
@@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield defer.ensureDeferred(
-            self.storage.state.get_state_for_event(e5.event_id)
-        )
+        state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
 
         self.assertIsNotNone(e4)
 
@@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
             )
@@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
             )
@@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
             )
@@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
@@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
@@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield defer.ensureDeferred(
+        group_ids = self.get_success(
             self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
         )
         group = list(group_ids.keys())[0]
 
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with wildcard types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
         room_id = self.room.to_string()
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
         room_id = self.room.to_string()
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # wildcard types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({}, state_dict)
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index a6f63f4aaf..019c5b7b14 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,10 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase, override_config
 
 ALICE = "@alice:a"
 BOB = "@bob:b"
@@ -25,73 +22,52 @@ BOBBY = "@bobby:a"
 BELA = "@somenickname:a"
 
 
-class UserDirectoryStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = self.hs.get_datastore()
+class UserDirectoryStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
-        yield defer.ensureDeferred(
-            self.store.update_profile_in_user_dir(ALICE, "alice", None)
-        )
-        yield defer.ensureDeferred(
-            self.store.update_profile_in_user_dir(BOB, "bob", None)
-        )
-        yield defer.ensureDeferred(
-            self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
-        )
-        yield defer.ensureDeferred(
-            self.store.update_profile_in_user_dir(BELA, "Bela", None)
-        )
-        yield defer.ensureDeferred(
-            self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
-        )
+        self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None))
+        self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None))
+        self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None))
+        self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
+        self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
 
-    @defer.inlineCallbacks
     def test_search_user_dir(self):
         # normally when alice searches the directory she should just find
         # bob because bobby doesn't share a room with her.
-        r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
+        r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
         self.assertFalse(r["limited"])
         self.assertEqual(1, len(r["results"]))
         self.assertDictEqual(
             r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
         )
 
-    @defer.inlineCallbacks
+    @override_config({"user_directory": {"search_all_users": True}})
     def test_search_user_dir_all_users(self):
-        self.hs.config.user_directory_search_all_users = True
-        try:
-            r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
-            self.assertFalse(r["limited"])
-            self.assertEqual(2, len(r["results"]))
-            self.assertDictEqual(
-                r["results"][0],
-                {"user_id": BOB, "display_name": "bob", "avatar_url": None},
-            )
-            self.assertDictEqual(
-                r["results"][1],
-                {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
-            )
-        finally:
-            self.hs.config.user_directory_search_all_users = False
+        r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
+        self.assertFalse(r["limited"])
+        self.assertEqual(2, len(r["results"]))
+        self.assertDictEqual(
+            r["results"][0],
+            {"user_id": BOB, "display_name": "bob", "avatar_url": None},
+        )
+        self.assertDictEqual(
+            r["results"][1],
+            {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
+        )
 
-    @defer.inlineCallbacks
+    @override_config({"user_directory": {"search_all_users": True}})
     def test_search_user_dir_stop_words(self):
         """Tests that a user can look up another user by searching for the start if its
         display name even if that name happens to be a common English word that would
         usually be ignored in full text searches.
         """
-        self.hs.config.user_directory_search_all_users = True
-        try:
-            r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
-            self.assertFalse(r["limited"])
-            self.assertEqual(1, len(r["results"]))
-            self.assertDictEqual(
-                r["results"][0],
-                {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
-            )
-        finally:
-            self.hs.config.user_directory_search_all_users = False
+        r = self.get_success(self.store.search_user_dir(ALICE, "be", 10))
+        self.assertFalse(r["limited"])
+        self.assertEqual(1, len(r["results"]))
+        self.assertDictEqual(
+            r["results"][0],
+            {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+        )
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 3f2691ee6b..b5f18344dc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -207,6 +207,226 @@ class EventAuthTestCase(unittest.TestCase):
                 do_sig_check=False,
             )
 
+    def test_join_rules_public(self):
+        """
+        Test joining a public room.
+        """
+        creator = "@creator:example.com"
+        pleb = "@joiner:example.com"
+
+        auth_events = {
+            ("m.room.create", ""): _create_event(creator),
+            ("m.room.member", creator): _join_event(creator),
+            ("m.room.join_rules", ""): _join_rules_event(creator, "public"),
+        }
+
+        # Check join.
+        event_auth.check(
+            RoomVersions.V6,
+            _join_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        )
+
+        # A user cannot be force-joined to a room.
+        with self.assertRaises(AuthError):
+            event_auth.check(
+                RoomVersions.V6,
+                _member_event(pleb, "join", sender=creator),
+                auth_events,
+                do_sig_check=False,
+            )
+
+        # Banned should be rejected.
+        auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+        with self.assertRaises(AuthError):
+            event_auth.check(
+                RoomVersions.V6,
+                _join_event(pleb),
+                auth_events,
+                do_sig_check=False,
+            )
+
+        # A user who left can re-join.
+        auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+        event_auth.check(
+            RoomVersions.V6,
+            _join_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        )
+
+        # A user can send a join if they're in the room.
+        auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+        event_auth.check(
+            RoomVersions.V6,
+            _join_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        )
+
+        # A user can accept an invite.
+        auth_events[("m.room.member", pleb)] = _member_event(
+            pleb, "invite", sender=creator
+        )
+        event_auth.check(
+            RoomVersions.V6,
+            _join_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        )
+
+    def test_join_rules_invite(self):
+        """
+        Test joining an invite only room.
+        """
+        creator = "@creator:example.com"
+        pleb = "@joiner:example.com"
+
+        auth_events = {
+            ("m.room.create", ""): _create_event(creator),
+            ("m.room.member", creator): _join_event(creator),
+            ("m.room.join_rules", ""): _join_rules_event(creator, "invite"),
+        }
+
+        # A join without an invite is rejected.
+        with self.assertRaises(AuthError):
+            event_auth.check(
+                RoomVersions.V6,
+                _join_event(pleb),
+                auth_events,
+                do_sig_check=False,
+            )
+
+        # A user cannot be force-joined to a room.
+        with self.assertRaises(AuthError):
+            event_auth.check(
+                RoomVersions.V6,
+                _member_event(pleb, "join", sender=creator),
+                auth_events,
+                do_sig_check=False,
+            )
+
+        # Banned should be rejected.
+        auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+        with self.assertRaises(AuthError):
+            event_auth.check(
+                RoomVersions.V6,
+                _join_event(pleb),
+                auth_events,
+                do_sig_check=False,
+            )
+
+        # A user who left cannot re-join.
+        auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+        with self.assertRaises(AuthError):
+            event_auth.check(
+                RoomVersions.V6,
+                _join_event(pleb),
+                auth_events,
+                do_sig_check=False,
+            )
+
+        # A user can send a join if they're in the room.
+        auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+        event_auth.check(
+            RoomVersions.V6,
+            _join_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        )
+
+        # A user can accept an invite.
+        auth_events[("m.room.member", pleb)] = _member_event(
+            pleb, "invite", sender=creator
+        )
+        event_auth.check(
+            RoomVersions.V6,
+            _join_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        )
+
+    def test_join_rules_msc3083_restricted(self):
+        """
+        Test joining a restricted room from MSC3083.
+
+        This is pretty much the same test as public.
+        """
+        creator = "@creator:example.com"
+        pleb = "@joiner:example.com"
+
+        auth_events = {
+            ("m.room.create", ""): _create_event(creator),
+            ("m.room.member", creator): _join_event(creator),
+            ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
+        }
+
+        # Older room versions don't understand this join rule
+        with self.assertRaises(AuthError):
+            event_auth.check(
+                RoomVersions.V6,
+                _join_event(pleb),
+                auth_events,
+                do_sig_check=False,
+            )
+
+        # Check join.
+        event_auth.check(
+            RoomVersions.MSC3083,
+            _join_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        )
+
+        # A user cannot be force-joined to a room.
+        with self.assertRaises(AuthError):
+            event_auth.check(
+                RoomVersions.MSC3083,
+                _member_event(pleb, "join", sender=creator),
+                auth_events,
+                do_sig_check=False,
+            )
+
+        # Banned should be rejected.
+        auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+        with self.assertRaises(AuthError):
+            event_auth.check(
+                RoomVersions.MSC3083,
+                _join_event(pleb),
+                auth_events,
+                do_sig_check=False,
+            )
+
+        # A user who left can re-join.
+        auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+        event_auth.check(
+            RoomVersions.MSC3083,
+            _join_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        )
+
+        # A user can send a join if they're in the room.
+        auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+        event_auth.check(
+            RoomVersions.MSC3083,
+            _join_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        )
+
+        # A user can accept an invite.
+        auth_events[("m.room.member", pleb)] = _member_event(
+            pleb, "invite", sender=creator
+        )
+        event_auth.check(
+            RoomVersions.MSC3083,
+            _join_event(pleb),
+            auth_events,
+            do_sig_check=False,
+        )
+
 
 # helpers for making events
 
@@ -225,19 +445,24 @@ def _create_event(user_id):
     )
 
 
-def _join_event(user_id):
+def _member_event(user_id, membership, sender=None):
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
             "event_id": _get_event_id(),
             "type": "m.room.member",
-            "sender": user_id,
+            "sender": sender or user_id,
             "state_key": user_id,
-            "content": {"membership": "join"},
+            "content": {"membership": membership},
+            "prev_events": [],
         }
     )
 
 
+def _join_event(user_id):
+    return _member_event(user_id, "join")
+
+
 def _power_levels_event(sender, content):
     return make_event_from_dict(
         {
@@ -277,6 +502,21 @@ def _random_state_event(sender):
     )
 
 
+def _join_rules_event(sender, join_rule):
+    return make_event_from_dict(
+        {
+            "room_id": TEST_ROOM_ID,
+            "event_id": _get_event_id(),
+            "type": "m.room.join_rules",
+            "sender": sender,
+            "state_key": "",
+            "content": {
+                "join_rule": join_rule,
+            },
+        }
+    )
+
+
 event_count = 0
 
 
diff --git a/tests/utils.py b/tests/utils.py
index be80b13760..a141ee6496 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -122,7 +122,6 @@ def default_config(name, parse=False):
         "enable_registration_captcha": False,
         "macaroon_secret_key": "not even a little secret",
         "trusted_third_party_id_servers": [],
-        "room_invite_state_types": [],
         "password_providers": [],
         "worker_replication_url": "",
         "worker_app": None,