summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12087.bugfix1
-rw-r--r--changelog.d/12198.misc1
-rw-r--r--changelog.d/12199.misc1
-rw-r--r--changelog.d/12216.misc1
-rw-r--r--changelog.d/12219.misc1
-rw-r--r--changelog.d/12224.misc1
-rw-r--r--changelog.d/12225.misc1
-rw-r--r--changelog.d/12227.misc1
-rw-r--r--changelog.d/12228.bugfix1
-rw-r--r--changelog.d/12231.doc1
-rw-r--r--changelog.d/12232.misc1
-rw-r--r--changelog.d/12237.misc1
-rw-r--r--changelog.d/12240.misc1
-rw-r--r--changelog.d/12242.misc1
-rw-r--r--changelog.d/12243.doc1
-rw-r--r--changelog.d/12244.misc1
-rw-r--r--changelog.d/12246.doc1
-rw-r--r--changelog.d/12248.misc1
-rw-r--r--changelog.d/12256.misc1
-rw-r--r--changelog.d/12258.misc1
-rw-r--r--docs/modules/spam_checker_callbacks.md10
-rw-r--r--docs/workers.md30
-rw-r--r--mypy.ini17
-rwxr-xr-xscripts-dev/release.py41
-rw-r--r--synapse/app/homeserver.py5
-rw-r--r--synapse/config/spam_checker.py4
-rw-r--r--synapse/events/spamcheck.py7
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/federation/federation_server.py7
-rw-r--r--synapse/handlers/federation.py61
-rw-r--r--synapse/handlers/pagination.py10
-rw-r--r--synapse/handlers/relations.py264
-rw-r--r--synapse/handlers/room.py5
-rw-r--r--synapse/handlers/search.py3
-rw-r--r--synapse/handlers/sync.py39
-rw-r--r--synapse/handlers/user_directory.py4
-rw-r--r--synapse/module_api/__init__.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py2
-rw-r--r--synapse/rest/client/relations.py75
-rw-r--r--synapse/rest/client/room.py3
-rw-r--r--synapse/rest/client/user_directory.py4
-rw-r--r--synapse/rest/media/v1/preview_html.py39
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py23
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/database.py61
-rw-r--r--synapse/storage/databases/main/account_data.py41
-rw-r--r--synapse/storage/databases/main/cache.py57
-rw-r--r--synapse/storage/databases/main/group_server.py156
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py38
-rw-r--r--synapse/storage/databases/main/relations.py149
-rw-r--r--synapse/storage/databases/main/user_directory.py23
-rw-r--r--synapse/types.py11
-rw-r--r--synapse/util/check_dependencies.py24
-rw-r--r--synapse/visibility.py18
-rw-r--r--tests/handlers/test_cas.py19
-rw-r--r--tests/handlers/test_directory.py84
-rw-r--r--tests/handlers/test_e2e_keys.py36
-rw-r--r--tests/handlers/test_federation.py36
-rw-r--r--tests/handlers/test_oidc.py94
-rw-r--r--tests/handlers/test_presence.py13
-rw-r--r--tests/handlers/test_profile.py43
-rw-r--r--tests/handlers/test_saml.py24
-rw-r--r--tests/handlers/test_typing.py35
-rw-r--r--tests/push/test_http.py40
-rw-r--r--tests/push/test_push_rule_evaluator.py23
-rw-r--r--tests/rest/client/test_relations.py1210
-rw-r--r--tests/rest/media/v1/test_html_preview.py54
-rw-r--r--tests/server.py20
-rw-r--r--tests/storage/test_account_data.py17
-rw-r--r--tests/storage/test_background_update.py48
-rw-r--r--tests/storage/test_database.py162
-rw-r--r--tests/storage/test_id_generators.py80
-rw-r--r--tests/util/test_check_dependencies.py15
73 files changed, 1859 insertions, 1454 deletions
diff --git a/changelog.d/12087.bugfix b/changelog.d/12087.bugfix
new file mode 100644
index 0000000000..6dacdddd0d
--- /dev/null
+++ b/changelog.d/12087.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug which caused the `/_matrix/federation/v1/state` and `.../state_ids` endpoints to return incorrect or invalid data when called for an event which we have stored as an "outlier".
diff --git a/changelog.d/12198.misc b/changelog.d/12198.misc
new file mode 100644
index 0000000000..6b184a9053
--- /dev/null
+++ b/changelog.d/12198.misc
@@ -0,0 +1 @@
+Add tests for database transaction callbacks.
diff --git a/changelog.d/12199.misc b/changelog.d/12199.misc
new file mode 100644
index 0000000000..16dec1d26d
--- /dev/null
+++ b/changelog.d/12199.misc
@@ -0,0 +1 @@
+Handle cancellation in `DatabasePool.runInteraction()`.
diff --git a/changelog.d/12216.misc b/changelog.d/12216.misc
new file mode 100644
index 0000000000..dc398ac1e0
--- /dev/null
+++ b/changelog.d/12216.misc
@@ -0,0 +1 @@
+Add missing type hints for cache storage.
diff --git a/changelog.d/12219.misc b/changelog.d/12219.misc
new file mode 100644
index 0000000000..6079414092
--- /dev/null
+++ b/changelog.d/12219.misc
@@ -0,0 +1 @@
+Clean-up logic around rebasing URLs for URL image previews.
diff --git a/changelog.d/12224.misc b/changelog.d/12224.misc
new file mode 100644
index 0000000000..b67a701dbb
--- /dev/null
+++ b/changelog.d/12224.misc
@@ -0,0 +1 @@
+Add type hints to tests files.
diff --git a/changelog.d/12225.misc b/changelog.d/12225.misc
new file mode 100644
index 0000000000..23105c727c
--- /dev/null
+++ b/changelog.d/12225.misc
@@ -0,0 +1 @@
+Use the `ignored_users` table in additional places instead of re-parsing the account data.
diff --git a/changelog.d/12227.misc b/changelog.d/12227.misc
new file mode 100644
index 0000000000..41c9dcbd37
--- /dev/null
+++ b/changelog.d/12227.misc
@@ -0,0 +1 @@
+Refactor the relations endpoints to add a `RelationsHandler`.
diff --git a/changelog.d/12228.bugfix b/changelog.d/12228.bugfix
new file mode 100644
index 0000000000..4755777139
--- /dev/null
+++ b/changelog.d/12228.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in v1.53.0 where an unnecessary query could be performed when fetching bundled aggregations for threads.
diff --git a/changelog.d/12231.doc b/changelog.d/12231.doc
new file mode 100644
index 0000000000..16593d2b92
--- /dev/null
+++ b/changelog.d/12231.doc
@@ -0,0 +1 @@
+Fix the link to the module documentation in the legacy spam checker warning message.
diff --git a/changelog.d/12232.misc b/changelog.d/12232.misc
new file mode 100644
index 0000000000..4a4132edff
--- /dev/null
+++ b/changelog.d/12232.misc
@@ -0,0 +1 @@
+Refactor relations tests to improve code re-use.
diff --git a/changelog.d/12237.misc b/changelog.d/12237.misc
new file mode 100644
index 0000000000..41c9dcbd37
--- /dev/null
+++ b/changelog.d/12237.misc
@@ -0,0 +1 @@
+Refactor the relations endpoints to add a `RelationsHandler`.
diff --git a/changelog.d/12240.misc b/changelog.d/12240.misc
new file mode 100644
index 0000000000..c5b6356799
--- /dev/null
+++ b/changelog.d/12240.misc
@@ -0,0 +1 @@
+Add type hints to tests files.
\ No newline at end of file
diff --git a/changelog.d/12242.misc b/changelog.d/12242.misc
new file mode 100644
index 0000000000..38e7e0f7d1
--- /dev/null
+++ b/changelog.d/12242.misc
@@ -0,0 +1 @@
+Generate announcement links in the release script.
diff --git a/changelog.d/12243.doc b/changelog.d/12243.doc
new file mode 100644
index 0000000000..b2031f0a40
--- /dev/null
+++ b/changelog.d/12243.doc
@@ -0,0 +1 @@
+Remove incorrect prefixes in the worker documentation for some endpoints.
diff --git a/changelog.d/12244.misc b/changelog.d/12244.misc
new file mode 100644
index 0000000000..950d48e4c6
--- /dev/null
+++ b/changelog.d/12244.misc
@@ -0,0 +1 @@
+Improve error message when dependencies check finds a broken installation.
\ No newline at end of file
diff --git a/changelog.d/12246.doc b/changelog.d/12246.doc
new file mode 100644
index 0000000000..e7fcc1b99c
--- /dev/null
+++ b/changelog.d/12246.doc
@@ -0,0 +1 @@
+Correct `check_username_for_spam` annotations and docs.
\ No newline at end of file
diff --git a/changelog.d/12248.misc b/changelog.d/12248.misc
new file mode 100644
index 0000000000..2b1290d1e1
--- /dev/null
+++ b/changelog.d/12248.misc
@@ -0,0 +1 @@
+Add missing type hints for storage.
\ No newline at end of file
diff --git a/changelog.d/12256.misc b/changelog.d/12256.misc
new file mode 100644
index 0000000000..c5b6356799
--- /dev/null
+++ b/changelog.d/12256.misc
@@ -0,0 +1 @@
+Add type hints to tests files.
\ No newline at end of file
diff --git a/changelog.d/12258.misc b/changelog.d/12258.misc
new file mode 100644
index 0000000000..80024c8e91
--- /dev/null
+++ b/changelog.d/12258.misc
@@ -0,0 +1 @@
+Compress metrics HTTP resource when enabled. Contributed by Nick @ Beeper.
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 2b672b78f9..472d957180 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -172,7 +172,7 @@ any of the subsequent implementations of this callback.
 _First introduced in Synapse v1.37.0_
 
 ```python
-async def check_username_for_spam(user_profile: Dict[str, str]) -> bool
+async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool
 ```
 
 Called when computing search results in the user directory. The module must return a
@@ -182,9 +182,11 @@ search results; otherwise return `False`.
 
 The profile is represented as a dictionary with the following keys:
 
-* `user_id`: The Matrix ID for this user.
-* `display_name`: The user's display name.
-* `avatar_url`: The `mxc://` URL to the user's avatar.
+* `user_id: str`. The Matrix ID for this user.
+* `display_name: Optional[str]`. The user's display name, or `None` if this user
+  has not set a display name.
+* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None`
+  if this user has not set an avatar.
 
 The module is given a copy of the original dictionary, so modifying it from within the
 module cannot modify a user's profile when included in user directory search results.
diff --git a/docs/workers.md b/docs/workers.md
index 8751134e65..9eb4194e4d 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -185,8 +185,8 @@ worker: refer to the [stream writers](#stream-writers) section below for further
 information.
 
     # Sync requests
-    ^/_matrix/client/(v2_alpha|r0|v3)/sync$
-    ^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$
+    ^/_matrix/client/(r0|v3)/sync$
+    ^/_matrix/client/(api/v1|r0|v3)/events$
     ^/_matrix/client/(api/v1|r0|v3)/initialSync$
     ^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$
 
@@ -200,13 +200,9 @@ information.
     ^/_matrix/federation/v1/query/
     ^/_matrix/federation/v1/make_join/
     ^/_matrix/federation/v1/make_leave/
-    ^/_matrix/federation/v1/send_join/
-    ^/_matrix/federation/v2/send_join/
-    ^/_matrix/federation/v1/send_leave/
-    ^/_matrix/federation/v2/send_leave/
-    ^/_matrix/federation/v1/invite/
-    ^/_matrix/federation/v2/invite/
-    ^/_matrix/federation/v1/query_auth/
+    ^/_matrix/federation/(v1|v2)/send_join/
+    ^/_matrix/federation/(v1|v2)/send_leave/
+    ^/_matrix/federation/(v1|v2)/invite/
     ^/_matrix/federation/v1/event_auth/
     ^/_matrix/federation/v1/exchange_third_party_invite/
     ^/_matrix/federation/v1/user/devices/
@@ -274,6 +270,8 @@ information.
 Additionally, the following REST endpoints can be handled for GET requests:
 
     ^/_matrix/federation/v1/groups/
+    ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
+    ^/_matrix/client/(r0|v3|unstable)/groups/
 
 Pagination requests can also be handled, but all requests for a given
 room must be routed to the same instance. Additionally, care must be taken to
@@ -397,23 +395,23 @@ the stream writer for the `typing` stream:
 The following endpoints should be routed directly to the worker configured as
 the stream writer for the `to_device` stream:
 
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/
+    ^/_matrix/client/(r0|v3|unstable)/sendToDevice/
 
 ##### The `account_data` stream
 
 The following endpoints should be routed directly to the worker configured as
 the stream writer for the `account_data` stream:
 
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data
+    ^/_matrix/client/(r0|v3|unstable)/.*/tags
+    ^/_matrix/client/(r0|v3|unstable)/.*/account_data
 
 ##### The `receipts` stream
 
 The following endpoints should be routed directly to the worker configured as
 the stream writer for the `receipts` stream:
 
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers
+    ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt
+    ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers
 
 ##### The `presence` stream
 
@@ -528,7 +526,7 @@ Note that if a reverse proxy is used , then `/_matrix/media/` must be routed for
 Handles searches in the user directory. It can handle REST endpoints matching
 the following regular expressions:
 
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$
+    ^/_matrix/client/(r0|v3|unstable)/user_directory/search$
 
 When using this worker you must also set `update_user_directory: False` in the
 shared configuration file to stop the main synapse running background
@@ -540,7 +538,7 @@ Proxies some frequently-requested client endpoints to add caching and remove
 load from the main synapse. It can handle REST endpoints matching the following
 regular expressions:
 
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload
+    ^/_matrix/client/(r0|v3|unstable)/keys/upload
 
 If `use_presence` is False in the homeserver config, it can also handle REST
 endpoints matching the following regular expressions:
diff --git a/mypy.ini b/mypy.ini
index f9c39fcaae..24d4ba15d4 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -42,9 +42,6 @@ exclude = (?x)
    |synapse/storage/databases/main/cache.py
    |synapse/storage/databases/main/devices.py
    |synapse/storage/databases/main/event_federation.py
-   |synapse/storage/databases/main/group_server.py
-   |synapse/storage/databases/main/metrics.py
-   |synapse/storage/databases/main/monthly_active_users.py
    |synapse/storage/databases/main/push_rule.py
    |synapse/storage/databases/main/receipts.py
    |synapse/storage/databases/main/roommember.py
@@ -66,14 +63,6 @@ exclude = (?x)
    |tests/federation/test_federation_server.py
    |tests/federation/transport/test_knocking.py
    |tests/federation/transport/test_server.py
-   |tests/handlers/test_cas.py
-   |tests/handlers/test_directory.py
-   |tests/handlers/test_e2e_keys.py
-   |tests/handlers/test_federation.py
-   |tests/handlers/test_oidc.py
-   |tests/handlers/test_presence.py
-   |tests/handlers/test_profile.py
-   |tests/handlers/test_saml.py
    |tests/handlers/test_typing.py
    |tests/http/federation/test_matrix_federation_agent.py
    |tests/http/federation/test_srv_resolver.py
@@ -85,7 +74,6 @@ exclude = (?x)
    |tests/logging/test_terse_json.py
    |tests/module_api/test_api.py
    |tests/push/test_email.py
-   |tests/push/test_http.py
    |tests/push/test_presentable_names.py
    |tests/push/test_push_rule_evaluator.py
    |tests/rest/client/test_transactions.py
@@ -94,12 +82,7 @@ exclude = (?x)
    |tests/server.py
    |tests/server_notices/test_resource_limits_server_notices.py
    |tests/state/test_v2.py
-   |tests/storage/test_background_update.py
    |tests/storage/test_base.py
-   |tests/storage/test_client_ips.py
-   |tests/storage/test_database.py
-   |tests/storage/test_event_federation.py
-   |tests/storage/test_id_generators.py
    |tests/storage/test_roommember.py
    |tests/test_metrics.py
    |tests/test_phone_home.py
diff --git a/scripts-dev/release.py b/scripts-dev/release.py
index 046453e65f..685fa32b03 100755
--- a/scripts-dev/release.py
+++ b/scripts-dev/release.py
@@ -66,11 +66,15 @@ def cli():
 
         ./scripts-dev/release.py tag
 
-        # ... wait for asssets to build ...
+        # ... wait for assets to build ...
 
         ./scripts-dev/release.py publish
         ./scripts-dev/release.py upload
 
+        # Optional: generate some nice links for the announcement
+
+        ./scripts-dev/release.py upload
+
     If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the
     `tag`/`publish` command, then a new draft release will be created/published.
     """
@@ -415,6 +419,41 @@ def upload():
     )
 
 
+@cli.command()
+def announce():
+    """Generate markdown to announce the release."""
+
+    current_version, _, _ = parse_version_from_module()
+    tag_name = f"v{current_version}"
+
+    click.echo(
+        f"""
+Hi everyone. Synapse {current_version} has just been released.
+
+[notes](https://github.com/matrix-org/synapse/releases/tag/{tag_name}) |\
+[docker](https://hub.docker.com/r/matrixdotorg/synapse/tags?name={tag_name}) | \
+[debs](https://packages.matrix.org/debian/) | \
+[pypi](https://pypi.org/project/matrix-synapse/{current_version}/)"""
+    )
+
+    if "rc" in tag_name:
+        click.echo(
+            """
+Announce the RC in
+- #homeowners:matrix.org (Synapse Announcements)
+- #synapse-dev:matrix.org"""
+        )
+    else:
+        click.echo(
+            """
+Announce the release in
+- #homeowners:matrix.org (Synapse Announcements), bumping the version in the topic
+- #synapse:matrix.org (Synapse Admins), bumping the version in the topic
+- #synapse-dev:matrix.org
+- #synapse-package-maintainers:matrix.org"""
+        )
+
+
 def parse_version_from_module() -> Tuple[
     version.Version, redbaron.RedBaron, redbaron.Node
 ]:
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e4dc04c0b4..ad2b7c9515 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -261,7 +261,10 @@ class SynapseHomeServer(HomeServer):
             resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
 
         if name == "metrics" and self.config.metrics.enable_metrics:
-            resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+            metrics_resource: Resource = MetricsResource(RegistryProxy)
+            if compress:
+                metrics_resource = gz_wrap(metrics_resource)
+            resources[METRICS_PREFIX] = metrics_resource
 
         if name == "replication":
             resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index a233a9ce03..4c52103b1c 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
 LEGACY_SPAM_CHECKER_WARNING = """
 This server is using a spam checker module that is implementing the deprecated spam
 checker interface. Please check with the module's maintainer to see if a new version
-supporting Synapse's generic modules system is available.
-For more information, please see https://matrix-org.github.io/synapse/latest/modules.html
+supporting Synapse's generic modules system is available. For more information, please
+see https://matrix-org.github.io/synapse/latest/modules/index.html
 ---------------------------------------------------------------------------------------"""
 
 
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 60904a55f5..cd80fcf9d1 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -21,7 +21,6 @@ from typing import (
     Awaitable,
     Callable,
     Collection,
-    Dict,
     List,
     Optional,
     Tuple,
@@ -31,7 +30,7 @@ from typing import (
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.types import RoomAlias
+from synapse.types import RoomAlias, UserProfile
 from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
@@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo
 USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
 USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
 USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
-CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
+CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
 LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
     [
         Optional[dict],
@@ -383,7 +382,7 @@ class SpamChecker:
 
         return True
 
-    async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+    async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
         """Checks if a user ID or display name are considered "spammy" by this server.
 
         If the server considers a username spammy, then it will not be included in
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index a0520068e0..7120062127 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze
 from . import EventBase
 
 if TYPE_CHECKING:
+    from synapse.handlers.relations import BundledAggregations
     from synapse.server import HomeServer
-    from synapse.storage.databases.main.relations import BundledAggregations
 
 
 # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 482bbdd867..af2d0f7d79 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,7 +22,6 @@ from typing import (
     Callable,
     Collection,
     Dict,
-    Iterable,
     List,
     Optional,
     Tuple,
@@ -577,10 +576,10 @@ class FederationServer(FederationBase):
     async def _on_context_state_request_compute(
         self, room_id: str, event_id: Optional[str]
     ) -> Dict[str, list]:
+        pdus: Collection[EventBase]
         if event_id:
-            pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
-                room_id, event_id
-            )
+            event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
+            pdus = await self.store.get_events_as_list(event_ids)
         else:
             pdus = (await self.state.get_current_state(room_id)).values()
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index db39aeabde..350ec9c03a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -950,54 +950,35 @@ class FederationHandler:
 
         return event
 
-    async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
-        """Returns the state at the event. i.e. not including said event."""
-
-        event = await self.store.get_event(event_id, check_room_id=room_id)
-
-        state_groups = await self.state_store.get_state_groups(room_id, [event_id])
-
-        if state_groups:
-            _, state = list(state_groups.items()).pop()
-            results = {(e.type, e.state_key): e for e in state}
-
-            if event.is_state():
-                # Get previous state
-                if "replaces_state" in event.unsigned:
-                    prev_id = event.unsigned["replaces_state"]
-                    if prev_id != event.event_id:
-                        prev_event = await self.store.get_event(prev_id)
-                        results[(event.type, event.state_key)] = prev_event
-                else:
-                    del results[(event.type, event.state_key)]
-
-            res = list(results.values())
-            return res
-        else:
-            return []
-
     async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event."""
         event = await self.store.get_event(event_id, check_room_id=room_id)
+        if event.internal_metadata.outlier:
+            raise NotFoundError("State not known at event %s" % (event_id,))
 
         state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
 
-        if state_groups:
-            _, state = list(state_groups.items()).pop()
-            results = state
+        # get_state_groups_ids should return exactly one result
+        assert len(state_groups) == 1
 
-            if event.is_state():
-                # Get previous state
-                if "replaces_state" in event.unsigned:
-                    prev_id = event.unsigned["replaces_state"]
-                    if prev_id != event.event_id:
-                        results[(event.type, event.state_key)] = prev_id
-                else:
-                    results.pop((event.type, event.state_key), None)
+        state_map = next(iter(state_groups.values()))
 
-            return list(results.values())
-        else:
-            return []
+        state_key = event.get_state_key()
+        if state_key is not None:
+            # the event was not rejected (get_event raises a NotFoundError for rejected
+            # events) so the state at the event should include the event itself.
+            assert (
+                state_map.get((event.type, state_key)) == event.event_id
+            ), "State at event did not include event itself"
+
+            # ... but we need the state *before* that event
+            if "replaces_state" in event.unsigned:
+                prev_id = event.unsigned["replaces_state"]
+                state_map[(event.type, state_key)] = prev_id
+            else:
+                del state_map[(event.type, state_key)]
+
+        return list(state_map.values())
 
     async def on_backfill_request(
         self, origin: str, room_id: str, pdu_list: List[str], limit: int
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 60059fec3e..876b879483 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
 
 import attr
 
@@ -134,6 +134,7 @@ class PaginationHandler:
         self.clock = hs.get_clock()
         self._server_name = hs.hostname
         self._room_shutdown_handler = hs.get_room_shutdown_handler()
+        self._relations_handler = hs.get_relations_handler()
 
         self.pagination_lock = ReadWriteLock()
         # IDs of rooms in which there currently an active purge *or delete* operation.
@@ -422,7 +423,7 @@ class PaginationHandler:
         pagin_config: PaginationConfig,
         as_client_event: bool = True,
         event_filter: Optional[Filter] = None,
-    ) -> Dict[str, Any]:
+    ) -> JsonDict:
         """Get messages in a room.
 
         Args:
@@ -431,6 +432,7 @@ class PaginationHandler:
             pagin_config: The pagination config rules to apply, if any.
             as_client_event: True to get events in client-server format.
             event_filter: Filter to apply to results or None
+
         Returns:
             Pagination API results
         """
@@ -538,7 +540,9 @@ class PaginationHandler:
                 state_dict = await self.store.get_events(list(state_ids.values()))
                 state = state_dict.values()
 
-        aggregations = await self.store.get_bundled_aggregations(events, user_id)
+        aggregations = await self._relations_handler.get_bundled_aggregations(
+            events, user_id
+        )
 
         time_now = self.clock.time_msec()
 
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
new file mode 100644
index 0000000000..57135d4519
--- /dev/null
+++ b/synapse/handlers/relations.py
@@ -0,0 +1,264 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
+
+import attr
+from frozendict import frozendict
+
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.events import EventBase
+from synapse.types import JsonDict, Requester, StreamToken
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
+
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+    # The latest event in the thread.
+    latest_event: EventBase
+    # The latest edit to the latest event in the thread.
+    latest_edit: Optional[EventBase]
+    # The total number of events in the thread.
+    count: int
+    # True if the current user has sent an event to the thread.
+    current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+    """
+    The bundled aggregations for an event.
+
+    Some values require additional processing during serialization.
+    """
+
+    annotations: Optional[JsonDict] = None
+    references: Optional[JsonDict] = None
+    replace: Optional[EventBase] = None
+    thread: Optional[_ThreadAggregation] = None
+
+    def __bool__(self) -> bool:
+        return bool(self.annotations or self.references or self.replace or self.thread)
+
+
+class RelationsHandler:
+    def __init__(self, hs: "HomeServer"):
+        self._main_store = hs.get_datastores().main
+        self._auth = hs.get_auth()
+        self._clock = hs.get_clock()
+        self._event_handler = hs.get_event_handler()
+        self._event_serializer = hs.get_event_client_serializer()
+
+    async def get_relations(
+        self,
+        requester: Requester,
+        event_id: str,
+        room_id: str,
+        relation_type: Optional[str] = None,
+        event_type: Optional[str] = None,
+        aggregation_key: Optional[str] = None,
+        limit: int = 5,
+        direction: str = "b",
+        from_token: Optional[StreamToken] = None,
+        to_token: Optional[StreamToken] = None,
+    ) -> JsonDict:
+        """Get related events of a event, ordered by topological ordering.
+
+        TODO Accept a PaginationConfig instead of individual pagination parameters.
+
+        Args:
+            requester: The user requesting the relations.
+            event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
+            relation_type: Only fetch events with this relation type, if given.
+            event_type: Only fetch events with this event type, if given.
+            aggregation_key: Only fetch events with this aggregation key, if given.
+            limit: Only fetch the most recent `limit` events.
+            direction: Whether to fetch the most recent first (`"b"`) or the
+                oldest first (`"f"`).
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
+
+        Returns:
+            The pagination chunk.
+        """
+
+        user_id = requester.user.to_string()
+
+        await self._auth.check_user_in_room_or_world_readable(
+            room_id, user_id, allow_departed_users=True
+        )
+
+        # This gets the original event and checks that a) the event exists and
+        # b) the user is allowed to view it.
+        event = await self._event_handler.get_event(requester.user, room_id, event_id)
+        if event is None:
+            raise SynapseError(404, "Unknown parent event.")
+
+        pagination_chunk = await self._main_store.get_relations_for_event(
+            event_id=event_id,
+            event=event,
+            room_id=room_id,
+            relation_type=relation_type,
+            event_type=event_type,
+            aggregation_key=aggregation_key,
+            limit=limit,
+            direction=direction,
+            from_token=from_token,
+            to_token=to_token,
+        )
+
+        events = await self._main_store.get_events_as_list(
+            [c["event_id"] for c in pagination_chunk.chunk]
+        )
+
+        now = self._clock.time_msec()
+        # Do not bundle aggregations when retrieving the original event because
+        # we want the content before relations are applied to it.
+        original_event = self._event_serializer.serialize_event(
+            event, now, bundle_aggregations=None
+        )
+        # The relations returned for the requested event do include their
+        # bundled aggregations.
+        aggregations = await self.get_bundled_aggregations(
+            events, requester.user.to_string()
+        )
+        serialized_events = self._event_serializer.serialize_events(
+            events, now, bundle_aggregations=aggregations
+        )
+
+        return_value = await pagination_chunk.to_dict(self._main_store)
+        return_value["chunk"] = serialized_events
+        return_value["original_event"] = original_event
+
+        return return_value
+
+    async def _get_bundled_aggregation_for_event(
+        self, event: EventBase, user_id: str
+    ) -> Optional[BundledAggregations]:
+        """Generate bundled aggregations for an event.
+
+        Note that this does not use a cache, but depends on cached methods.
+
+        Args:
+            event: The event to calculate bundled aggregations for.
+            user_id: The user requesting the bundled aggregations.
+
+        Returns:
+            The bundled aggregations for an event, if bundled aggregations are
+            enabled and the event can have bundled aggregations.
+        """
+
+        # Do not bundle aggregations for an event which represents an edit or an
+        # annotation. It does not make sense for them to have related events.
+        relates_to = event.content.get("m.relates_to")
+        if isinstance(relates_to, (dict, frozendict)):
+            relation_type = relates_to.get("rel_type")
+            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+                return None
+
+        event_id = event.event_id
+        room_id = event.room_id
+
+        # The bundled aggregations to include, a mapping of relation type to a
+        # type-specific value. Some types include the direct return type here
+        # while others need more processing during serialization.
+        aggregations = BundledAggregations()
+
+        annotations = await self._main_store.get_aggregation_groups_for_event(
+            event_id, room_id
+        )
+        if annotations.chunk:
+            aggregations.annotations = await annotations.to_dict(
+                cast("DataStore", self)
+            )
+
+        references = await self._main_store.get_relations_for_event(
+            event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
+        )
+        if references.chunk:
+            aggregations.references = await references.to_dict(cast("DataStore", self))
+
+        # Store the bundled aggregations in the event metadata for later use.
+        return aggregations
+
+    async def get_bundled_aggregations(
+        self, events: Iterable[EventBase], user_id: str
+    ) -> Dict[str, BundledAggregations]:
+        """Generate bundled aggregations for events.
+
+        Args:
+            events: The iterable of events to calculate bundled aggregations for.
+            user_id: The user requesting the bundled aggregations.
+
+        Returns:
+            A map of event ID to the bundled aggregation for the event. Not all
+            events may have bundled aggregations in the results.
+        """
+        # De-duplicate events by ID to handle the same event requested multiple times.
+        #
+        # State events do not get bundled aggregations.
+        events_by_id = {
+            event.event_id: event for event in events if not event.is_state()
+        }
+
+        # event ID -> bundled aggregation in non-serialized form.
+        results: Dict[str, BundledAggregations] = {}
+
+        # Fetch other relations per event.
+        for event in events_by_id.values():
+            event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+            if event_result:
+                results[event.event_id] = event_result
+
+        # Fetch any edits (but not for redacted events).
+        edits = await self._main_store.get_applicable_edits(
+            [
+                event_id
+                for event_id, event in events_by_id.items()
+                if not event.internal_metadata.is_redacted()
+            ]
+        )
+        for event_id, edit in edits.items():
+            results.setdefault(event_id, BundledAggregations()).replace = edit
+
+        # Fetch thread summaries.
+        summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
+        # Only fetch participated for a limited selection based on what had
+        # summaries.
+        participated = await self._main_store.get_threads_participated(
+            [event_id for event_id, summary in summaries.items() if summary], user_id
+        )
+        for event_id, summary in summaries.items():
+            if summary:
+                thread_count, latest_thread_event, edit = summary
+                results.setdefault(
+                    event_id, BundledAggregations()
+                ).thread = _ThreadAggregation(
+                    latest_event=latest_thread_event,
+                    latest_edit=edit,
+                    count=thread_count,
+                    # If there's a thread summary it must also exist in the
+                    # participated dictionary.
+                    current_user_participated=participated[event_id],
+                )
+
+        return results
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b9735631fc..092e185c99 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -60,8 +60,8 @@ from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.handlers.federation import get_domains_from_state
+from synapse.handlers.relations import BundledAggregations
 from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.state import StateFilter
 from synapse.streams import EventSource
 from synapse.types import (
@@ -1118,6 +1118,7 @@ class RoomContextHandler:
         self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
+        self._relations_handler = hs.get_relations_handler()
 
     async def get_event_context(
         self,
@@ -1190,7 +1191,7 @@ class RoomContextHandler:
         event = filtered[0]
 
         # Fetch the aggregations.
-        aggregations = await self.store.get_bundled_aggregations(
+        aggregations = await self._relations_handler.get_bundled_aggregations(
             itertools.chain(events_before, (event,), events_after),
             user.to_string(),
         )
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index aa16e417eb..30eddda65f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -54,6 +54,7 @@ class SearchHandler:
         self.clock = hs.get_clock()
         self.hs = hs
         self._event_serializer = hs.get_event_client_serializer()
+        self._relations_handler = hs.get_relations_handler()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
         self.auth = hs.get_auth()
@@ -354,7 +355,7 @@ class SearchHandler:
 
         aggregations = None
         if self._msc3666_enabled:
-            aggregations = await self.store.get_bundled_aggregations(
+            aggregations = await self._relations_handler.get_bundled_aggregations(
                 # Generate an iterable of EventBase for all the events that will be
                 # returned, including contextual events.
                 itertools.chain(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0aa3052fd6..6c569cfb1c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,16 +28,16 @@ from typing import (
 import attr
 from prometheus_client import Counter
 
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
 from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
+from synapse.handlers.relations import BundledAggregations
 from synapse.logging.context import current_context
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.databases.main.event_push_actions import NotifCounts
-from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -269,6 +269,7 @@ class SyncHandler:
         self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.presence_handler = hs.get_presence_handler()
+        self._relations_handler = hs.get_relations_handler()
         self.event_sources = hs.get_event_sources()
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
@@ -638,8 +639,10 @@ class SyncHandler:
         # as clients will have all the necessary information.
         bundled_aggregations = None
         if limited or newly_joined_room:
-            bundled_aggregations = await self.store.get_bundled_aggregations(
-                recents, sync_config.user.to_string()
+            bundled_aggregations = (
+                await self._relations_handler.get_bundled_aggregations(
+                    recents, sync_config.user.to_string()
+                )
             )
 
         return TimelineBatch(
@@ -1601,7 +1604,7 @@ class SyncHandler:
                         return set(), set(), set(), set()
 
         # 3. Work out which rooms need reporting in the sync response.
-        ignored_users = await self._get_ignored_users(user_id)
+        ignored_users = await self.store.ignored_users(user_id)
         if since_token:
             room_changes = await self._get_rooms_changed(
                 sync_result_builder, ignored_users
@@ -1627,7 +1630,6 @@ class SyncHandler:
             logger.debug("Generating room entry for %s", room_entry.room_id)
             await self._generate_room_entry(
                 sync_result_builder,
-                ignored_users,
                 room_entry,
                 ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
                 tags=tags_by_room.get(room_entry.room_id),
@@ -1657,29 +1659,6 @@ class SyncHandler:
             newly_left_users,
         )
 
-    async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
-        """Retrieve the users ignored by the given user from their global account_data.
-
-        Returns an empty set if
-        - there is no global account_data entry for ignored_users
-        - there is such an entry, but it's not a JSON object.
-        """
-        # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
-        ignored_account_data = (
-            await self.store.get_global_account_data_by_type_for_user(
-                user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
-            )
-        )
-
-        # If there is ignored users account data and it matches the proper type,
-        # then use it.
-        ignored_users: FrozenSet[str] = frozenset()
-        if ignored_account_data:
-            ignored_users_data = ignored_account_data.get("ignored_users", {})
-            if isinstance(ignored_users_data, dict):
-                ignored_users = frozenset(ignored_users_data.keys())
-        return ignored_users
-
     async def _have_rooms_changed(
         self, sync_result_builder: "SyncResultBuilder"
     ) -> bool:
@@ -2022,7 +2001,6 @@ class SyncHandler:
     async def _generate_room_entry(
         self,
         sync_result_builder: "SyncResultBuilder",
-        ignored_users: FrozenSet[str],
         room_builder: "RoomSyncResultBuilder",
         ephemeral: List[JsonDict],
         tags: Optional[Dict[str, Dict[str, Any]]],
@@ -2051,7 +2029,6 @@ class SyncHandler:
 
         Args:
             sync_result_builder
-            ignored_users: Set of users ignored by user.
             room_builder
             ephemeral: List of new ephemeral events for room
             tags: List of *all* tags for room, or None if there has been
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d27ed2be6a..048fd4bb82 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -19,8 +19,8 @@ import synapse.metrics
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
 from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.user_directory import SearchResult
 from synapse.storage.roommember import ProfileInfo
-from synapse.types import JsonDict
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler):
 
     async def search_users(
         self, user_id: str, search_term: str, limit: int
-    ) -> JsonDict:
+    ) -> SearchResult:
         """Searches for users in directory
 
         Returns:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d735c1d461..aa8256b36f 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -111,6 +111,7 @@ from synapse.types import (
     StateMap,
     UserID,
     UserInfo,
+    UserProfile,
     create_requester,
 )
 from synapse.util import Clock
@@ -150,6 +151,7 @@ __all__ = [
     "EventBase",
     "StateMap",
     "ProfileInfo",
+    "UserProfile",
 ]
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 8140afcb6b..030898e4d0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -213,7 +213,7 @@ class BulkPushRuleEvaluator:
         if not event.is_state():
             ignorers = await self.store.ignored_by(event.sender)
         else:
-            ignorers = set()
+            ignorers = frozenset()
 
         for uid, rules in rules_by_user.items():
             if event.sender == uid:
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index d9a6be43f7..c16078b187 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
-        self.clock = hs.get_clock()
-        self._event_serializer = hs.get_event_client_serializer()
-        self.event_handler = hs.get_event_handler()
+        self._relations_handler = hs.get_relations_handler()
 
     async def on_GET(
         self,
@@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        await self.auth.check_user_in_room_or_world_readable(
-            room_id, requester.user.to_string(), allow_departed_users=True
-        )
-
-        # This gets the original event and checks that a) the event exists and
-        # b) the user is allowed to view it.
-        event = await self.event_handler.get_event(requester.user, room_id, parent_id)
-        if event is None:
-            raise SynapseError(404, "Unknown parent event.")
-
         limit = parse_integer(request, "limit", default=5)
         direction = parse_string(
             request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
@@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
         if to_token_str:
             to_token = await StreamToken.from_string(self.store, to_token_str)
 
-        pagination_chunk = await self.store.get_relations_for_event(
+        result = await self._relations_handler.get_relations(
+            requester=requester,
             event_id=parent_id,
-            event=event,
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
@@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
             to_token=to_token,
         )
 
-        events = await self.store.get_events_as_list(
-            [c["event_id"] for c in pagination_chunk.chunk]
-        )
-
-        now = self.clock.time_msec()
-        # Do not bundle aggregations when retrieving the original event because
-        # we want the content before relations are applied to it.
-        original_event = self._event_serializer.serialize_event(
-            event, now, bundle_aggregations=None
-        )
-        # The relations returned for the requested event do include their
-        # bundled aggregations.
-        aggregations = await self.store.get_bundled_aggregations(
-            events, requester.user.to_string()
-        )
-        serialized_events = self._event_serializer.serialize_events(
-            events, now, bundle_aggregations=aggregations
-        )
-
-        return_value = await pagination_chunk.to_dict(self.store)
-        return_value["chunk"] = serialized_events
-        return_value["original_event"] = original_event
-
-        return 200, return_value
+        return 200, result
 
 
 class RelationAggregationPaginationServlet(RestServlet):
@@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
-        self.clock = hs.get_clock()
-        self._event_serializer = hs.get_event_client_serializer()
-        self.event_handler = hs.get_event_handler()
+        self._relations_handler = hs.get_relations_handler()
 
     async def on_GET(
         self,
@@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        await self.auth.check_user_in_room_or_world_readable(
-            room_id,
-            requester.user.to_string(),
-            allow_departed_users=True,
-        )
-
-        # This checks that a) the event exists and b) the user is allowed to
-        # view it.
-        event = await self.event_handler.get_event(requester.user, room_id, parent_id)
-        if event is None:
-            raise SynapseError(404, "Unknown parent event.")
-
         if relation_type != RelationTypes.ANNOTATION:
             raise SynapseError(400, "Relation type must be 'annotation'")
 
@@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         if to_token_str:
             to_token = await StreamToken.from_string(self.store, to_token_str)
 
-        result = await self.store.get_relations_for_event(
+        result = await self._relations_handler.get_relations(
+            requester=requester,
             event_id=parent_id,
-            event=event,
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
@@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
             to_token=to_token,
         )
 
-        events = await self.store.get_events_as_list(
-            [c["event_id"] for c in result.chunk]
-        )
-
-        now = self.clock.time_msec()
-        serialized_events = self._event_serializer.serialize_events(events, now)
-
-        return_value = await result.to_dict(self.store)
-        return_value["chunk"] = serialized_events
-
-        return 200, return_value
+        return 200, result
 
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 8a06ab8c5f..47e152c8cc 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet):
         self._store = hs.get_datastores().main
         self.event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
+        self._relations_handler = hs.get_relations_handler()
         self.auth = hs.get_auth()
 
     async def on_GET(
@@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet):
 
         if event:
             # Ensure there are bundled aggregations available.
-            aggregations = await self._store.get_bundled_aggregations(
+            aggregations = await self._relations_handler.get_bundled_aggregations(
                 [event], requester.user.to_string()
             )
 
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index a47d9bd01d..116c982ce6 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict
+from synapse.types import JsonMapping
 
 from ._base import client_patterns
 
@@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.user_directory_handler = hs.get_user_directory_handler()
 
-    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
         """Searches for users in directory
 
         Returns:
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
index 872a9e72e8..4cc9c66fbe 100644
--- a/synapse/rest/media/v1/preview_html.py
+++ b/synapse/rest/media/v1/preview_html.py
@@ -16,7 +16,6 @@ import itertools
 import logging
 import re
 from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
-from urllib import parse as urlparse
 
 if TYPE_CHECKING:
     from lxml import etree
@@ -144,9 +143,7 @@ def decode_body(
     return etree.fromstring(body, parser)
 
 
-def parse_html_to_open_graph(
-    tree: "etree.Element", media_uri: str
-) -> Dict[str, Optional[str]]:
+def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
     """
     Parse the HTML document into an Open Graph response.
 
@@ -155,7 +152,6 @@ def parse_html_to_open_graph(
 
     Args:
         tree: The parsed HTML document.
-        media_url: The URI used to download the body.
 
     Returns:
         The Open Graph response as a dictionary.
@@ -209,7 +205,7 @@ def parse_html_to_open_graph(
             "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
         )
         if meta_image:
-            og["og:image"] = rebase_url(meta_image[0], media_uri)
+            og["og:image"] = meta_image[0]
         else:
             # TODO: consider inlined CSS styles as well as width & height attribs
             images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
@@ -320,37 +316,6 @@ def _iterate_over_text(
             )
 
 
-def rebase_url(url: str, base: str) -> str:
-    """
-    Resolves a potentially relative `url` against an absolute `base` URL.
-
-    For example:
-
-        >>> rebase_url("subpage", "https://example.com/foo/")
-        'https://example.com/foo/subpage'
-        >>> rebase_url("sibling", "https://example.com/foo")
-        'https://example.com/sibling'
-        >>> rebase_url("/bar", "https://example.com/foo/")
-        'https://example.com/bar'
-        >>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
-        'https://alice.com/a'
-    """
-    base_parts = urlparse.urlparse(base)
-    # Convert the parsed URL to a list for (potential) modification.
-    url_parts = list(urlparse.urlparse(url))
-    # Add a scheme, if one does not exist.
-    if not url_parts[0]:
-        url_parts[0] = base_parts.scheme or "http"
-    # Fix up the hostname, if this is not a data URL.
-    if url_parts[0] != "data" and not url_parts[1]:
-        url_parts[1] = base_parts.netloc
-        # If the path does not start with a /, nest it under the base path's last
-        # directory.
-        if not url_parts[2].startswith("/"):
-            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
-    return urlparse.urlunparse(url_parts)
-
-
 def summarize_paragraphs(
     text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
 ) -> Optional[str]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 14ea88b240..d47af8ead6 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -22,7 +22,7 @@ import shutil
 import sys
 import traceback
 from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
-from urllib import parse as urlparse
+from urllib.parse import urljoin, urlparse, urlsplit
 from urllib.request import urlopen
 
 import attr
@@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
 from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.rest.media.v1.oembed import OEmbedProvider
-from synapse.rest.media.v1.preview_html import (
-    decode_body,
-    parse_html_to_open_graph,
-    rebase_url,
-)
+from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
 from synapse.types import JsonDict, UserID
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
@@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             ts = self.clock.time_msec()
 
         # XXX: we could move this into _do_preview if we wanted.
-        url_tuple = urlparse.urlsplit(url)
+        url_tuple = urlsplit(url)
         for entry in self.url_preview_url_blacklist:
             match = True
             for attrib in entry:
@@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource):
 
                 # Parse Open Graph information from the HTML in case the oEmbed
                 # response failed or is incomplete.
-                og_from_html = parse_html_to_open_graph(tree, media_info.uri)
+                og_from_html = parse_html_to_open_graph(tree)
 
                 # Compile the Open Graph response by using the scraped
                 # information from the HTML and overlaying any information
@@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource):
         if "og:image" not in og or not og["og:image"]:
             return
 
+        # The image URL from the HTML might be relative to the previewed page,
+        # convert it to an URL which can be requested directly.
+        image_url = og["og:image"]
+        url_parts = urlparse(image_url)
+        if url_parts.scheme != "data":
+            image_url = urljoin(media_info.uri, image_url)
+
         # FIXME: it might be cleaner to use the same flow as the main /preview_url
         # request itself and benefit from the same caching etc.  But for now we
         # just rely on the caching on the master request to speed things up.
-        image_info = await self._handle_url(
-            rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
-        )
+        image_info = await self._handle_url(image_url, user, allow_data_urls=True)
 
         if _is_media(image_info.media_type):
             # TODO: make sure we don't choke on white-on-transparent images
diff --git a/synapse/server.py b/synapse/server.py
index 2fcf18a7a6..380369db92 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler
 from synapse.handlers.read_marker import ReadMarkerHandler
 from synapse.handlers.receipts import ReceiptsHandler
 from synapse.handlers.register import RegistrationHandler
+from synapse.handlers.relations import RelationsHandler
 from synapse.handlers.room import (
     RoomContextHandler,
     RoomCreationHandler,
@@ -720,6 +721,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return PaginationHandler(self)
 
     @cache_in_self
+    def get_relations_handler(self) -> RelationsHandler:
+        return RelationsHandler(self)
+
+    @cache_in_self
     def get_room_context_handler(self) -> RoomContextHandler:
         return RoomContextHandler(self)
 
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 99802228c9..9749f0c06e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,6 +41,7 @@ from prometheus_client import Histogram
 from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
+from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.util.async_helpers import delay_cancellation
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -732,34 +734,45 @@ class DatabasePool:
         Returns:
             The result of func
         """
-        after_callbacks: List[_CallbackListEntry] = []
-        exception_callbacks: List[_CallbackListEntry] = []
 
-        if not current_context():
-            logger.warning("Starting db txn '%s' from sentinel context", desc)
+        async def _runInteraction() -> R:
+            after_callbacks: List[_CallbackListEntry] = []
+            exception_callbacks: List[_CallbackListEntry] = []
 
-        try:
-            with opentracing.start_active_span(f"db.{desc}"):
-                result = await self.runWithConnection(
-                    self.new_transaction,
-                    desc,
-                    after_callbacks,
-                    exception_callbacks,
-                    func,
-                    *args,
-                    db_autocommit=db_autocommit,
-                    isolation_level=isolation_level,
-                    **kwargs,
-                )
+            if not current_context():
+                logger.warning("Starting db txn '%s' from sentinel context", desc)
 
-            for after_callback, after_args, after_kwargs in after_callbacks:
-                after_callback(*after_args, **after_kwargs)
-        except Exception:
-            for after_callback, after_args, after_kwargs in exception_callbacks:
-                after_callback(*after_args, **after_kwargs)
-            raise
+            try:
+                with opentracing.start_active_span(f"db.{desc}"):
+                    result = await self.runWithConnection(
+                        self.new_transaction,
+                        desc,
+                        after_callbacks,
+                        exception_callbacks,
+                        func,
+                        *args,
+                        db_autocommit=db_autocommit,
+                        isolation_level=isolation_level,
+                        **kwargs,
+                    )
 
-        return cast(R, result)
+                for after_callback, after_args, after_kwargs in after_callbacks:
+                    after_callback(*after_args, **after_kwargs)
+
+                return cast(R, result)
+            except Exception:
+                for after_callback, after_args, after_kwargs in exception_callbacks:
+                    after_callback(*after_args, **after_kwargs)
+                raise
+
+        # To handle cancellation, we ensure that `after_callback`s and
+        # `exception_callback`s are always run, since the transaction will complete
+        # on another thread regardless of cancellation.
+        #
+        # We also wait until everything above is done before releasing the
+        # `CancelledError`, so that logging contexts won't get used after they have been
+        # finished.
+        return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
 
     async def runWithConnection(
         self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 52146aacc8..9af9f4f18e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,17 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    FrozenSet,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    cast,
+)
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         )
 
     @cached(max_entries=5000, iterable=True)
-    async def ignored_by(self, user_id: str) -> Set[str]:
+    async def ignored_by(self, user_id: str) -> FrozenSet[str]:
         """
         Get users which ignore the given user.
 
@@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         Return:
             The user IDs which ignore the given user.
         """
-        return set(
+        return frozenset(
             await self.db_pool.simple_select_onecol(
                 table="ignored_users",
                 keyvalues={"ignored_user_id": user_id},
@@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
         )
 
+    @cached(max_entries=5000, iterable=True)
+    async def ignored_users(self, user_id: str) -> FrozenSet[str]:
+        """
+        Get users which the given user ignores.
+
+        Params:
+            user_id: The user ID which is making the request.
+
+        Return:
+            The user IDs which are ignored by the given user.
+        """
+        return frozenset(
+            await self.db_pool.simple_select_onecol(
+                table="ignored_users",
+                keyvalues={"ignorer_user_id": user_id},
+                retcol="ignored_user_id",
+                desc="ignored_users",
+            )
+        )
+
     def process_replication_rows(
         self,
         stream_name: str,
@@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         else:
             currently_ignored_users = set()
 
+        # If the data has not changed, nothing to do.
+        if previously_ignored_users == currently_ignored_users:
+            return
+
         # Delete entries which are no longer ignored.
         self.db_pool.simple_delete_many_txn(
             txn,
@@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         # Invalidate the cache for any ignored users which were added or removed.
         for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
             self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+        self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
 
     async def purge_account_data_for_user(self, user_id: str) -> None:
         """
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index d6a2df1afe..2d7511d613 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStream,
     EventsStreamCurrentStateRow,
     EventsStreamEventRow,
+    EventsStreamRow,
 )
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
@@ -31,6 +32,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import _CachedFunction
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_caches_txn(txn):
+        def get_all_updated_caches_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             # We purposefully don't bound by the current token, as we want to
             # send across cache invalidations as quickly as possible. Cache
             # invalidations are idempotent, so duplicates are fine.
@@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             "get_all_updated_caches", get_all_updated_caches_txn
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+    ) -> None:
         if stream_name == EventsStream.NAME:
             for row in rows:
                 self._process_event_stream_row(token, row)
@@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def _process_event_stream_row(self, token, row):
+    def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
         data = row.data
 
         if row.type == EventsStreamEventRow.TypeId:
+            assert isinstance(data, EventsStreamEventRow)
             self._invalidate_caches_for_event(
                 token,
                 data.event_id,
@@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 backfilled=False,
             )
         elif row.type == EventsStreamCurrentStateRow.TypeId:
-            self._curr_state_delta_stream_cache.entity_has_changed(
-                row.data.room_id, token
-            )
+            assert isinstance(data, EventsStreamCurrentStateRow)
+            self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
 
             if data.type == EventTypes.Member:
                 self.get_rooms_for_user_with_stream_ordering.invalidate(
@@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
     def _invalidate_caches_for_event(
         self,
-        stream_ordering,
-        event_id,
-        room_id,
-        etype,
-        state_key,
-        redacts,
-        relates_to,
-        backfilled,
-    ):
+        stream_ordering: int,
+        event_id: str,
+        room_id: str,
+        etype: str,
+        state_key: Optional[str],
+        redacts: Optional[str],
+        relates_to: Optional[str],
+        backfilled: bool,
+    ) -> None:
         self._invalidate_get_event_cache(event_id)
         self.have_seen_event.invalidate((room_id, event_id))
 
@@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self.get_thread_summary.invalidate((relates_to,))
             self.get_thread_participated.invalidate((relates_to,))
 
-    async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+    async def invalidate_cache_and_stream(
+        self, cache_name: str, keys: Tuple[Any, ...]
+    ) -> None:
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
 
@@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             keys,
         )
 
-    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+    def _invalidate_cache_and_stream(
+        self,
+        txn: LoggingTransaction,
+        cache_func: _CachedFunction,
+        keys: Tuple[Any, ...],
+    ) -> None:
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
 
@@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         txn.call_after(cache_func.invalidate, keys)
         self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
 
-    def _invalidate_all_cache_and_stream(self, txn, cache_func):
+    def _invalidate_all_cache_and_stream(
+        self, txn: LoggingTransaction, cache_func: _CachedFunction
+    ) -> None:
         """Invalidates the entire cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
         """
@@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             )
 
     def _send_invalidation_to_replication(
-        self, txn, cache_name: str, keys: Optional[Iterable[Any]]
-    ):
+        self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
+    ) -> None:
         """Notifies replication that given cache has been invalidated.
 
         Note that this does *not* invalidate the cache locally.
@@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                     "instance_name": self._instance_name,
                     "cache_func": cache_name,
                     "keys": keys,
-                    "invalidation_ts": self.clock.time_msec(),
+                    "invalidation_ts": self._clock.time_msec(),
                 },
             )
 
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 3f6086050b..0aef121d83 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,13 +13,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
 
 from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
@@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
     ) -> List[Dict[str, Any]]:
         # TODO: Pagination
 
-        keyvalues = {"group_id": group_id}
+        keyvalues: JsonDict = {"group_id": group_id}
         if not include_private:
             keyvalues["is_public"] = True
 
@@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         # TODO: Pagination
 
-        def _get_rooms_in_group_txn(txn):
+        def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
             sql = """
             SELECT room_id, is_public FROM group_rooms
                 WHERE group_id = ?
@@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
                         * "order": int, the sort order of rooms in this category
         """
 
-        def _get_rooms_for_summary_txn(txn):
-            keyvalues = {"group_id": group_id}
+        def _get_rooms_for_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+            keyvalues: JsonDict = {"group_id": group_id}
             if not include_private:
                 keyvalues["is_public"] = True
 
@@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_rooms_for_summary", _get_rooms_for_summary_txn
         )
 
-    async def get_group_categories(self, group_id):
+    async def get_group_categories(self, group_id: str) -> JsonDict:
         rows = await self.db_pool.simple_select_list(
             table="group_room_categories",
             keyvalues={"group_id": group_id},
@@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    async def get_group_category(self, group_id, category_id):
+    async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
         category = await self.db_pool.simple_select_one(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
@@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         return category
 
-    async def get_group_roles(self, group_id):
+    async def get_group_roles(self, group_id: str) -> JsonDict:
         rows = await self.db_pool.simple_select_list(
             table="group_roles",
             keyvalues={"group_id": group_id},
@@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    async def get_group_role(self, group_id, role_id):
+    async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
         role = await self.db_pool.simple_select_one(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
@@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_local_groups_for_room",
         )
 
-    async def get_users_for_summary_by_role(self, group_id, include_private=False):
+    async def get_users_for_summary_by_role(
+        self, group_id: str, include_private: bool = False
+    ) -> Tuple[List[JsonDict], JsonDict]:
         """Get the users and roles that should be included in a summary request
 
         Returns:
             ([users], [roles])
         """
 
-        def _get_users_for_summary_txn(txn):
-            keyvalues = {"group_id": group_id}
+        def _get_users_for_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], JsonDict]:
+            keyvalues: JsonDict = {"group_id": group_id}
             if not include_private:
                 keyvalues["is_public"] = True
 
@@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    async def get_users_membership_info_in_group(self, group_id, user_id):
+    async def get_users_membership_info_in_group(
+        self, group_id: str, user_id: str
+    ) -> JsonDict:
         """Get a dict describing the membership of a user in a group.
 
         Example if joined:
@@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
              An empty dict if the user is not join/invite/etc
         """
 
-        def _get_users_membership_in_group_txn(txn):
+        def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 table="group_users",
@@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_publicised_groups_for_user",
         )
 
-    async def get_attestations_need_renewals(self, valid_until_ms):
+    async def get_attestations_need_renewals(
+        self, valid_until_ms: int
+    ) -> List[Dict[str, Any]]:
         """Get all attestations that need to be renewed until givent time"""
 
-        def _get_attestations_need_renewals_txn(txn):
+        def _get_attestations_need_renewals_txn(
+            txn: LoggingTransaction,
+        ) -> List[Dict[str, Any]]:
             sql = """
                 SELECT group_id, user_id FROM group_attestations_renewals
                 WHERE valid_until_ms <= ?
@@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_attestations_need_renewals", _get_attestations_need_renewals_txn
         )
 
-    async def get_remote_attestation(self, group_id, user_id):
+    async def get_remote_attestation(
+        self, group_id: str, user_id: str
+    ) -> Optional[JsonDict]:
         """Get the attestation that proves the remote agrees that the user is
         in the group.
         """
@@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_joined_groups",
         )
 
-    async def get_all_groups_for_user(self, user_id, now_token):
-        def _get_all_groups_for_user_txn(txn):
+    async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
+        def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
             sql = """
                 SELECT group_id, type, membership, u.content
                 FROM local_group_updates AS u
@@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_all_groups_for_user", _get_all_groups_for_user_txn
         )
 
-    async def get_groups_changes_for_user(self, user_id, from_token, to_token):
-        from_token = int(from_token)
-        has_changed = self._group_updates_stream_cache.has_entity_changed(
+    async def get_groups_changes_for_user(
+        self, user_id: str, from_token: int, to_token: int
+    ) -> List[JsonDict]:
+        has_changed = self._group_updates_stream_cache.has_entity_changed(  # type: ignore[attr-defined]
             user_id, from_token
         )
         if not has_changed:
             return []
 
-        def _get_groups_changes_for_user_txn(txn):
+        def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
             sql = """
                 SELECT group_id, membership, type, u.content
                 FROM local_group_updates AS u
@@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
         """
 
         last_id = int(last_id)
-        has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+        has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)  # type: ignore[attr-defined]
 
         if not has_changed:
             return [], current_id, False
 
-        def _get_all_groups_changes_txn(txn):
+        def _get_all_groups_changes_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             sql = """
                 SELECT stream_id, group_id, user_id, type, content
                 FROM local_group_updates
@@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [
-                (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
-                for stream_id, group_id, user_id, gtype, content_json in txn
-            ]
+            updates = cast(
+                List[Tuple[int, tuple]],
+                [
+                    (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
+                    for stream_id, group_id, user_id, gtype, content_json in txn
+                ],
+            )
 
             limited = False
             upto_token = current_id
@@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
         self,
         group_id: str,
         room_id: str,
-        category_id: str,
-        order: int,
+        category_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
     ) -> None:
         """Add (or update) room's entry in summary.
@@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
 
     def _add_room_to_summary_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         group_id: str,
         room_id: str,
-        category_id: str,
-        order: int,
+        category_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
     ) -> None:
         """Add (or update) room's entry in summary.
@@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
                 WHERE group_id = ? AND category_id = ?
             """
             txn.execute(sql, (group_id, category_id))
-            (order,) = txn.fetchone()
+            (order,) = cast(Tuple[int], txn.fetchone())
 
         if existing:
             to_update = {}
@@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
                     "category_id": category_id,
                     "room_id": room_id,
                 },
-                values=to_update,
+                updatevalues=to_update,
             )
         else:
             if is_public is None:
@@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
             )
 
     async def remove_room_from_summary(
-        self, group_id: str, room_id: str, category_id: str
+        self, group_id: str, room_id: str, category_id: Optional[str]
     ) -> int:
         if category_id is None:
             category_id = _DEFAULT_CATEGORY_ID
@@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
         is_public: Optional[bool],
     ) -> None:
         """Add/update room category for group"""
-        insertion_values = {}
-        update_values = {"category_id": category_id}  # This cannot be empty
+        insertion_values: JsonDict = {}
+        update_values: JsonDict = {"category_id": category_id}  # This cannot be empty
 
         if profile is None:
             insertion_values["profile"] = "{}"
@@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
         is_public: Optional[bool],
     ) -> None:
         """Add/remove user role"""
-        insertion_values = {}
-        update_values = {"role_id": role_id}  # This cannot be empty
+        insertion_values: JsonDict = {}
+        update_values: JsonDict = {"role_id": role_id}  # This cannot be empty
 
         if profile is None:
             insertion_values["profile"] = "{}"
@@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
         self,
         group_id: str,
         user_id: str,
-        role_id: str,
-        order: int,
+        role_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
     ) -> None:
         """Add (or update) user's entry in summary.
@@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
 
     def _add_user_to_summary_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         group_id: str,
         user_id: str,
-        role_id: str,
-        order: int,
+        role_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
-    ):
+    ) -> None:
         """Add (or update) user's entry in summary.
 
         Args:
@@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
                 WHERE group_id = ? AND role_id = ?
             """
             txn.execute(sql, (group_id, role_id))
-            (order,) = txn.fetchone()
+            (order,) = cast(Tuple[int], txn.fetchone())
 
         if existing:
             to_update = {}
@@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
                     "role_id": role_id,
                     "user_id": user_id,
                 },
-                values=to_update,
+                updatevalues=to_update,
             )
         else:
             if is_public is None:
@@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
             )
 
     async def remove_user_from_summary(
-        self, group_id: str, user_id: str, role_id: str
+        self, group_id: str, user_id: str, role_id: Optional[str]
     ) -> int:
         if role_id is None:
             role_id = _DEFAULT_ROLE_ID
@@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
                 Optional if the user and group are on the same server
         """
 
-        def _add_user_to_group_txn(txn):
+        def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_insert_txn(
                 txn,
                 table="group_users",
@@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
         await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
 
     async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
-        def _remove_user_from_group_txn(txn):
+        def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="group_users",
@@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
         )
 
     async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
-        def _remove_room_from_group_txn(txn):
+        def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="group_rooms",
@@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
 
         content = content or {}
 
-        def _register_user_group_membership_txn(txn, next_id):
+        def _register_user_group_membership_txn(
+            txn: LoggingTransaction, next_id: int
+        ) -> int:
             # TODO: Upsert?
             self.db_pool.simple_delete_txn(
                 txn,
@@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
                     ),
                 },
             )
-            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)  # type: ignore[attr-defined]
 
             # TODO: Insert profile to ensure it comes down stream if its a join.
 
@@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        async with self._group_updates_id_gen.get_next() as next_id:
+        async with self._group_updates_id_gen.get_next() as next_id:  # type: ignore[attr-defined]
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
@@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
         return res
 
     async def create_group(
-        self, group_id, user_id, name, avatar_url, short_description, long_description
+        self,
+        group_id: str,
+        user_id: str,
+        name: str,
+        avatar_url: str,
+        short_description: str,
+        long_description: str,
     ) -> None:
         await self.db_pool.simple_insert(
             table="groups",
@@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="create_group",
         )
 
-    async def update_group_profile(self, group_id, profile):
+    async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
         await self.db_pool.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
@@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="remove_attestation_renewal",
         )
 
-    def get_group_stream_token(self):
-        return self._group_updates_id_gen.get_current_token()
+    def get_group_stream_token(self) -> int:
+        return self._group_updates_id_gen.get_current_token()  # type: ignore[attr-defined]
 
     async def delete_group(self, group_id: str) -> None:
         """Deletes a group fully from the database.
@@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
             group_id: The group ID to delete.
         """
 
-        def _delete_group_txn(txn):
+        def _delete_group_txn(txn: LoggingTransaction) -> None:
             tables = [
                 "groups",
                 "group_users",
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e9a0cdc6be..216622964a 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,15 +12,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
+    LoggingTransaction,
     make_in_list_sql_clause,
 )
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
 from synapse.util.caches.descriptors import cached
 from synapse.util.threepids import canonicalise_email
 
@@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             Number of current monthly active users
         """
 
-        def _count_users(txn):
+        def _count_users(txn: LoggingTransaction) -> int:
             # Exclude app service users
             sql = """
                 SELECT COUNT(*)
@@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
                 WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
             """
             txn.execute(sql)
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_users", _count_users)
@@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
         """
 
-        def _count_users_by_service(txn):
+        def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
             sql = """
                 SELECT COALESCE(appservice_id, 'native'), COUNT(*)
                 FROM monthly_active_users
@@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             """
 
             txn.execute(sql)
-            result = txn.fetchall()
+            result = cast(List[Tuple[str, int]], txn.fetchall())
             return dict(result)
 
         return await self.db_pool.runInteraction(
@@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         )
 
     @wrap_as_background_process("reap_monthly_active_users")
-    async def reap_monthly_active_users(self):
+    async def reap_monthly_active_users(self) -> None:
         """Cleans out monthly active user table to ensure that no stale
         entries exist.
         """
 
-        def _reap_users(txn, reserved_users):
+        def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
             """
             Args:
                 reserved_users (tuple): reserved users to preserve
@@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             # is racy.
             # Have resolved to invalidate the whole cache for now and do
             # something about it if and when the perf becomes significant
-            self._invalidate_all_cache_and_stream(
+            self._invalidate_all_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.user_last_seen_monthly_active
             )
-            self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+            self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())  # type: ignore[attr-defined]
 
         reserved_users = await self.get_registered_reserved_users()
         await self.db_pool.runInteraction(
@@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         )
 
 
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
         )
 
-    def _initialise_reserved_users(self, txn, threepids):
+    def _initialise_reserved_users(
+        self, txn: LoggingTransaction, threepids: List[dict]
+    ) -> None:
         """Ensures that reserved threepids are accounted for in the MAU table, should
         be called on start up.
 
         Args:
-            txn (cursor):
-            threepids (list[dict]): List of threepid dicts to reserve
+            txn:
+            threepids: List of threepid dicts to reserve
         """
 
         # XXX what is this function trying to achieve?  It upserts into
@@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
         )
 
-    def upsert_monthly_active_user_txn(self, txn, user_id):
+    def upsert_monthly_active_user_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> None:
         """Updates or inserts monthly active user member
 
         We consciously do not call is_support_txn from this method because it
@@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             txn, self.user_last_seen_monthly_active, (user_id,)
         )
 
-    async def populate_monthly_active_users(self, user_id):
+    async def populate_monthly_active_users(self, user_id: str) -> None:
         """Checks on the state of monthly active user limits and optionally
         add the user to the monthly active tables
 
@@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
         """
         if self._limit_usage_by_mau or self._mau_stats_only:
             # Trial users and guests should not be included as part of MAU group
-            is_guest = await self.is_guest(user_id)
+            is_guest = await self.is_guest(user_id)  # type: ignore[attr-defined]
             if is_guest:
                 return
             is_trial = await self.is_trial_user(user_id)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c4869d64e6..b2295fd51f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -27,7 +27,6 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
 
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
@@ -41,45 +40,15 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
-    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _ThreadAggregation:
-    # The latest event in the thread.
-    latest_event: EventBase
-    # The latest edit to the latest event in the thread.
-    latest_edit: Optional[EventBase]
-    # The total number of events in the thread.
-    count: int
-    # True if the current user has sent an event to the thread.
-    current_user_participated: bool
-
-
-@attr.s(slots=True, auto_attribs=True)
-class BundledAggregations:
-    """
-    The bundled aggregations for an event.
-
-    Some values require additional processing during serialization.
-    """
-
-    annotations: Optional[JsonDict] = None
-    references: Optional[JsonDict] = None
-    replace: Optional[EventBase] = None
-    thread: Optional[_ThreadAggregation] = None
-
-    def __bool__(self) -> bool:
-        return bool(self.annotations or self.references or self.replace or self.thread)
-
-
 class RelationsWorkerStore(SQLBaseStore):
     def __init__(
         self,
@@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
-    async def _get_applicable_edits(
+    async def get_applicable_edits(
         self, event_ids: Collection[str]
     ) -> Dict[str, Optional[EventBase]]:
         """Get the most recent edit (if any) that has happened for the given
@@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
-    async def _get_thread_summaries(
+    async def get_thread_summaries(
         self, event_ids: Collection[str]
     ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
         """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
         latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
         # Check to see if any of those events are edited.
-        latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+        latest_edits = await self.get_applicable_edits(latest_event_ids.values())
 
         # Map to the event IDs to the thread summary.
         #
@@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
-    async def _get_threads_participated(
+    async def get_threads_participated(
         self, event_ids: Collection[str], user_id: str
     ) -> Dict[str, bool]:
         """Get whether the requesting user participated in the given threads.
@@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
-    async def _get_bundled_aggregation_for_event(
-        self, event: EventBase, user_id: str
-    ) -> Optional[BundledAggregations]:
-        """Generate bundled aggregations for an event.
-
-        Note that this does not use a cache, but depends on cached methods.
-
-        Args:
-            event: The event to calculate bundled aggregations for.
-            user_id: The user requesting the bundled aggregations.
-
-        Returns:
-            The bundled aggregations for an event, if bundled aggregations are
-            enabled and the event can have bundled aggregations.
-        """
-
-        # Do not bundle aggregations for an event which represents an edit or an
-        # annotation. It does not make sense for them to have related events.
-        relates_to = event.content.get("m.relates_to")
-        if isinstance(relates_to, (dict, frozendict)):
-            relation_type = relates_to.get("rel_type")
-            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
-                return None
-
-        event_id = event.event_id
-        room_id = event.room_id
-
-        # The bundled aggregations to include, a mapping of relation type to a
-        # type-specific value. Some types include the direct return type here
-        # while others need more processing during serialization.
-        aggregations = BundledAggregations()
-
-        annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
-        if annotations.chunk:
-            aggregations.annotations = await annotations.to_dict(
-                cast("DataStore", self)
-            )
-
-        references = await self.get_relations_for_event(
-            event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
-        )
-        if references.chunk:
-            aggregations.references = await references.to_dict(cast("DataStore", self))
-
-        # Store the bundled aggregations in the event metadata for later use.
-        return aggregations
-
-    async def get_bundled_aggregations(
-        self, events: Iterable[EventBase], user_id: str
-    ) -> Dict[str, BundledAggregations]:
-        """Generate bundled aggregations for events.
-
-        Args:
-            events: The iterable of events to calculate bundled aggregations for.
-            user_id: The user requesting the bundled aggregations.
-
-        Returns:
-            A map of event ID to the bundled aggregation for the event. Not all
-            events may have bundled aggregations in the results.
-        """
-        # De-duplicate events by ID to handle the same event requested multiple times.
-        #
-        # State events do not get bundled aggregations.
-        events_by_id = {
-            event.event_id: event for event in events if not event.is_state()
-        }
-
-        # event ID -> bundled aggregation in non-serialized form.
-        results: Dict[str, BundledAggregations] = {}
-
-        # Fetch other relations per event.
-        for event in events_by_id.values():
-            event_result = await self._get_bundled_aggregation_for_event(event, user_id)
-            if event_result:
-                results[event.event_id] = event_result
-
-        # Fetch any edits (but not for redacted events).
-        edits = await self._get_applicable_edits(
-            [
-                event_id
-                for event_id, event in events_by_id.items()
-                if not event.internal_metadata.is_redacted()
-            ]
-        )
-        for event_id, edit in edits.items():
-            results.setdefault(event_id, BundledAggregations()).replace = edit
-
-        # Fetch thread summaries.
-        summaries = await self._get_thread_summaries(events_by_id.keys())
-        # Only fetch participated for a limited selection based on what had
-        # summaries.
-        participated = await self._get_threads_participated(summaries.keys(), user_id)
-        for event_id, summary in summaries.items():
-            if summary:
-                thread_count, latest_thread_event, edit = summary
-                results.setdefault(
-                    event_id, BundledAggregations()
-                ).thread = _ThreadAggregation(
-                    latest_event=latest_thread_event,
-                    latest_edit=edit,
-                    count=thread_count,
-                    # If there's a thread summary it must also exist in the
-                    # participated dictionary.
-                    current_user_participated=participated[event_id],
-                )
-
-        return results
-
 
 class RelationsStore(RelationsWorkerStore):
     pass
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e7fddd2426..55cc9178f0 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,8 @@ from typing import (
     cast,
 )
 
+from typing_extensions import TypedDict
+
 from synapse.api.errors import StoreError
 
 if TYPE_CHECKING:
@@ -40,7 +42,12 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.state import StateFilter
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
+from synapse.types import (
+    JsonDict,
+    UserProfile,
+    get_domain_from_id,
+    get_localpart_from_id,
+)
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
 
+class SearchResult(TypedDict):
+    limited: bool
+    results: List[UserProfile]
+
+
 class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     # How many records do we calculate before sending it to
     # add_users_who_share_private_rooms?
@@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
 
     async def search_user_dir(
         self, user_id: str, search_term: str, limit: int
-    ) -> JsonDict:
+    ) -> SearchResult:
         """Searches for users in directory
 
         Returns:
@@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
-        results = await self.db_pool.execute(
-            "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+        results = cast(
+            List[UserProfile],
+            await self.db_pool.execute(
+                "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+            ),
         )
 
         limited = len(results) > limit
diff --git a/synapse/types.py b/synapse/types.py
index 53be3583a0..5ce2a5b0a5 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -34,6 +34,7 @@ from typing import (
 import attr
 from frozendict import frozendict
 from signedjson.key import decode_verify_key_bytes
+from typing_extensions import TypedDict
 from unpaddedbase64 import decode_base64
 from zope.interface import Interface
 
@@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T]
 # JSON types. These could be made stronger, but will do for now.
 # A JSON-serialisable dict.
 JsonDict = Dict[str, Any]
+# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
+# Useful when you have a TypedDict which isn't going to be mutated and you don't want
+# to cast to JsonDict everywhere.
+JsonMapping = Mapping[str, Any]
 # A JSON-serialisable object.
 JsonSerializable = object
 
@@ -791,3 +796,9 @@ class UserInfo:
     is_deactivated: bool
     is_guest: bool
     is_shadow_banned: bool
+
+
+class UserProfile(TypedDict):
+    user_id: str
+    display_name: Optional[str]
+    avatar_url: Optional[str]
diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index 12cd804939..66f1da7502 100644
--- a/synapse/util/check_dependencies.py
+++ b/synapse/util/check_dependencies.py
@@ -128,6 +128,19 @@ def _incorrect_version(
         )
 
 
+def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str:
+    if extra:
+        return (
+            f"Synapse {VERSION} needs {requirement} for {extra}, "
+            f"but can't determine {requirement.name}'s version"
+        )
+    else:
+        return (
+            f"Synapse {VERSION} needs {requirement}, "
+            f"but can't determine {requirement.name}'s version"
+        )
+
+
 def check_requirements(extra: Optional[str] = None) -> None:
     """Check Synapse's dependencies are present and correctly versioned.
 
@@ -163,8 +176,17 @@ def check_requirements(extra: Optional[str] = None) -> None:
                 deps_unfulfilled.append(requirement.name)
                 errors.append(_not_installed(requirement, extra))
         else:
+            if dist.version is None:
+                # This shouldn't happen---it suggests a borked virtualenv. (See #12223)
+                # Try to give a vaguely helpful error message anyway.
+                # Type-ignore: the annotations don't reflect reality: see
+                #     https://github.com/python/typeshed/issues/7513
+                #     https://bugs.python.org/issue47060
+                deps_unfulfilled.append(requirement.name)  # type: ignore[unreachable]
+                errors.append(_no_reported_version(requirement, extra))
+
             # We specify prereleases=True to allow prereleases such as RCs.
-            if not requirement.specifier.contains(dist.version, prereleases=True):
+            elif not requirement.specifier.contains(dist.version, prereleases=True):
                 deps_unfulfilled.append(requirement.name)
                 errors.append(_incorrect_version(requirement, dist.version, extra))
 
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 281cbe4d88..49519eb8f5 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -14,12 +14,7 @@
 import logging
 from typing import Dict, FrozenSet, List, Optional
 
-from synapse.api.constants import (
-    AccountDataTypes,
-    EventTypes,
-    HistoryVisibility,
-    Membership,
-)
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
 from synapse.events import EventBase
 from synapse.events.utils import prune_event
 from synapse.storage import Storage
@@ -87,15 +82,8 @@ async def filter_events_for_client(
         state_filter=StateFilter.from_types(types),
     )
 
-    ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
-        user_id, AccountDataTypes.IGNORED_USER_LIST
-    )
-
-    ignore_list: FrozenSet[str] = frozenset()
-    if ignore_dict_content:
-        ignored_users_dict = ignore_dict_content.get("ignored_users", {})
-        if isinstance(ignored_users_dict, dict):
-            ignore_list = frozenset(ignored_users_dict.keys())
+    # Get the users who are ignored by the requesting user.
+    ignore_list = await storage.main.ignored_users(user_id)
 
     erased_senders = await storage.main.are_users_erased(e.sender for e in events)
 
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index a267228846..a54aa29cf1 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -11,9 +11,14 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
+from typing import Any, Dict
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.handlers.cas import CasResponse
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/"
 
 
 class CasHandlerTestCase(HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
         cas_config = {
@@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         return config
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver()
 
         self.handler = hs.get_cas_handler()
@@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         return hs
 
-    def test_map_cas_user_to_user(self):
+    def test_map_cas_user_to_user(self) -> None:
         """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
 
         # stub out the auth handler
@@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             auth_provider_session_id=None,
         )
 
-    def test_map_cas_user_to_existing_user(self):
+    def test_map_cas_user_to_existing_user(self) -> None:
         """Existing users can log in with CAS account."""
         store = self.hs.get_datastores().main
         self.get_success(
@@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             auth_provider_session_id=None,
         )
 
-    def test_map_cas_user_to_invalid_localpart(self):
+    def test_map_cas_user_to_invalid_localpart(self) -> None:
         """CAS automaps invalid characters to base-64 encoding."""
 
         # stub out the auth handler
@@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_required_attributes(self):
+    def test_required_attributes(self) -> None:
         """The required attributes must be met from the CAS response."""
 
         # stub out the auth handler
@@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase):
         auth_handler.complete_sso_login.assert_not_called()
 
         # The response doesn't have any department.
-        cas_response = CasResponse("test_user", {"userGroup": "staff"})
+        cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
         request.reset_mock()
         self.get_success(
             self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 6e403a87c5..11ad44223d 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -12,14 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+from typing import Any, Awaitable, Callable, Dict
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.api.errors
 import synapse.rest.admin
 from synapse.api.constants import EventTypes
 from synapse.rest.client import directory, login, room
-from synapse.types import RoomAlias, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomAlias, create_requester
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
 class DirectoryTestCase(unittest.HomeserverTestCase):
     """Tests the directory service."""
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.mock_federation = Mock()
         self.mock_registry = Mock()
 
-        self.query_handlers = {}
+        self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
 
-        def register_query_handler(query_type, handler):
+        def register_query_handler(
+            query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+        ) -> None:
             self.query_handlers[query_type] = handler
 
         self.mock_registry.register_query_handler = register_query_handler
@@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
         return hs
 
-    def test_get_local_association(self):
+    def test_get_local_association(self) -> None:
         self.get_success(
             self.store.create_room_alias_association(
                 self.my_room, "!8765qwer:test", ["test"]
@@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
-    def test_get_remote_association(self):
+    def test_get_remote_association(self) -> None:
         self.mock_federation.make_query.return_value = make_awaitable(
             {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
         )
@@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
             ignore_backoff=True,
         )
 
-    def test_incoming_fed_query(self):
+    def test_incoming_fed_query(self) -> None:
         self.get_success(
             self.store.create_room_alias_association(
                 self.your_room, "!8765asdf:test", ["test"]
@@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
         directory.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = hs.get_directory_handler()
 
         # Create user
@@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
         self.test_user_tok = self.login("user", "pass")
         self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
 
-    def test_create_alias_joined_room(self):
+    def test_create_alias_joined_room(self) -> None:
         """A user can create an alias for a room they're in."""
         self.get_success(
             self.handler.create_association(
@@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
             )
         )
 
-    def test_create_alias_other_room(self):
+    def test_create_alias_other_room(self) -> None:
         """A user cannot create an alias for a room they're NOT in."""
         other_room_id = self.helper.create_room_as(
             self.admin_user, tok=self.admin_user_tok
@@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
             synapse.api.errors.SynapseError,
         )
 
-    def test_create_alias_admin(self):
+    def test_create_alias_admin(self) -> None:
         """An admin can create an alias for a room they're NOT in."""
         other_room_id = self.helper.create_room_as(
             self.test_user, tok=self.test_user_tok
@@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
         directory.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
@@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
         self.test_user_tok = self.login("user", "pass")
         self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
 
-    def _create_alias(self, user):
+    def _create_alias(self, user) -> None:
         # Create a new alias to this room.
         self.get_success(
             self.store.create_room_alias_association(
@@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             )
         )
 
-    def test_delete_alias_not_allowed(self):
+    def test_delete_alias_not_allowed(self) -> None:
         """A user that doesn't meet the expected guidelines cannot delete an alias."""
         self._create_alias(self.admin_user)
         self.get_failure(
@@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             synapse.api.errors.AuthError,
         )
 
-    def test_delete_alias_creator(self):
+    def test_delete_alias_creator(self) -> None:
         """An alias creator can delete their own alias."""
         # Create an alias from a different user.
         self._create_alias(self.test_user)
@@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             synapse.api.errors.SynapseError,
         )
 
-    def test_delete_alias_admin(self):
+    def test_delete_alias_admin(self) -> None:
         """A server admin can delete an alias created by another user."""
         # Create an alias from a different user.
         self._create_alias(self.test_user)
@@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             synapse.api.errors.SynapseError,
         )
 
-    def test_delete_alias_sufficient_power(self):
+    def test_delete_alias_sufficient_power(self) -> None:
         """A user with a sufficient power level should be able to delete an alias."""
         self._create_alias(self.admin_user)
 
@@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         directory.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
@@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         )
         return room_alias
 
-    def _set_canonical_alias(self, content):
+    def _set_canonical_alias(self, content) -> None:
         """Configure the canonical alias state on the room."""
         self.helper.send_state(
             self.room_id,
@@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
             )
         )
 
-    def test_remove_alias(self):
+    def test_remove_alias(self) -> None:
         """Removing an alias that is the canonical alias should remove it there too."""
         # Set this new alias as the canonical alias for this room
         self._set_canonical_alias(
@@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         self.assertNotIn("alias", data["content"])
         self.assertNotIn("alt_aliases", data["content"])
 
-    def test_remove_other_alias(self):
+    def test_remove_other_alias(self) -> None:
         """Removing an alias listed as in alt_aliases should remove it there too."""
         # Create a second alias.
         other_test_alias = "#test2:test"
@@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
 
     servlets = [directory.register_servlets, room.register_servlets]
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
 
         # Add custom alias creation rules to the config.
@@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
 
         return config
 
-    def test_denied(self):
+    def test_denied(self) -> None:
         room_id = self.helper.create_room_as(self.user_id)
 
         channel = self.make_request(
@@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
         )
         self.assertEqual(403, channel.code, channel.result)
 
-    def test_allowed(self):
+    def test_allowed(self) -> None:
         room_id = self.helper.create_room_as(self.user_id)
 
         channel = self.make_request(
@@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
         )
         self.assertEqual(200, channel.code, channel.result)
 
-    def test_denied_during_creation(self):
+    def test_denied_during_creation(self) -> None:
         """A room alias that is not allowed should be rejected during creation."""
         # Invalid room alias.
         self.helper.create_room_as(
@@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
             extra_content={"room_alias_name": "foo"},
         )
 
-    def test_allowed_during_creation(self):
+    def test_allowed_during_creation(self) -> None:
         """A valid room alias should be allowed during creation."""
         room_id = self.helper.create_room_as(
             self.user_id,
@@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
     data = {"room_alias_name": "unofficial_test"}
     allowed_localpart = "allowed"
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
 
         # Add custom room list publication rules to the config.
@@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+    ) -> HomeServer:
         self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
         self.allowed_access_token = self.login(self.allowed_localpart, "pass")
 
@@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
 
         return hs
 
-    def test_denied_without_publication_permission(self):
+    def test_denied_without_publication_permission(self) -> None:
         """
         Try to create a room, register an alias for it, and publish it,
         as a user without permission to publish rooms.
@@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=403,
         )
 
-    def test_allowed_when_creating_private_room(self):
+    def test_allowed_when_creating_private_room(self) -> None:
         """
         Try to create a room, register an alias for it, and NOT publish it,
         as a user without permission to publish rooms.
@@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=200,
         )
 
-    def test_allowed_with_publication_permission(self):
+    def test_allowed_with_publication_permission(self) -> None:
         """
         Try to create a room, register an alias for it, and publish it,
         as a user WITH permission to publish rooms.
@@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=200,
         )
 
-    def test_denied_publication_with_invalid_alias(self):
+    def test_denied_publication_with_invalid_alias(self) -> None:
         """
         Try to create a room, register an alias for it, and publish it,
         as a user WITH permission to publish rooms.
@@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=403,
         )
 
-    def test_can_create_as_private_room_after_rejection(self):
+    def test_can_create_as_private_room_after_rejection(self) -> None:
         """
         After failing to publish a room with an alias as a user without publish permission,
         retry as the same user, but without publishing the room.
@@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
         self.test_denied_without_publication_permission()
         self.test_allowed_when_creating_private_room()
 
-    def test_can_create_with_permission_after_rejection(self):
+    def test_can_create_with_permission_after_rejection(self) -> None:
         """
         After failing to publish a room with an alias as a user without publish permission,
         retry as someone with permission, using the same alias.
@@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
 
     servlets = [directory.register_servlets, room.register_servlets]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+    ) -> HomeServer:
         room_id = self.helper.create_room_as(self.user_id)
 
         channel = self.make_request(
@@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
 
         return hs
 
-    def test_disabling_room_list(self):
+    def test_disabling_room_list(self) -> None:
         self.room_list_handler.enable_room_list_search = True
         self.directory_handler.enable_room_list_search = True
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 9338ab92e9..ac21a28c43 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -20,33 +20,37 @@ from parameterized import parameterized
 from signedjson import key as key, sign as sign
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import Codes, SynapseError
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 
 
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         return self.setup_test_homeserver(federation_client=mock.Mock())
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = hs.get_e2e_keys_handler()
         self.store = self.hs.get_datastores().main
 
-    def test_query_local_devices_no_devices(self):
+    def test_query_local_devices_no_devices(self) -> None:
         """If the user has no devices, we expect an empty list."""
         local_user = "@boris:" + self.hs.hostname
         res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
 
-    def test_reupload_one_time_keys(self):
+    def test_reupload_one_time_keys(self) -> None:
         """we should be able to re-upload the same keys"""
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
-        keys = {
+        keys: JsonDict = {
             "alg1:k1": "key1",
             "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
             "alg2:k3": {"key": "key3"},
@@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
         )
 
-    def test_change_one_time_keys(self):
+    def test_change_one_time_keys(self) -> None:
         """attempts to change one-time-keys should be rejected"""
 
         local_user = "@boris:" + self.hs.hostname
@@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             SynapseError,
         )
 
-    def test_claim_one_time_key(self):
+    def test_claim_one_time_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         keys = {"alg1:k1": "key1"}
@@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def test_fallback_key(self):
+    def test_fallback_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         fallback_key = {"alg1:k1": "fallback_key1"}
@@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
         )
 
-    def test_replace_master_key(self):
+    def test_replace_master_key(self) -> None:
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
         keys1 = {
@@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
-    def test_reupload_signatures(self):
+    def test_reupload_signatures(self) -> None:
         """re-uploading a signature should not fail"""
         local_user = "@boris:" + self.hs.hostname
         keys1 = {
@@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
         self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
 
-    def test_self_signing_key_doesnt_show_up_as_device(self):
+    def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
         """signing keys should be hidden when fetching a user's devices"""
         local_user = "@boris:" + self.hs.hostname
         keys1 = {
@@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
 
-    def test_upload_signatures(self):
+    def test_upload_signatures(self) -> None:
         """should check signatures that are uploaded"""
         # set up a user with cross-signing keys and a device.  This user will
         # try uploading signatures
@@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
         )
 
-    def test_query_devices_remote_no_sync(self):
+    def test_query_devices_remote_no_sync(self) -> None:
         """Tests that querying keys for a remote user that we don't share a room
         with returns the cross signing keys correctly.
         """
@@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def test_query_devices_remote_sync(self):
+    def test_query_devices_remote_sync(self) -> None:
         """Tests that querying keys for a remote user that we share a room with,
         but haven't yet fetched the keys for, returns the cross signing keys
         correctly.
@@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             (["device_1", "device_2"],),
         ]
     )
-    def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+    def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
         """Test that requests for all of a remote user's devices are cached.
 
         We do this by asserting that only one call over federation was made, and that
@@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         """
         local_user_id = "@test:test"
         remote_user_id = "@test:other"
-        request_body = {"device_keys": {remote_user_id: []}}
+        request_body: JsonDict = {"device_keys": {remote_user_id: []}}
 
         response_devices = [
             {
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index e8b4e39d1a..89078fc637 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,9 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List
+from typing import List, cast
 from unittest import TestCase
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
 from synapse.api.room_versions import RoomVersions
@@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json
 from synapse.logging.context import LoggingContext, run_in_background
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
 from synapse.types import create_requester
+from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver(federation_http_client=None)
         self.handler = hs.get_federation_handler()
         self.store = hs.get_datastores().main
@@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self._event_auth_handler = hs.get_event_auth_handler()
         return hs
 
-    def test_exchange_revoked_invite(self):
+    def test_exchange_revoked_invite(self) -> None:
         user_id = self.register_user("kermit", "test")
         tok = self.login("kermit", "test")
 
@@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
         self.assertEqual(failure.msg, "You are not invited to this room.")
 
-    def test_rejected_message_event_state(self):
+    def test_rejected_message_event_state(self) -> None:
         """
         Check that we store the state group correctly for rejected non-state events.
 
@@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
                 "content": {},
                 "room_id": room_id,
                 "sender": "@yetanotheruser:" + OTHER_SERVER,
-                "depth": join_event["depth"] + 1,
+                "depth": cast(int, join_event["depth"]) + 1,
                 "prev_events": [join_event.event_id],
                 "auth_events": [],
                 "origin_server_ts": self.clock.time_msec(),
@@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sg, sg2)
 
-    def test_rejected_state_event_state(self):
+    def test_rejected_state_event_state(self) -> None:
         """
         Check that we store the state group correctly for rejected state events.
 
@@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
                 "content": {},
                 "room_id": room_id,
                 "sender": "@yetanotheruser:" + OTHER_SERVER,
-                "depth": join_event["depth"] + 1,
+                "depth": cast(int, join_event["depth"]) + 1,
                 "prev_events": [join_event.event_id],
                 "auth_events": [],
                 "origin_server_ts": self.clock.time_msec(),
@@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sg, sg2)
 
-    def test_backfill_with_many_backward_extremities(self):
+    def test_backfill_with_many_backward_extremities(self) -> None:
         """
         Check that we can backfill with many backward extremities.
         The goal is to make sure that when we only use a portion
@@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             )
         self.get_success(d)
 
-    def test_backfill_floating_outlier_membership_auth(self):
+    def test_backfill_floating_outlier_membership_auth(self) -> None:
         """
         As the local homeserver, check that we can properly process a federated
         event from the OTHER_SERVER with auth_events that include a floating
@@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
                 for ae in auth_events
             ]
 
-        self.handler.federation_client.get_event_auth = get_event_auth
+        self.handler.federation_client.get_event_auth = get_event_auth  # type: ignore[assignment]
 
         with LoggingContext("receive_pdu"):
             # Fake the OTHER_SERVER federating the message event over to our local homeserver
@@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
     @unittest.override_config(
         {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
     )
-    def test_invite_by_user_ratelimit(self):
+    def test_invite_by_user_ratelimit(self) -> None:
         """Tests that invites from federation to a particular user are
         actually rate-limited.
         """
@@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
             exc=LimitExceededError,
         )
 
-    def _build_and_send_join_event(self, other_server, other_user, room_id):
+    def _build_and_send_join_event(
+        self, other_server: str, other_user: str, room_id: str
+    ) -> EventBase:
         join_event = self.get_success(
             self.handler.on_make_join_request(other_server, room_id, other_user)
         )
@@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
 
 class EventFromPduTestCase(TestCase):
-    def test_valid_json(self):
+    def test_valid_json(self) -> None:
         """Valid JSON should be turned into an event."""
         ev = event_from_pdu_json(
             {
@@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase):
 
         self.assertIsInstance(ev, EventBase)
 
-    def test_invalid_numbers(self):
+    def test_invalid_numbers(self) -> None:
         """Invalid values for an integer should be rejected, all floats should be rejected."""
         for value in [
             -(2 ** 53),
@@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase):
                     RoomVersions.V6,
                 )
 
-    def test_invalid_nested(self):
+    def test_invalid_nested(self) -> None:
         """List and dictionaries are recursively searched."""
         with self.assertRaises(SynapseError):
             event_from_pdu_json(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e8418b6638..014815db6e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,14 +13,18 @@
 # limitations under the License.
 import json
 import os
+from typing import Any, Dict
 from unittest.mock import ANY, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
 import pymacaroons
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
 from synapse.util.macaroons import get_value_from_macaroon
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-async def get_json(url):
+async def get_json(url: str) -> JsonDict:
     # Mock get_json calls to handle jwks & oidc discovery endpoints
     if url == WELL_KNOWN:
         # Minimal discovery document, as defined in OpenID.Discovery
@@ -116,6 +120,8 @@ async def get_json(url):
     elif url == JWKS_URI:
         return {"keys": []}
 
+    return {}
+
 
 def _key_file_path() -> str:
     """path to a file containing the private half of a test key"""
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
     if not HAS_OIDC:
         skip = "requires OIDC"
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
         return config
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.http_client = Mock(spec=["get_json"])
         self.http_client.get_json.side_effect = get_json
         self.http_client.user_agent = b"Synapse Test"
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
-        sso_handler.render_error = self.render_error
+        sso_handler.render_error = self.render_error  # type: ignore[assignment]
 
         # Reduce the number of attempts when generating MXIDs.
         sso_handler._MAP_USERNAME_RETRIES = 3
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return args
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_config(self):
+    def test_config(self) -> None:
         """Basic config correctly sets up the callback URL and client auth correctly."""
         self.assertEqual(self.provider._callback_url, CALLBACK_URL)
         self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
         self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
-    def test_discovery(self):
+    def test_discovery(self) -> None:
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
         metadata = self.get_success(self.provider.load_metadata())
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
-    def test_no_discovery(self):
+    def test_no_discovery(self) -> None:
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
-    def test_load_jwks(self):
+    def test_load_jwks(self) -> None:
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_validate_config(self):
+    def test_validate_config(self) -> None:
         """Provider metadatas are extensively validated."""
         h = self.provider
 
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             force_load_metadata()
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
-    def test_skip_verification(self):
+    def test_skip_verification(self) -> None:
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
             get_awaitable_result(self.provider.load_metadata())
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_redirect_request(self):
+    def test_redirect_request(self) -> None:
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["cookies"])
         req.cookies = []
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(redirect, "http://client/redirect")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback_error(self):
+    def test_callback_error(self) -> None:
         """Errors from the provider returned in the callback are displayed."""
         request = Mock(args={})
         request.args[b"error"] = [b"invalid_client"]
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("invalid_client", "some description")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback(self):
+    def test_callback(self) -> None:
         """Code callback works and display errors if something went wrong.
 
         A lot of scenarios are tested here:
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": username,
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.provider._exchange_code = simple_async_mock(return_value=token)
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
+        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.provider._parse_id_token = simple_async_mock(raises=Exception())
+        self.provider._parse_id_token = simple_async_mock(raises=Exception())  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "type": "bearer",
             "access_token": "access_token",
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         id_token = {
             "sid": "abcdefgh",
         }
-        self.provider._parse_id_token = simple_async_mock(return_value=id_token)
-        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=id_token)  # type: ignore[assignment]
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
         auth_handler.complete_sso_login.reset_mock()
         self.provider._fetch_userinfo.reset_mock()
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
+        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
         from synapse.handlers.oidc import OidcError
 
-        self.provider._exchange_code = simple_async_mock(
+        self.provider._exchange_code = simple_async_mock(  # type: ignore[assignment]
             raises=OidcError("invalid_request")
         )
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback_session(self):
+    def test_callback_session(self) -> None:
         """The callback verifies the session presence and validity"""
         request = Mock(spec=["args", "getCookie", "cookies"])
 
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
     )
-    def test_exchange_code(self):
+    def test_exchange_code(self) -> None:
         """Code exchange behaves correctly and handles various error scenarios."""
         token = {"type": "bearer"}
         token_json = json.dumps(token).encode("utf-8")
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_exchange_code_jwt_key(self):
+    def test_exchange_code_jwt_key(self) -> None:
         """Test that code exchange works with a JWK client secret."""
         from authlib.jose import jwt
 
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_exchange_code_no_auth(self):
+    def test_exchange_code_no_auth(self) -> None:
         """Test that code exchange works with no client secret."""
         token = {"type": "bearer"}
         self.http_client.request = simple_async_mock(
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_extra_attributes(self):
+    def test_extra_attributes(self) -> None:
         """
         Login while using a mapping provider that implements get_extra_attributes.
         """
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "foo",
             "phone": "1234567",
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_map_userinfo_to_user(self):
+    def test_map_userinfo_to_user(self) -> None:
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
-        userinfo = {
+        userinfo: dict = {
             "sub": "test_user",
             "username": "test_user",
         }
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
-    def test_map_userinfo_to_existing_user(self):
+    def test_map_userinfo_to_existing_user(self) -> None:
         """Existing users can log in with OpenID Connect when allow_existing_users is True."""
         store = self.hs.get_datastores().main
         user = UserID.from_string("@test_user:test")
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_map_userinfo_to_invalid_localpart(self):
+    def test_map_userinfo_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
         self.get_success(
             _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_map_userinfo_to_user_retries(self):
+    def test_map_userinfo_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_empty_localpart(self):
+    def test_empty_localpart(self) -> None:
         """Attempts to map onto an empty localpart should be rejected."""
         userinfo = {
             "sub": "tester",
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_null_localpart(self):
+    def test_null_localpart(self) -> None:
         """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
         userinfo = {
             "sub": "tester",
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_attribute_requirements(self):
+    def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the OIDC userinfo response."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_attribute_requirements_contains(self):
+    def test_attribute_requirements_contains(self) -> None:
         """Test that auth succeeds if userinfo attribute CONTAINS required value"""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_attribute_requirements_mismatch(self):
+    def test_attribute_requirements_mismatch(self) -> None:
         """
         Test that auth fails if attributes exist but don't match,
         or are non-string values.
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": "not_foobar" attribute should fail
-        userinfo = {
+        userinfo: dict = {
             "sub": "tester",
             "username": "tester",
             "test": "not_foobar",
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
 
     handler = hs.get_oidc_handler()
     provider = handler._providers["oidc"]
-    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
-    provider._parse_id_token = simple_async_mock(return_value=userinfo)
-    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})  # type: ignore[assignment]
+    provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
+    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
 
     state = "state"
     session = handler._token_generator.generate_oidc_session_token(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 6ddec9ecf1..b2ed9cbe37 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
 
         # Extract presence update user ID and state information into lists of tuples
         db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
-        presence_states = [(ps.user_id, ps.state) for ps in presence_states]
+        presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
 
         # Compare what we put into the storage with what we got out.
         # They should be identical.
-        self.assertEqual(presence_states, db_presence_states)
+        self.assertEqual(presence_states_compare, db_presence_states)
 
 
 class PresenceTimeoutTestCase(unittest.TestCase):
@@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
+        assert new_state is not None
         self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
+        assert new_state is not None
         self.assertEqual(new_state.state, PresenceState.BUSY)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
+        assert new_state is not None
         self.assertEqual(new_state.state, PresenceState.OFFLINE)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         )
 
         self.assertIsNotNone(new_state)
+        assert new_state is not None
         self.assertEqual(new_state.state, PresenceState.ONLINE)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         )
 
         self.assertIsNotNone(new_state)
+        assert new_state is not None
         self.assertEqual(new_state.state, PresenceState.OFFLINE)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
         self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
 
     def _set_presencestate_with_status_msg(
-        self, user_id: str, state: PresenceState, status_msg: Optional[str]
+        self, user_id: str, state: str, status_msg: Optional[str]
     ):
         """Set a PresenceState and status_msg and check the result.
 
         Args:
             user_id: User for that the status is to be set.
-            PresenceState: The new PresenceState.
+            state: The new PresenceState.
             status_msg: Status message that is to be set.
         """
         self.get_success(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 972cbac6e4..08733a9f2d 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -11,14 +11,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict
+from typing import Any, Awaitable, Callable, Dict
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.types
 from synapse.api.errors import AuthError, SynapseError
 from synapse.rest import admin
 from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
     servlets = [admin.register_servlets]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.mock_federation = Mock()
         self.mock_registry = Mock()
 
-        self.query_handlers = {}
+        self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
 
-        def register_query_handler(query_type, handler):
+        def register_query_handler(
+            query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+        ) -> None:
             self.query_handlers[query_type] = handler
 
         self.mock_registry.register_query_handler = register_query_handler
@@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
         return hs
 
-    def prepare(self, reactor, clock, hs: HomeServer):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
         self.frank = UserID.from_string("@1234abcd:test")
@@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         self.handler = hs.get_profile_handler()
 
-    def test_get_my_name(self):
+    def test_get_my_name(self) -> None:
         self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
@@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual("Frank", displayname)
 
-    def test_set_my_name(self):
+    def test_set_my_name(self) -> None:
         self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
@@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             self.get_success(self.store.get_profile_displayname(self.frank.localpart))
         )
 
-    def test_set_my_name_if_disabled(self):
+    def test_set_my_name_if_disabled(self) -> None:
         self.hs.config.registration.enable_set_displayname = False
 
         # Setting displayname for the first time is allowed
@@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             SynapseError,
         )
 
-    def test_set_my_name_noauth(self):
+    def test_set_my_name_noauth(self) -> None:
         self.get_failure(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
@@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             AuthError,
         )
 
-    def test_get_other_name(self):
+    def test_get_other_name(self) -> None:
         self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
         )
@@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             ignore_backoff=True,
         )
 
-    def test_incoming_fed_query(self):
+    def test_incoming_fed_query(self) -> None:
         self.get_success(self.store.create_profile("caroline"))
         self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
 
@@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual({"displayname": "Caroline"}, response)
 
-    def test_get_my_avatar(self):
+    def test_get_my_avatar(self) -> None:
         self.get_success(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
@@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual("http://my.server/me.png", avatar_url)
 
-    def test_set_my_avatar(self):
+    def test_set_my_avatar(self) -> None:
         self.get_success(
             self.handler.set_avatar_url(
                 self.frank,
@@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
         )
 
-    def test_set_my_avatar_if_disabled(self):
+    def test_set_my_avatar_if_disabled(self) -> None:
         self.hs.config.registration.enable_set_avatar_url = False
 
         # Setting displayname for the first time is allowed
@@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             SynapseError,
         )
 
-    def test_avatar_constraints_no_config(self):
+    def test_avatar_constraints_no_config(self) -> None:
         """Tests that the method to check an avatar against configured constraints skips
         all of its check if no constraint is configured.
         """
@@ -263,7 +268,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertTrue(res)
 
     @unittest.override_config({"max_avatar_size": 50})
-    def test_avatar_constraints_missing(self):
+    def test_avatar_constraints_missing(self) -> None:
         """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
         be found.
         """
@@ -273,7 +278,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertFalse(res)
 
     @unittest.override_config({"max_avatar_size": 50})
-    def test_avatar_constraints_file_size(self):
+    def test_avatar_constraints_file_size(self) -> None:
         """Tests that a file that's above the allowed file size is forbidden but one
         that's below it is allowed.
         """
@@ -295,7 +300,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertFalse(res)
 
     @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
-    def test_avatar_constraint_mime_type(self):
+    def test_avatar_constraint_mime_type(self) -> None:
         """Tests that a file with an unauthorised MIME type is forbidden but one with
         an authorised content type is allowed.
         """
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 23941abed8..8d4404eda1 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,12 +12,16 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
-from typing import Optional
+from typing import Any, Dict, Optional
 from unittest.mock import Mock
 
 import attr
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.errors import RedirectException
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
 
 
 class SamlHandlerTestCase(HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
-        saml_config = {
+        saml_config: Dict[str, Any] = {
             "sp_config": {"metadata": {}},
             # Disable grandfathering.
             "grandfathered_mxid_source_attribute": None,
@@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         return config
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver()
 
         self.handler = hs.get_saml_handler()
@@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
     elif not has_xmlsec1:
         skip = "Requires xmlsec1"
 
-    def test_map_saml_response_to_user(self):
+    def test_map_saml_response_to_user(self) -> None:
         """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
 
         # stub out the auth handler
@@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
-    def test_map_saml_response_to_existing_user(self):
+    def test_map_saml_response_to_existing_user(self) -> None:
         """Existing users can log in with SAML account."""
         store = self.hs.get_datastores().main
         self.get_success(
@@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             auth_provider_session_id=None,
         )
 
-    def test_map_saml_response_to_invalid_localpart(self):
+    def test_map_saml_response_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
 
         # stub out the auth handler
@@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         auth_handler.complete_sso_login.assert_not_called()
 
-    def test_map_saml_response_to_user_retries(self):
+    def test_map_saml_response_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
 
         # stub out the auth handler and error renderer
@@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_map_saml_response_redirect(self):
+    def test_map_saml_response_redirect(self) -> None:
         """Test a mapping provider that raises a RedirectException"""
 
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             },
         }
     )
-    def test_attribute_requirements(self):
+    def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the SAML response."""
 
         # stub out the auth handler
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f91a80b9fa..ffd5c4cb93 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -18,11 +18,14 @@ from typing import Dict
 from unittest.mock import ANY, Mock, call
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 from synapse.api.errors import AuthError
 from synapse.federation.transport.server import TransportLayerServer
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -42,7 +45,9 @@ ROOM_ID = "a-room"
 OTHER_ROOM_ID = "another-room"
 
 
-def _expect_edu_transaction(edu_type, content, origin="test"):
+def _expect_edu_transaction(
+    edu_type: str, content: JsonDict, origin: str = "test"
+) -> JsonDict:
     return {
         "origin": origin,
         "origin_server_ts": 1000000,
@@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
     }
 
 
-def _make_edu_transaction_json(edu_type, content):
+def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
     return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
 
 
 class TypingNotificationsTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # we mock out the keyring so as to skip the authentication check on the
         # federation API call.
         mock_keyring = Mock(spec=["verify_json_for_server"])
@@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         d["/_matrix/federation"] = TransportLayerServer(self.hs)
         return d
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         mock_notifier = hs.get_notifier()
         self.on_new_event = mock_notifier.on_new_event
 
@@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.room_members = []
 
-        async def check_user_in_room(room_id, user_id):
+        async def check_user_in_room(room_id: str, user_id: str) -> None:
             if user_id not in [u.to_string() for u in self.room_members]:
                 raise AuthError(401, "User is not in the room")
             return None
 
         hs.get_auth().check_user_in_room = check_user_in_room
 
-        async def check_host_in_room(room_id, server_name):
+        async def check_host_in_room(room_id: str, server_name: str) -> bool:
             return room_id == ROOM_ID
 
         hs.get_event_auth_handler().check_host_in_room = check_host_in_room
 
-        def get_joined_hosts_for_room(room_id):
+        def get_joined_hosts_for_room(room_id: str):
             return {member.domain for member in self.room_members}
 
         self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
 
-        async def get_users_in_room(room_id):
+        async def get_users_in_room(room_id: str):
             return {str(u) for u in self.room_members}
 
         self.datastore.get_users_in_room = get_users_in_room
@@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             lambda *args, **kwargs: make_awaitable(None)
         )
 
-    def test_started_typing_local(self):
+    def test_started_typing_local(self) -> None:
         self.room_members = [U_APPLE, U_BANANA]
 
         self.assertEqual(self.event_source.get_current_key(), 0)
@@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         )
 
     @override_config({"send_federation": True})
-    def test_started_typing_remote_send(self):
+    def test_started_typing_remote_send(self) -> None:
         self.room_members = [U_APPLE, U_ONION]
 
         self.get_success(
@@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             try_trailing_slash_on_400=True,
         )
 
-    def test_started_typing_remote_recv(self):
+    def test_started_typing_remote_recv(self) -> None:
         self.room_members = [U_APPLE, U_ONION]
 
         self.assertEqual(self.event_source.get_current_key(), 0)
@@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             ],
         )
 
-    def test_started_typing_remote_recv_not_in_room(self):
+    def test_started_typing_remote_recv_not_in_room(self) -> None:
         self.room_members = [U_APPLE, U_ONION]
 
         self.assertEqual(self.event_source.get_current_key(), 0)
@@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(events[1], 0)
 
     @override_config({"send_federation": True})
-    def test_stopped_typing(self):
+    def test_stopped_typing(self) -> None:
         self.room_members = [U_APPLE, U_BANANA, U_ONION]
 
         # Gut-wrenching
@@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
         )
 
-    def test_typing_timeout(self):
+    def test_typing_timeout(self) -> None:
         self.room_members = [U_APPLE, U_BANANA]
 
         self.assertEqual(self.event_source.get_current_key(), 0)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 6691e07128..ba158f5d93 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,15 +11,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 from unittest.mock import Mock
 
 from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.logging.context import make_deferred_yieldable
 from synapse.push import PusherConfigException
 from synapse.rest.client import login, push_rule, receipts, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase, override_config
 
@@ -35,13 +39,13 @@ class HTTPPusherTests(HomeserverTestCase):
     user_id = True
     hijack_auth = False
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["start_pushers"] = True
         return config
 
-    def make_homeserver(self, reactor, clock):
-        self.push_attempts: List[tuple[Deferred, str, dict]] = []
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.push_attempts: List[Tuple[Deferred, str, dict]] = []
 
         m = Mock()
 
@@ -56,7 +60,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         return hs
 
-    def test_invalid_configuration(self):
+    def test_invalid_configuration(self) -> None:
         """Invalid push configurations should be rejected."""
         # Register the user who gets notified
         user_id = self.register_user("user", "pass")
@@ -68,7 +72,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         token_id = user_tuple.token_id
 
-        def test_data(data):
+        def test_data(data: Optional[JsonDict]) -> None:
             self.get_failure(
                 self.hs.get_pusherpool().add_pusher(
                     user_id=user_id,
@@ -95,7 +99,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # A url with an incorrect path isn't accepted.
         test_data({"url": "http://example.com/foo"})
 
-    def test_sends_http(self):
+    def test_sends_http(self) -> None:
         """
         The HTTP pusher will send pushes for each message to a HTTP endpoint
         when configured to do so.
@@ -200,7 +204,7 @@ class HTTPPusherTests(HomeserverTestCase):
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
 
-    def test_sends_high_priority_for_encrypted(self):
+    def test_sends_high_priority_for_encrypted(self) -> None:
         """
         The HTTP pusher will send pushes at high priority if they correspond
         to an encrypted message.
@@ -321,7 +325,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
 
-    def test_sends_high_priority_for_one_to_one_only(self):
+    def test_sends_high_priority_for_one_to_one_only(self) -> None:
         """
         The HTTP pusher will send pushes at high priority if they correspond
         to a message in a one-to-one room.
@@ -404,7 +408,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
 
-    def test_sends_high_priority_for_mention(self):
+    def test_sends_high_priority_for_mention(self) -> None:
         """
         The HTTP pusher will send pushes at high priority if they correspond
         to a message containing the user's display name.
@@ -480,7 +484,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
 
-    def test_sends_high_priority_for_atroom(self):
+    def test_sends_high_priority_for_atroom(self) -> None:
         """
         The HTTP pusher will send pushes at high priority if they correspond
         to a message that contains @room.
@@ -563,7 +567,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
 
-    def test_push_unread_count_group_by_room(self):
+    def test_push_unread_count_group_by_room(self) -> None:
         """
         The HTTP pusher will group unread count by number of unread rooms.
         """
@@ -576,7 +580,7 @@ class HTTPPusherTests(HomeserverTestCase):
         self._check_push_attempt(6, 1)
 
     @override_config({"push": {"group_unread_count_by_room": False}})
-    def test_push_unread_count_message_count(self):
+    def test_push_unread_count_message_count(self) -> None:
         """
         The HTTP pusher will send the total unread message count.
         """
@@ -589,7 +593,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # last read receipt
         self._check_push_attempt(6, 3)
 
-    def _test_push_unread_count(self):
+    def _test_push_unread_count(self) -> None:
         """
         Tests that the correct unread count appears in sent push notifications
 
@@ -681,7 +685,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         self.helper.send(room_id, body="HELLO???", tok=other_access_token)
 
-    def _advance_time_and_make_push_succeed(self, expected_push_attempts):
+    def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None:
         self.pump()
         self.push_attempts[expected_push_attempts - 1][0].callback({})
 
@@ -708,7 +712,9 @@ class HTTPPusherTests(HomeserverTestCase):
             expected_unread_count_last_push,
         )
 
-    def _send_read_request(self, access_token, message_event_id, room_id):
+    def _send_read_request(
+        self, access_token: str, message_event_id: str, room_id: str
+    ) -> None:
         # Now set the user's read receipt position to the first event
         #
         # This will actually trigger a new notification to be sent out so that
@@ -748,7 +754,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         return user_id, access_token
 
-    def test_dont_notify_rule_overrides_message(self):
+    def test_dont_notify_rule_overrides_message(self) -> None:
         """
         The override push rule will suppress notification
         """
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 3849beb9d6..5dba187076 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict
+from typing import Dict, Optional, Union
 
 import frozendict
 
@@ -20,12 +20,13 @@ from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent
 from synapse.push import push_rule_evaluator
 from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
+from synapse.types import JsonDict
 
 from tests import unittest
 
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
-    def _get_evaluator(self, content):
+    def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
         event = FrozenEvent(
             {
                 "event_id": "$event_id",
@@ -39,12 +40,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         )
         room_member_count = 0
         sender_power_level = 0
-        power_levels = {}
+        power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
         return PushRuleEvaluatorForEvent(
             event, room_member_count, sender_power_level, power_levels
         )
 
-    def test_display_name(self):
+    def test_display_name(self) -> None:
         """Check for a matching display name in the body of the event."""
         evaluator = self._get_evaluator({"body": "foo bar baz"})
 
@@ -71,20 +72,20 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
 
     def _assert_matches(
-        self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+        self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
     ) -> None:
         evaluator = self._get_evaluator(content)
         self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
 
     def _assert_not_matches(
-        self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+        self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
     ) -> None:
         evaluator = self._get_evaluator(content)
         self.assertFalse(
             evaluator.matches(condition, "@user:test", "display_name"), msg
         )
 
-    def test_event_match_body(self):
+    def test_event_match_body(self) -> None:
         """Check that event_match conditions on content.body work as expected"""
 
         # if the key is `content.body`, the pattern matches substrings.
@@ -165,7 +166,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             r"? after \ should match any character",
         )
 
-    def test_event_match_non_body(self):
+    def test_event_match_non_body(self) -> None:
         """Check that event_match conditions on other keys work as expected"""
 
         # if the key is anything other than 'content.body', the pattern must match the
@@ -241,7 +242,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             "pattern should not match before a newline",
         )
 
-    def test_no_body(self):
+    def test_no_body(self) -> None:
         """Not having a body shouldn't break the evaluator."""
         evaluator = self._get_evaluator({})
 
@@ -250,7 +251,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         }
         self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
 
-    def test_invalid_body(self):
+    def test_invalid_body(self) -> None:
         """A non-string body should not break the evaluator."""
         condition = {
             "kind": "contains_display_name",
@@ -260,7 +261,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             evaluator = self._get_evaluator({"body": body})
             self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
 
-    def test_tweaks_for_actions(self):
+    def test_tweaks_for_actions(self) -> None:
         """
         This tests the behaviour of tweaks_for_actions.
         """
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 171f4e97c8..329690f8f7 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,7 +15,7 @@
 
 import itertools
 import urllib.parse
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 from unittest.mock import patch
 
 from twisted.test.proto_helpers import MemoryReactor
@@ -79,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
         content: Optional[dict] = None,
         access_token: Optional[str] = None,
         parent_id: Optional[str] = None,
+        expected_response_code: int = 200,
     ) -> FakeChannel:
         """Helper function to send a relation pointing at `self.parent_id`
 
@@ -115,16 +116,60 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
             content,
             access_token=access_token,
         )
+        self.assertEqual(expected_response_code, channel.code, channel.json_body)
         return channel
 
+    def _get_related_events(self) -> List[str]:
+        """
+        Requests /relations on the parent ID and returns a list of event IDs.
+        """
+        # Request the relations of the event.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        return [ev["event_id"] for ev in channel.json_body["chunk"]]
+
+    def _get_bundled_aggregations(self) -> JsonDict:
+        """
+        Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists.
+        """
+        # Fetch the bundled aggregations of the event.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        return channel.json_body["unsigned"].get("m.relations", {})
+
+    def _get_aggregations(self) -> List[JsonDict]:
+        """Request /aggregations on the parent ID and includes the returned chunk."""
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        return channel.json_body["chunk"]
+
+    def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+        """
+        Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+        """
+        for event in events:
+            if event["event_id"] == self.parent_id:
+                return event
+
+        raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
 
 class RelationsTestCase(BaseRelationsTestCase):
     def test_send_relation(self) -> None:
         """Tests that sending a relation works."""
-
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEqual(200, channel.code, channel.json_body)
-
         event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -151,13 +196,13 @@ class RelationsTestCase(BaseRelationsTestCase):
 
     def test_deny_invalid_event(self) -> None:
         """Test that we deny relations on non-existant events"""
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.ANNOTATION,
             EventTypes.Message,
             parent_id="foo",
             content={"body": "foo", "msgtype": "m.text"},
+            expected_response_code=400,
         )
-        self.assertEqual(400, channel.code, channel.json_body)
 
         # Unless that event is referenced from another event!
         self.get_success(
@@ -171,13 +216,12 @@ class RelationsTestCase(BaseRelationsTestCase):
                 desc="test_deny_invalid_event",
             )
         )
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.THREAD,
             EventTypes.Message,
             parent_id="foo",
             content={"body": "foo", "msgtype": "m.text"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
     def test_deny_invalid_room(self) -> None:
         """Test that we deny relations on non-existant events"""
@@ -187,18 +231,20 @@ class RelationsTestCase(BaseRelationsTestCase):
         parent_id = res["event_id"]
 
         # Attempt to send an annotation to that event.
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+        self._send_relation(
+            RelationTypes.ANNOTATION,
+            "m.reaction",
+            parent_id=parent_id,
+            key="A",
+            expected_response_code=400,
         )
-        self.assertEqual(400, channel.code, channel.json_body)
 
     def test_deny_double_react(self) -> None:
         """Test that we deny relations on membership events"""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(400, channel.code, channel.json_body)
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400
+        )
 
     def test_deny_forked_thread(self) -> None:
         """It is invalid to start a thread off a thread."""
@@ -208,316 +254,24 @@ class RelationsTestCase(BaseRelationsTestCase):
             content={"msgtype": "m.text", "body": "foo"},
             parent_id=self.parent_id,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         parent_id = channel.json_body["event_id"]
 
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.THREAD,
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo"},
             parent_id=parent_id,
+            expected_response_code=400,
         )
-        self.assertEqual(400, channel.code, channel.json_body)
-
-    def test_basic_paginate_relations(self) -> None:
-        """Tests that calling pagination API correctly the latest relations."""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-        first_annotation_id = channel.json_body["event_id"]
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEqual(200, channel.code, channel.json_body)
-        second_annotation_id = channel.json_body["event_id"]
-
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        # We expect to get back a single pagination result, which is the latest
-        # full relation event we sent above.
-        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-        self.assert_dict(
-            {
-                "event_id": second_annotation_id,
-                "sender": self.user_id,
-                "type": "m.reaction",
-            },
-            channel.json_body["chunk"][0],
-        )
-
-        # We also expect to get the original event (the id of which is self.parent_id)
-        self.assertEqual(
-            channel.json_body["original_event"]["event_id"], self.parent_id
-        )
-
-        # Make sure next_batch has something in it that looks like it could be a
-        # valid token.
-        self.assertIsInstance(
-            channel.json_body.get("next_batch"), str, channel.json_body
-        )
-
-        # Request the relations again, but with a different direction.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations"
-            f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        # We expect to get back a single pagination result, which is the earliest
-        # full relation event we sent above.
-        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-        self.assert_dict(
-            {
-                "event_id": first_annotation_id,
-                "sender": self.user_id,
-                "type": "m.reaction",
-            },
-            channel.json_body["chunk"][0],
-        )
-
-    def test_repeated_paginate_relations(self) -> None:
-        """Test that if we paginate using a limit and tokens then we get the
-        expected events.
-        """
-
-        expected_event_ids = []
-        for idx in range(10):
-            channel = self._send_relation(
-                RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-            expected_event_ids.append(channel.json_body["event_id"])
-
-        prev_token = ""
-        found_event_ids: List[str] = []
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
-
-    def test_pagination_from_sync_and_messages(self) -> None:
-        """Pagination tokens from /sync and /messages can be used to paginate /relations."""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
-        self.assertEqual(200, channel.code, channel.json_body)
-        annotation_id = channel.json_body["event_id"]
-        # Send an event after the relation events.
-        self.helper.send(self.room, body="Latest event", tok=self.user_token)
-
-        # Request /sync, limiting it such that only the latest event is returned
-        # (and not the relation).
-        filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
-        channel = self.make_request(
-            "GET", f"/sync?filter={filter}", access_token=self.user_token
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
-        sync_prev_batch = room_timeline["prev_batch"]
-        self.assertIsNotNone(sync_prev_batch)
-        # Ensure the relation event is not in the batch returned from /sync.
-        self.assertNotIn(
-            annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
-        )
-
-        # Request /messages, limiting it such that only the latest event is
-        # returned (and not the relation).
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/messages?dir=b&limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        messages_end = channel.json_body["end"]
-        self.assertIsNotNone(messages_end)
-        # Ensure the relation event is not in the chunk returned from /messages.
-        self.assertNotIn(
-            annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
-        )
-
-        # Request /relations with the pagination tokens received from both the
-        # /sync and /messages responses above, in turn.
-        #
-        # This is a tiny bit silly since the client wouldn't know the parent ID
-        # from the requests above; consider the parent ID to be known from a
-        # previous /sync.
-        for from_token in (sync_prev_batch, messages_end):
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            # The relation should be in the returned chunk.
-            self.assertIn(
-                annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
-            )
-
-    def test_aggregation_pagination_groups(self) -> None:
-        """Test that we can paginate annotation groups correctly."""
-
-        # We need to create ten separate users to send each reaction.
-        access_tokens = [self.user_token, self.user2_token]
-        idx = 0
-        while len(access_tokens) < 10:
-            user_id, token = self._create_user("test" + str(idx))
-            idx += 1
-
-            self.helper.join(self.room, user=user_id, tok=token)
-            access_tokens.append(token)
-
-        idx = 0
-        sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
-        for key in itertools.chain.from_iterable(
-            itertools.repeat(key, num) for key, num in sent_groups.items()
-        ):
-            channel = self._send_relation(
-                RelationTypes.ANNOTATION,
-                "m.reaction",
-                key=key,
-                access_token=access_tokens[idx],
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            idx += 1
-            idx %= len(access_tokens)
-
-        prev_token: Optional[str] = None
-        found_groups: Dict[str, int] = {}
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
-            for groups in channel.json_body["chunk"]:
-                # We only expect reactions
-                self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
-                # We should only see each key once
-                self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
-                found_groups[groups["key"]] = groups["count"]
-
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        self.assertEqual(sent_groups, found_groups)
-
-    def test_aggregation_pagination_within_group(self) -> None:
-        """Test that we can paginate within an annotation group."""
-
-        # We need to create ten separate users to send each reaction.
-        access_tokens = [self.user_token, self.user2_token]
-        idx = 0
-        while len(access_tokens) < 10:
-            user_id, token = self._create_user("test" + str(idx))
-            idx += 1
-
-            self.helper.join(self.room, user=user_id, tok=token)
-            access_tokens.append(token)
-
-        idx = 0
-        expected_event_ids = []
-        for _ in range(10):
-            channel = self._send_relation(
-                RelationTypes.ANNOTATION,
-                "m.reaction",
-                key="👍",
-                access_token=access_tokens[idx],
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-            expected_event_ids.append(channel.json_body["event_id"])
-
-            idx += 1
-
-        # Also send a different type of reaction so that we test we don't see it
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        prev_token = ""
-        found_event_ids: List[str] = []
-        encoded_key = urllib.parse.quote_plus("👍".encode())
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}"
-                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
-                f"/m.reaction/{encoded_key}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
 
     def test_aggregation(self) -> None:
         """Test that annotations get correctly aggregated."""
 
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEqual(200, channel.code, channel.json_body)
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
 
         channel = self.make_request(
             "GET",
@@ -547,215 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
         self.assertEqual(400, channel.code, channel.json_body)
 
-    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
-    def test_bundled_aggregations(self) -> None:
-        """
-        Test that annotations, references, and threads get correctly bundled.
-
-        Note that this doesn't test against /relations since only thread relations
-        get bundled via that API. See test_aggregation_get_event_for_thread.
-
-        See test_edit for a similar test for edits.
-        """
-        # Setup by sending a variety of relations.
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        reply_1 = channel.json_body["event_id"]
-
-        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        reply_2 = channel.json_body["event_id"]
-
-        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        thread_2 = channel.json_body["event_id"]
-
-        def assert_bundle(event_json: JsonDict) -> None:
-            """Assert the expected values of the bundled aggregations."""
-            relations_dict = event_json["unsigned"].get("m.relations")
-
-            # Ensure the fields are as expected.
-            self.assertCountEqual(
-                relations_dict.keys(),
-                (
-                    RelationTypes.ANNOTATION,
-                    RelationTypes.REFERENCE,
-                    RelationTypes.THREAD,
-                ),
-            )
-
-            # Check the values of each field.
-            self.assertEqual(
-                {
-                    "chunk": [
-                        {"type": "m.reaction", "key": "a", "count": 2},
-                        {"type": "m.reaction", "key": "b", "count": 1},
-                    ]
-                },
-                relations_dict[RelationTypes.ANNOTATION],
-            )
-
-            self.assertEqual(
-                {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
-                relations_dict[RelationTypes.REFERENCE],
-            )
-
-            self.assertEqual(
-                2,
-                relations_dict[RelationTypes.THREAD].get("count"),
-            )
-            self.assertTrue(
-                relations_dict[RelationTypes.THREAD].get("current_user_participated")
-            )
-            # The latest thread event has some fields that don't matter.
-            self.assert_dict(
-                {
-                    "content": {
-                        "m.relates_to": {
-                            "event_id": self.parent_id,
-                            "rel_type": RelationTypes.THREAD,
-                        }
-                    },
-                    "event_id": thread_2,
-                    "sender": self.user_id,
-                    "type": "m.room.test",
-                },
-                relations_dict[RelationTypes.THREAD].get("latest_event"),
-            )
-
-        # Request the event directly.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/event/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        assert_bundle(channel.json_body)
-
-        # Request the room messages.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/messages?dir=b",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
-
-        # Request the room context.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/context/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        assert_bundle(channel.json_body["event"])
-
-        # Request sync.
-        channel = self.make_request("GET", "/sync", access_token=self.user_token)
-        self.assertEqual(200, channel.code, channel.json_body)
-        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
-        self.assertTrue(room_timeline["limited"])
-        assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
-
-        # Request search.
-        channel = self.make_request(
-            "POST",
-            "/search",
-            # Search term matches the parent message.
-            content={"search_categories": {"room_events": {"search_term": "Hi"}}},
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        chunk = [
-            result["result"]
-            for result in channel.json_body["search_categories"]["room_events"][
-                "results"
-            ]
-        ]
-        assert_bundle(self._find_event_in_chunk(chunk))
-
-    def test_aggregation_get_event_for_annotation(self) -> None:
-        """Test that annotations do not get bundled aggregations included
-        when directly requested.
-        """
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-        annotation_id = channel.json_body["event_id"]
-
-        # Annotate the annotation.
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/event/{annotation_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
-
-    def test_aggregation_get_event_for_thread(self) -> None:
-        """Test that threads get bundled aggregations included when directly requested."""
-        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        thread_id = channel.json_body["event_id"]
-
-        # Annotate the annotation.
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/event/{thread_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(
-            channel.json_body["unsigned"].get("m.relations"),
-            {
-                RelationTypes.ANNOTATION: {
-                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
-                },
-            },
-        )
-
-        # It should also be included when the entire thread is requested.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(len(channel.json_body["chunk"]), 1)
-
-        thread_message = channel.json_body["chunk"][0]
-        self.assertEqual(
-            thread_message["unsigned"].get("m.relations"),
-            {
-                RelationTypes.ANNOTATION: {
-                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
-                },
-            },
-        )
-
     def test_ignore_invalid_room(self) -> None:
         """Test that we ignore invalid relations over federation."""
         # Create another room and send a message in it.
@@ -877,8 +422,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
         edit_event_id = channel.json_body["event_id"]
 
         def assert_bundle(event_json: JsonDict) -> None:
@@ -954,7 +497,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         shouldn't be allowed, are correctly handled.
         """
 
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             content={
@@ -963,7 +506,6 @@ class RelationsTestCase(BaseRelationsTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "First edit"},
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
         channel = self._send_relation(
@@ -971,11 +513,9 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
         edit_event_id = channel.json_body["event_id"]
 
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message.WRONG_TYPE",
             content={
@@ -984,7 +524,6 @@ class RelationsTestCase(BaseRelationsTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self.make_request(
             "GET",
@@ -1015,7 +554,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "A reply!"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         reply = channel.json_body["event_id"]
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1025,8 +563,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
             parent_id=reply,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
         edit_event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -1071,17 +607,15 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "A threaded reply!"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         threaded_event_id = channel.json_body["event_id"]
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
             parent_id=threaded_event_id,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         # Fetch the thread root, to get the bundled aggregation for the thread.
         channel = self.make_request(
@@ -1113,11 +647,10 @@ class RelationsTestCase(BaseRelationsTestCase):
                 "m.new_content": new_body,
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         edit_event_id = channel.json_body["event_id"]
 
         # Edit the edit event.
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             content={
@@ -1127,7 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             },
             parent_id=edit_event_id,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         # Request the original event.
         channel = self.make_request(
@@ -1154,7 +686,6 @@ class RelationsTestCase(BaseRelationsTestCase):
     def test_unknown_relations(self) -> None:
         """Unknown relations should be accepted."""
         channel = self._send_relation("m.relation.test", "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
         event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -1195,28 +726,15 @@ class RelationsTestCase(BaseRelationsTestCase):
         self.assertEqual(200, channel.code, channel.json_body)
         self.assertEqual(channel.json_body["chunk"], [])
 
-    def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
-        """
-        Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
-        """
-        for event in events:
-            if event["event_id"] == self.parent_id:
-                return event
-
-        raise AssertionError(f"Event {self.parent_id} not found in chunk")
-
     def test_background_update(self) -> None:
         """Test the event_arbitrary_relations background update."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEqual(200, channel.code, channel.json_body)
         annotation_event_id_good = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
-        self.assertEqual(200, channel.code, channel.json_body)
         annotation_event_id_bad = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
         thread_event_id = channel.json_body["event_id"]
 
         # Clean-up the table as if the inserts did not happen during event creation.
@@ -1267,6 +785,516 @@ class RelationsTestCase(BaseRelationsTestCase):
             [annotation_event_id_good, thread_event_id],
         )
 
+
+class RelationPaginationTestCase(BaseRelationsTestCase):
+    def test_basic_paginate_relations(self) -> None:
+        """Tests that calling pagination API correctly the latest relations."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        first_annotation_id = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+        second_annotation_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the latest
+        # full relation event we sent above.
+        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {
+                "event_id": second_annotation_id,
+                "sender": self.user_id,
+                "type": "m.reaction",
+            },
+            channel.json_body["chunk"][0],
+        )
+
+        # We also expect to get the original event (the id of which is self.parent_id)
+        self.assertEqual(
+            channel.json_body["original_event"]["event_id"], self.parent_id
+        )
+
+        # Make sure next_batch has something in it that looks like it could be a
+        # valid token.
+        self.assertIsInstance(
+            channel.json_body.get("next_batch"), str, channel.json_body
+        )
+
+        # Request the relations again, but with a different direction.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations"
+            f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the earliest
+        # full relation event we sent above.
+        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {
+                "event_id": first_annotation_id,
+                "sender": self.user_id,
+                "type": "m.reaction",
+            },
+            channel.json_body["chunk"][0],
+        )
+
+    def test_repeated_paginate_relations(self) -> None:
+        """Test that if we paginate using a limit and tokens then we get the
+        expected events.
+        """
+
+        expected_event_ids = []
+        for idx in range(10):
+            channel = self._send_relation(
+                RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
+            )
+            expected_event_ids.append(channel.json_body["event_id"])
+
+        prev_token = ""
+        found_event_ids: List[str] = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEqual(found_event_ids, expected_event_ids)
+
+    def test_pagination_from_sync_and_messages(self) -> None:
+        """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+        annotation_id = channel.json_body["event_id"]
+        # Send an event after the relation events.
+        self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+        # Request /sync, limiting it such that only the latest event is returned
+        # (and not the relation).
+        filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        sync_prev_batch = room_timeline["prev_batch"]
+        self.assertIsNotNone(sync_prev_batch)
+        # Ensure the relation event is not in the batch returned from /sync.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+        )
+
+        # Request /messages, limiting it such that only the latest event is
+        # returned (and not the relation).
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b&limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        messages_end = channel.json_body["end"]
+        self.assertIsNotNone(messages_end)
+        # Ensure the relation event is not in the chunk returned from /messages.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+        )
+
+        # Request /relations with the pagination tokens received from both the
+        # /sync and /messages responses above, in turn.
+        #
+        # This is a tiny bit silly since the client wouldn't know the parent ID
+        # from the requests above; consider the parent ID to be known from a
+        # previous /sync.
+        for from_token in (sync_prev_batch, messages_end):
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            # The relation should be in the returned chunk.
+            self.assertIn(
+                annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+            )
+
+    def test_aggregation_pagination_groups(self) -> None:
+        """Test that we can paginate annotation groups correctly."""
+
+        # We need to create ten separate users to send each reaction.
+        access_tokens = [self.user_token, self.user2_token]
+        idx = 0
+        while len(access_tokens) < 10:
+            user_id, token = self._create_user("test" + str(idx))
+            idx += 1
+
+            self.helper.join(self.room, user=user_id, tok=token)
+            access_tokens.append(token)
+
+        idx = 0
+        sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
+        for key in itertools.chain.from_iterable(
+            itertools.repeat(key, num) for key, num in sent_groups.items()
+        ):
+            self._send_relation(
+                RelationTypes.ANNOTATION,
+                "m.reaction",
+                key=key,
+                access_token=access_tokens[idx],
+            )
+
+            idx += 1
+            idx %= len(access_tokens)
+
+        prev_token: Optional[str] = None
+        found_groups: Dict[str, int] = {}
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+            for groups in channel.json_body["chunk"]:
+                # We only expect reactions
+                self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+                # We should only see each key once
+                self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+                found_groups[groups["key"]] = groups["count"]
+
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        self.assertEqual(sent_groups, found_groups)
+
+    def test_aggregation_pagination_within_group(self) -> None:
+        """Test that we can paginate within an annotation group."""
+
+        # We need to create ten separate users to send each reaction.
+        access_tokens = [self.user_token, self.user2_token]
+        idx = 0
+        while len(access_tokens) < 10:
+            user_id, token = self._create_user("test" + str(idx))
+            idx += 1
+
+            self.helper.join(self.room, user=user_id, tok=token)
+            access_tokens.append(token)
+
+        idx = 0
+        expected_event_ids = []
+        for _ in range(10):
+            channel = self._send_relation(
+                RelationTypes.ANNOTATION,
+                "m.reaction",
+                key="👍",
+                access_token=access_tokens[idx],
+            )
+            expected_event_ids.append(channel.json_body["event_id"])
+
+            idx += 1
+
+        # Also send a different type of reaction so that we test we don't see it
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+
+        prev_token = ""
+        found_event_ids: List[str] = []
+        encoded_key = urllib.parse.quote_plus("👍".encode())
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}"
+                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+                f"/m.reaction/{encoded_key}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEqual(found_event_ids, expected_event_ids)
+
+
+class BundledAggregationsTestCase(BaseRelationsTestCase):
+    """
+    See RelationsTestCase.test_edit for a similar test for edits.
+
+    Note that this doesn't test against /relations since only thread relations
+    get bundled via that API. See test_aggregation_get_event_for_thread.
+    """
+
+    def _test_bundled_aggregations(
+        self,
+        relation_type: str,
+        assertion_callable: Callable[[JsonDict], None],
+        expected_db_txn_for_event: int,
+    ) -> None:
+        """
+        Makes requests to various endpoints which should include bundled aggregations
+        and then calls an assertion function on the bundled aggregations.
+
+        Args:
+            relation_type: The field to search for in the `m.relations` field in unsigned.
+            assertion_callable: Called with the contents of unsigned["m.relations"][relation_type]
+                for relation-specific assertions.
+            expected_db_txn_for_event: The number of database transactions which
+                are expected for a call to /event/.
+        """
+
+        def assert_bundle(event_json: JsonDict) -> None:
+            """Assert the expected values of the bundled aggregations."""
+            relations_dict = event_json["unsigned"].get("m.relations")
+
+            # Ensure the fields are as expected.
+            self.assertCountEqual(relations_dict.keys(), (relation_type,))
+            assertion_callable(relations_dict[relation_type])
+
+        # Request the event directly.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        assert_bundle(channel.json_body)
+        assert channel.resource_usage is not None
+        self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event)
+
+        # Request the room messages.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+
+        # Request the room context.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/context/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        assert_bundle(channel.json_body["event"])
+
+        # Request sync.
+        filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        self.assertTrue(room_timeline["limited"])
+        assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+
+        # Request search.
+        channel = self.make_request(
+            "POST",
+            "/search",
+            # Search term matches the parent message.
+            content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        chunk = [
+            result["result"]
+            for result in channel.json_body["search_categories"]["room_events"][
+                "results"
+            ]
+        ]
+        assert_bundle(self._find_event_in_chunk(chunk))
+
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+    def test_annotation(self) -> None:
+        """
+        Test that annotations get correctly bundled.
+        """
+        # Setup by sending a variety of relations.
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+        )
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+
+        def assert_annotations(bundled_aggregations: JsonDict) -> None:
+            self.assertEqual(
+                {
+                    "chunk": [
+                        {"type": "m.reaction", "key": "a", "count": 2},
+                        {"type": "m.reaction", "key": "b", "count": 1},
+                    ]
+                },
+                bundled_aggregations,
+            )
+
+        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+    def test_reference(self) -> None:
+        """
+        Test that references get correctly bundled.
+        """
+        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+        reply_1 = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+        reply_2 = channel.json_body["event_id"]
+
+        def assert_annotations(bundled_aggregations: JsonDict) -> None:
+            self.assertEqual(
+                {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
+                bundled_aggregations,
+            )
+
+        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+    def test_thread(self) -> None:
+        """
+        Test that threads get correctly bundled.
+        """
+        self._send_relation(RelationTypes.THREAD, "m.room.test")
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        thread_2 = channel.json_body["event_id"]
+
+        def assert_annotations(bundled_aggregations: JsonDict) -> None:
+            self.assertEqual(2, bundled_aggregations.get("count"))
+            self.assertTrue(bundled_aggregations.get("current_user_participated"))
+            # The latest thread event has some fields that don't matter.
+            self.assert_dict(
+                {
+                    "content": {
+                        "m.relates_to": {
+                            "event_id": self.parent_id,
+                            "rel_type": RelationTypes.THREAD,
+                        }
+                    },
+                    "event_id": thread_2,
+                    "sender": self.user_id,
+                    "type": "m.room.test",
+                },
+                bundled_aggregations.get("latest_event"),
+            )
+
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9)
+
+    def test_aggregation_get_event_for_annotation(self) -> None:
+        """Test that annotations do not get bundled aggregations included
+        when directly requested.
+        """
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        annotation_id = channel.json_body["event_id"]
+
+        # Annotate the annotation.
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
+        )
+
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{annotation_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
+
+    def test_aggregation_get_event_for_thread(self) -> None:
+        """Test that threads get bundled aggregations included when directly requested."""
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        thread_id = channel.json_body["event_id"]
+
+        # Annotate the annotation.
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+        )
+
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{thread_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertEqual(
+            channel.json_body["unsigned"].get("m.relations"),
+            {
+                RelationTypes.ANNOTATION: {
+                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+                },
+            },
+        )
+
+        # It should also be included when the entire thread is requested.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertEqual(len(channel.json_body["chunk"]), 1)
+
+        thread_message = channel.json_body["chunk"][0]
+        self.assertEqual(
+            thread_message["unsigned"].get("m.relations"),
+            {
+                RelationTypes.ANNOTATION: {
+                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+                },
+            },
+        )
+
     def test_bundled_aggregations_with_filter(self) -> None:
         """
         If "unsigned" is an omitted field (due to filtering), adding the bundled
@@ -1322,46 +1350,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         )
         self.assertEqual(200, channel.code, channel.json_body)
 
-    def _make_relation_requests(self) -> Tuple[List[str], JsonDict]:
-        """
-        Makes requests and ensures they result in a 200 response, returns a
-        tuple of results:
-
-        1. `/relations` -> Returns a list of event IDs.
-        2. `/event` -> Returns the response's m.relations field (from unsigned),
-            if it exists.
-        """
-
-        # Request the relations of the event.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEquals(200, channel.code, channel.json_body)
-        event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
-
-        # Fetch the bundled aggregations of the event.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEquals(200, channel.code, channel.json_body)
-        bundled_relations = channel.json_body["unsigned"].get("m.relations", {})
-
-        return event_ids, bundled_relations
-
-    def _get_aggregations(self) -> List[JsonDict]:
-        """Request /aggregations on the parent ID and includes the returned chunk."""
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        return channel.json_body["chunk"]
-
     def test_redact_relation_annotation(self) -> None:
         """
         Test that annotations of an event are properly handled after the
@@ -1371,17 +1359,16 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         the response to relations.
         """
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
         to_redact_event_id = channel.json_body["event_id"]
 
         channel = self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         unredacted_event_id = channel.json_body["event_id"]
 
         # Both relations should exist.
-        event_ids, relations = self._make_relation_requests()
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
         self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
         self.assertEquals(
             relations["m.annotation"],
@@ -1396,7 +1383,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self._redact(to_redact_event_id)
 
         # The unredacted relation should still exist.
-        event_ids, relations = self._make_relation_requests()
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
         self.assertEquals(event_ids, [unredacted_event_id])
         self.assertEquals(
             relations["m.annotation"],
@@ -1419,7 +1407,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             EventTypes.Message,
             content={"body": "reply 1", "msgtype": "m.text"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         unredacted_event_id = channel.json_body["event_id"]
 
         # Note that the *last* event in the thread is redacted, as that gets
@@ -1429,11 +1416,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             EventTypes.Message,
             content={"body": "reply 2", "msgtype": "m.text"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         to_redact_event_id = channel.json_body["event_id"]
 
         # Both relations exist.
-        event_ids, relations = self._make_relation_requests()
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
         self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
         self.assertDictContainsSubset(
             {
@@ -1452,7 +1439,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self._redact(to_redact_event_id)
 
         # The unredacted relation should still exist.
-        event_ids, relations = self._make_relation_requests()
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
         self.assertEquals(event_ids, [unredacted_event_id])
         self.assertDictContainsSubset(
             {
@@ -1472,7 +1460,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         is redacted.
         """
         # Add a relation
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             parent_id=self.parent_id,
@@ -1482,10 +1470,10 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "First edit"},
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         # Check the relation is returned
-        event_ids, relations = self._make_relation_requests()
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
         self.assertEqual(len(event_ids), 1)
         self.assertIn(RelationTypes.REPLACE, relations)
 
@@ -1493,7 +1481,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self._redact(self.parent_id)
 
         # The relations are not returned.
-        event_ids, relations = self._make_relation_requests()
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
         self.assertEqual(len(event_ids), 0)
         self.assertEqual(relations, {})
 
@@ -1503,11 +1492,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         """
         # Add a relation
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEqual(200, channel.code, channel.json_body)
         related_event_id = channel.json_body["event_id"]
 
         # The relations should exist.
-        event_ids, relations = self._make_relation_requests()
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
         self.assertEqual(len(event_ids), 1)
         self.assertIn(RelationTypes.ANNOTATION, relations)
 
@@ -1519,7 +1508,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self._redact(self.parent_id)
 
         # The relations are returned.
-        event_ids, relations = self._make_relation_requests()
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
         self.assertEquals(event_ids, [related_event_id])
         self.assertEquals(
             relations["m.annotation"],
@@ -1540,14 +1530,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             EventTypes.Message,
             content={"body": "reply 1", "msgtype": "m.text"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         related_event_id = channel.json_body["event_id"]
 
         # Redact one of the reactions.
         self._redact(self.parent_id)
 
         # The unredacted relation should still exist.
-        event_ids, relations = self._make_relation_requests()
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
         self.assertEquals(len(event_ids), 1)
         self.assertDictContainsSubset(
             {
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index 3fb37a2a59..62e308814d 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import (
     _get_html_media_encodings,
     decode_body,
     parse_html_to_open_graph,
-    rebase_url,
     summarize_paragraphs,
 )
 
@@ -161,7 +160,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -177,7 +176,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(
             og,
@@ -218,7 +217,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -232,7 +231,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
@@ -247,7 +246,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
 
@@ -262,7 +261,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
@@ -290,7 +289,7 @@ class CalcOgTestCase(unittest.TestCase):
         <head><title>Foo</title></head><body>Some text.</body></html>
         """.strip()
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_invalid_encoding(self) -> None:
@@ -304,7 +303,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_invalid_encoding2(self) -> None:
@@ -319,7 +318,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
 
     def test_windows_1252(self) -> None:
@@ -333,7 +332,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
 
 
@@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase):
             'text/html; charset="invalid"',
         )
         self.assertEqual(list(encodings), ["utf-8", "cp1252"])
-
-
-class RebaseUrlTestCase(unittest.TestCase):
-    def test_relative(self) -> None:
-        """Relative URLs should be resolved based on the context of the base URL."""
-        self.assertEqual(
-            rebase_url("subpage", "https://example.com/foo/"),
-            "https://example.com/foo/subpage",
-        )
-        self.assertEqual(
-            rebase_url("sibling", "https://example.com/foo"),
-            "https://example.com/sibling",
-        )
-        self.assertEqual(
-            rebase_url("/bar", "https://example.com/foo/"),
-            "https://example.com/bar",
-        )
-
-    def test_absolute(self) -> None:
-        """Absolute URLs should not be modified."""
-        self.assertEqual(
-            rebase_url("https://alice.com/a/", "https://example.com/foo/"),
-            "https://alice.com/a/",
-        )
-
-    def test_data(self) -> None:
-        """Data URLs should not be modified."""
-        self.assertEqual(
-            rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
-            "data:,Hello%2C%20World%21",
-        )
diff --git a/tests/server.py b/tests/server.py
index 82990c2eb9..6ce2a17bf4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -54,13 +54,18 @@ from twisted.internet.interfaces import (
     ITransport,
 )
 from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import (
+    AccumulatingProtocol,
+    MemoryReactor,
+    MemoryReactorClock,
+)
 from twisted.web.http_headers import Headers
 from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.http.site import SynapseRequest
+from synapse.logging.context import ContextResourceUsage
 from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.engines import PostgresEngine, create_engine
@@ -88,18 +93,19 @@ class TimedOutException(Exception):
     """
 
 
-@attr.s
+@attr.s(auto_attribs=True)
 class FakeChannel:
     """
     A fake Twisted Web Channel (the part that interfaces with the
     wire).
     """
 
-    site = attr.ib(type=Union[Site, "FakeSite"])
-    _reactor = attr.ib()
-    result = attr.ib(type=dict, default=attr.Factory(dict))
-    _ip = attr.ib(type=str, default="127.0.0.1")
+    site: Union[Site, "FakeSite"]
+    _reactor: MemoryReactor
+    result: dict = attr.Factory(dict)
+    _ip: str = "127.0.0.1"
     _producer: Optional[Union[IPullProducer, IPushProducer]] = None
+    resource_usage: Optional[ContextResourceUsage] = None
 
     @property
     def json_body(self):
@@ -168,6 +174,8 @@ class FakeChannel:
 
     def requestDone(self, _self):
         self.result["done"] = True
+        if isinstance(_self, SynapseRequest):
+            self.resource_usage = _self.logcontext.get_resource_usage()
 
     def getPeer(self):
         # We give an address so that getClientIP returns a non null entry,
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 272cd35402..72bf5b3d31 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
             expected_ignorer_user_ids,
         )
 
+    def assert_ignored(
+        self, ignorer_user_id: str, expected_ignored_user_ids: Set[str]
+    ) -> None:
+        self.assertEqual(
+            self.get_success(self.store.ignored_users(ignorer_user_id)),
+            expected_ignored_user_ids,
+        )
+
     def test_ignoring_users(self):
         """Basic adding/removing of users from the ignore list."""
         self._update_ignore_list("@other:test", "@another:remote")
+        self.assert_ignored(self.user, {"@other:test", "@another:remote"})
 
         # Check a user which no one ignores.
         self.assert_ignorers("@user:test", set())
@@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
 
         # Add one user, remove one user, and leave one user.
         self._update_ignore_list("@foo:test", "@another:remote")
+        self.assert_ignored(self.user, {"@foo:test", "@another:remote"})
 
         # Check the removed user.
         self.assert_ignorers("@other:test", set())
@@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
         """Ensure that caching works properly between different users."""
         # The first user ignores a user.
         self._update_ignore_list("@other:test")
+        self.assert_ignored(self.user, {"@other:test"})
         self.assert_ignorers("@other:test", {self.user})
 
         # The second user ignores them.
         self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+        self.assert_ignored("@second:test", {"@other:test"})
         self.assert_ignorers("@other:test", {self.user, "@second:test"})
 
         # The first user un-ignores them.
         self._update_ignore_list()
+        self.assert_ignored(self.user, set())
         self.assert_ignorers("@other:test", {"@second:test"})
 
     def test_invalid_data(self):
         """Invalid data ends up clearing out the ignored users list."""
         # Add some data and ensure it is there.
         self._update_ignore_list("@other:test")
+        self.assert_ignored(self.user, {"@other:test"})
         self.assert_ignorers("@other:test", {self.user})
 
         # No ignored_users key.
@@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
         )
 
         # No one ignores the user now.
+        self.assert_ignored(self.user, set())
         self.assert_ignorers("@other:test", set())
 
         # Add some data and ensure it is there.
         self._update_ignore_list("@other:test")
+        self.assert_ignored(self.user, {"@other:test"})
         self.assert_ignorers("@other:test", {self.user})
 
         # Invalid data.
@@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
         )
 
         # No one ignores the user now.
+        self.assert_ignored(self.user, set())
         self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 5cf18b690e..fd619b64d4 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -17,8 +17,12 @@ from unittest.mock import Mock
 import yaml
 
 from twisted.internet.defer import Deferred, ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.server import HomeServer
 from synapse.storage.background_updates import BackgroundUpdater
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable, simple_async_mock
@@ -26,7 +30,7 @@ from tests.unittest import override_config
 
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
         # the base test class should have run the real bg updates for us
         self.assertTrue(
@@ -39,7 +43,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         )
         self.store = self.hs.get_datastores().main
 
-    async def update(self, progress, count):
+    async def update(self, progress: JsonDict, count: int) -> int:
         duration_ms = 10
         await self.clock.sleep((count * duration_ms) / 1000)
         progress = {"my_key": progress["my_key"] + 1}
@@ -51,7 +55,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         )
         return count
 
-    def test_do_background_update(self):
+    def test_do_background_update(self) -> None:
         # the time we claim it takes to update one item when running the update
         duration_ms = 10
 
@@ -80,7 +84,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         # second step: complete the update
         # we should now get run with a much bigger number of items to update
-        async def update(progress, count):
+        async def update(progress: JsonDict, count: int) -> int:
             self.assertEqual(progress, {"my_key": 2})
             self.assertAlmostEqual(
                 count,
@@ -110,7 +114,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
             """
         )
     )
-    def test_background_update_default_batch_set_by_config(self):
+    def test_background_update_default_batch_set_by_config(self) -> None:
         """
         Test that the background update is run with the default_batch_size set by the config
         """
@@ -133,7 +137,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         # on the first call, we should get run with the default background update size specified in the config
         self.update_handler.assert_called_once_with({"my_key": 1}, 20)
 
-    def test_background_update_default_sleep_behavior(self):
+    def test_background_update_default_sleep_behavior(self) -> None:
         """
         Test default background update behavior, which is to sleep
         """
@@ -147,7 +151,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         self.update_handler.side_effect = self.update
         self.update_handler.reset_mock()
-        self.updates.start_doing_background_updates(),
+        self.updates.start_doing_background_updates()
 
         # 2: advance the reactor less than the default sleep duration (1000ms)
         self.reactor.pump([0.5])
@@ -167,7 +171,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
             """
         )
     )
-    def test_background_update_sleep_set_in_config(self):
+    def test_background_update_sleep_set_in_config(self) -> None:
         """
         Test that changing the sleep time in the config changes how long it sleeps
         """
@@ -181,7 +185,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         self.update_handler.side_effect = self.update
         self.update_handler.reset_mock()
-        self.updates.start_doing_background_updates(),
+        self.updates.start_doing_background_updates()
 
         # 2: advance the reactor less than the configured sleep duration (500ms)
         self.reactor.pump([0.45])
@@ -201,7 +205,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
             """
         )
     )
-    def test_disabling_background_update_sleep(self):
+    def test_disabling_background_update_sleep(self) -> None:
         """
         Test that disabling sleep in the config results in bg update not sleeping
         """
@@ -215,7 +219,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         self.update_handler.side_effect = self.update
         self.update_handler.reset_mock()
-        self.updates.start_doing_background_updates(),
+        self.updates.start_doing_background_updates()
 
         # 2: advance the reactor very little
         self.reactor.pump([0.025])
@@ -230,7 +234,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
             """
         )
     )
-    def test_background_update_duration_set_in_config(self):
+    def test_background_update_duration_set_in_config(self) -> None:
         """
         Test that the desired duration set in the config is used in determining batch size
         """
@@ -254,7 +258,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         # the first update was run with the default batch size, this should be run with 500ms as the
         # desired duration
-        async def update(progress, count):
+        async def update(progress: JsonDict, count: int) -> int:
             self.assertEqual(progress, {"my_key": 2})
             self.assertAlmostEqual(
                 count,
@@ -275,7 +279,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
             """
         )
     )
-    def test_background_update_min_batch_set_in_config(self):
+    def test_background_update_min_batch_set_in_config(self) -> None:
         """
         Test that the minimum batch size set in the config is used
         """
@@ -290,7 +294,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         # Run the update with the long-running update item
-        async def update(progress, count):
+        async def update_long(progress: JsonDict, count: int) -> int:
             await self.clock.sleep((count * duration_ms) / 1000)
             progress = {"my_key": progress["my_key"] + 1}
             await self.store.db_pool.runInteraction(
@@ -301,7 +305,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
             )
             return count
 
-        self.update_handler.side_effect = update
+        self.update_handler.side_effect = update_long
         self.update_handler.reset_mock()
         res = self.get_success(
             self.updates.do_next_background_update(False),
@@ -311,25 +315,25 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         # the first update was run with the default batch size, this should be run with minimum batch size
         # as the first items took a very long time
-        async def update(progress, count):
+        async def update_short(progress: JsonDict, count: int) -> int:
             self.assertEqual(progress, {"my_key": 2})
             self.assertEqual(count, 5)
             await self.updates._end_background_update("test_update")
             return count
 
-        self.update_handler.side_effect = update
+        self.update_handler.side_effect = update_short
         self.get_success(self.updates.do_next_background_update(False))
 
 
 class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
         # the base test class should have run the real bg updates for us
         self.assertTrue(
             self.get_success(self.updates.has_completed_background_updates())
         )
 
-        self.update_deferred = Deferred()
+        self.update_deferred: Deferred[int] = Deferred()
         self.update_handler = Mock(return_value=self.update_deferred)
         self.updates.register_background_update_handler(
             "test_update", self.update_handler
@@ -358,7 +362,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
             ),
         )
 
-    def test_controller(self):
+    def test_controller(self) -> None:
         store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
@@ -368,7 +372,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
         )
 
         # Set the return value for the context manager.
-        enter_defer = Deferred()
+        enter_defer: Deferred[int] = Deferred()
         self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
 
         # Start the background update.
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 8597867563..a40fc20ef9 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -12,7 +12,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.database import make_tuple_comparison_clause
+from typing import Callable, Tuple
+from unittest.mock import Mock, call
+
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError, Deferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingTransaction,
+    make_tuple_comparison_clause,
+)
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -22,3 +35,150 @@ class TupleComparisonClauseTestCase(unittest.TestCase):
         clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
         self.assertEqual(clause, "(a,b) > (?,?)")
         self.assertEqual(args, [1, 2])
+
+
+class CallbacksTestCase(unittest.HomeserverTestCase):
+    """Tests for transaction callbacks."""
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.db_pool: DatabasePool = self.store.db_pool
+
+    def _run_interaction(
+        self, func: Callable[[LoggingTransaction], object]
+    ) -> Tuple[Mock, Mock]:
+        """Run the given function in a database transaction, with callbacks registered.
+
+        Args:
+            func: The function to be run in a transaction. The transaction will be
+                retried if `func` raises an `OperationalError`.
+
+        Returns:
+            Two mocks, which were registered as an `after_callback` and an
+            `exception_callback` respectively, on every transaction attempt.
+        """
+        after_callback = Mock()
+        exception_callback = Mock()
+
+        def _test_txn(txn: LoggingTransaction) -> None:
+            txn.call_after(after_callback, 123, 456, extra=789)
+            txn.call_on_exception(exception_callback, 987, 654, extra=321)
+            func(txn)
+
+        try:
+            self.get_success_or_raise(
+                self.db_pool.runInteraction("test_transaction", _test_txn)
+            )
+        except Exception:
+            pass
+
+        return after_callback, exception_callback
+
+    def test_after_callback(self) -> None:
+        """Test that the after callback is called when a transaction succeeds."""
+        after_callback, exception_callback = self._run_interaction(lambda txn: None)
+
+        after_callback.assert_called_once_with(123, 456, extra=789)
+        exception_callback.assert_not_called()
+
+    def test_exception_callback(self) -> None:
+        """Test that the exception callback is called when a transaction fails."""
+        _test_txn = Mock(side_effect=ZeroDivisionError)
+        after_callback, exception_callback = self._run_interaction(_test_txn)
+
+        after_callback.assert_not_called()
+        exception_callback.assert_called_once_with(987, 654, extra=321)
+
+    def test_failed_retry(self) -> None:
+        """Test that the exception callback is called for every failed attempt."""
+        # Always raise an `OperationalError`.
+        _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
+        after_callback, exception_callback = self._run_interaction(_test_txn)
+
+        after_callback.assert_not_called()
+        exception_callback.assert_has_calls(
+            [
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+            ]
+        )
+        self.assertEqual(exception_callback.call_count, 6)  # no additional calls
+
+    def test_successful_retry(self) -> None:
+        """Test callbacks for a failed transaction followed by a successful attempt."""
+        # Raise an `OperationalError` on the first attempt only.
+        _test_txn = Mock(
+            side_effect=[self.db_pool.engine.module.OperationalError, None]
+        )
+        after_callback, exception_callback = self._run_interaction(_test_txn)
+
+        # Calling both `after_callback`s when the first attempt failed is rather
+        # surprising (#12184). Let's document the behaviour in a test.
+        after_callback.assert_has_calls(
+            [
+                call(123, 456, extra=789),
+                call(123, 456, extra=789),
+            ]
+        )
+        self.assertEqual(after_callback.call_count, 2)  # no additional calls
+        exception_callback.assert_not_called()
+
+
+class CancellationTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.db_pool: DatabasePool = self.store.db_pool
+
+    def test_after_callback(self) -> None:
+        """Test that the after callback is called when a transaction succeeds."""
+        d: "Deferred[None]"
+        after_callback = Mock()
+        exception_callback = Mock()
+
+        def _test_txn(txn: LoggingTransaction) -> None:
+            txn.call_after(after_callback, 123, 456, extra=789)
+            txn.call_on_exception(exception_callback, 987, 654, extra=321)
+            d.cancel()
+
+        d = defer.ensureDeferred(
+            self.db_pool.runInteraction("test_transaction", _test_txn)
+        )
+        self.get_failure(d, CancelledError)
+
+        after_callback.assert_called_once_with(123, 456, extra=789)
+        exception_callback.assert_not_called()
+
+    def test_exception_callback(self) -> None:
+        """Test that the exception callback is called when a transaction fails."""
+        d: "Deferred[None]"
+        after_callback = Mock()
+        exception_callback = Mock()
+
+        def _test_txn(txn: LoggingTransaction) -> None:
+            txn.call_after(after_callback, 123, 456, extra=789)
+            txn.call_on_exception(exception_callback, 987, 654, extra=321)
+            d.cancel()
+            # Simulate a retryable failure on every attempt.
+            raise self.db_pool.engine.module.OperationalError()
+
+        d = defer.ensureDeferred(
+            self.db_pool.runInteraction("test_transaction", _test_txn)
+        )
+        self.get_failure(d, CancelledError)
+
+        after_callback.assert_not_called()
+        exception_callback.assert_has_calls(
+            [
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+            ]
+        )
+        self.assertEqual(exception_callback.call_count, 6)  # no additional calls
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 6ac4b93f98..395396340b 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -13,9 +13,13 @@
 # limitations under the License.
 from typing import List, Optional
 
-from synapse.storage.database import DatabasePool
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import IncorrectDatabaseSetup
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
-    def _setup_db(self, txn):
+    def _setup_db(self, txn: LoggingTransaction) -> None:
         txn.execute("CREATE SEQUENCE foobar_seq")
         txn.execute(
             """
@@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
 
-    def _insert_rows(self, instance_name: str, number: int):
+    def _insert_rows(self, instance_name: str, number: int) -> None:
         """Insert N rows as the given instance, inserting with stream IDs pulled
         from the postgres sequence.
         """
 
-        def _insert(txn):
+        def _insert(txn: LoggingTransaction) -> None:
             for _ in range(number):
                 txn.execute(
                     "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
@@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
 
-    def _insert_row_with_id(self, instance_name: str, stream_id: int):
+    def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
         """Insert one row as the given instance with given stream_id, updating
         the postgres sequence position to match.
         """
 
-        def _insert(txn):
+        def _insert(txn: LoggingTransaction) -> None:
             txn.execute(
                 "INSERT INTO foobar VALUES (?, ?)",
                 (
@@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
 
-    def test_empty(self):
+    def test_empty(self) -> None:
         """Test an ID generator against an empty database gives sensible
         current positions.
         """
@@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # The table is empty so we expect an empty map for positions
         self.assertEqual(id_gen.get_positions(), {})
 
-    def test_single_instance(self):
+    def test_single_instance(self) -> None:
         """Test that reads and writes from a single process are handled
         correctly.
         """
@@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 8)
 
@@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 8})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
-    def test_out_of_order_finish(self):
+    def test_out_of_order_finish(self) -> None:
         """Test that IDs persisted out of order are correctly handled"""
 
         # Prefill table with 7 rows written by 'master'
@@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 11})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
 
-    def test_multi_instance(self):
+    def test_multi_instance(self) -> None:
         """Test that reads and writes from multiple processes are handled
         correctly.
         """
@@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with first_id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 8)
 
@@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # ... but calling `get_next` on the second instance should give a unique
         # stream ID
 
-        async def _get_next_async():
+        async def _get_next_async2() -> None:
             async with second_id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 9)
 
@@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                     second_id_gen.get_positions(), {"first": 3, "second": 7}
                 )
 
-        self.get_success(_get_next_async())
+        self.get_success(_get_next_async2())
 
         self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
 
@@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen.advance("first", 8)
         self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
 
-    def test_get_next_txn(self):
+    def test_get_next_txn(self) -> None:
         """Test that the `get_next_txn` function works correctly."""
 
         # Prefill table with 7 rows written by 'master'
@@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
 
-        def _get_next_txn(txn):
+        def _get_next_txn(txn: LoggingTransaction) -> None:
             stream_id = id_gen.get_next_txn(txn)
             self.assertEqual(stream_id, 8)
 
@@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 8})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
-    def test_get_persisted_upto_position(self):
+    def test_get_persisted_upto_position(self) -> None:
         """Test that `get_persisted_upto_position` correctly tracks updates to
         positions.
         """
@@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen.advance("second", 15)
         self.assertEqual(id_gen.get_persisted_upto_position(), 11)
 
-    def test_get_persisted_upto_position_get_next(self):
+    def test_get_persisted_upto_position_get_next(self) -> None:
         """Test that `get_persisted_upto_position` correctly tracks updates to
         positions when `get_next` is called.
         """
@@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 6)
                 self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # `persisted_upto_position` in this case, then it will be correct in the
         # other cases that are tested above (since they'll hit the same code).
 
-    def test_restart_during_out_of_order_persistence(self):
+    def test_restart_during_out_of_order_persistence(self) -> None:
         """Test that restarting a process while another process is writing out
         of order updates are handled correctly.
         """
@@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen_worker.advance("master", 9)
         self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
 
-    def test_writer_config_change(self):
+    def test_writer_config_change(self) -> None:
         """Test that changing the writer config correctly works."""
 
         self._insert_row_with_id("first", 3)
@@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         # Check that we get a sane next stream ID with this new config.
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with id_gen_3.get_next() as stream_id:
                 self.assertEqual(stream_id, 6)
 
@@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
         self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
 
-    def test_sequence_consistency(self):
+    def test_sequence_consistency(self) -> None:
         """Test that we error out if the table and sequence diverges."""
 
         # Prefill with some rows
@@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
-    def _setup_db(self, txn):
+    def _setup_db(self, txn: LoggingTransaction) -> None:
         txn.execute("CREATE SEQUENCE foobar_seq")
         txn.execute(
             """
@@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         return self.get_success(self.db_pool.runWithConnection(_create))
 
-    def _insert_row(self, instance_name: str, stream_id: int):
+    def _insert_row(self, instance_name: str, stream_id: int) -> None:
         """Insert one row as the given instance with given stream_id."""
 
-        def _insert(txn):
+        def _insert(txn: LoggingTransaction) -> None:
             txn.execute(
                 "INSERT INTO foobar VALUES (?, ?)",
                 (
@@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
 
-    def test_single_instance(self):
+    def test_single_instance(self) -> None:
         """Test that reads and writes from a single process are handled
         correctly.
         """
         id_gen = self._create_id_generator()
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with id_gen.get_next() as stream_id:
                 self._insert_row("master", stream_id)
 
@@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
         self.assertEqual(id_gen.get_persisted_upto_position(), -1)
 
-        async def _get_next_async2():
+        async def _get_next_async2() -> None:
             async with id_gen.get_next_mult(3) as stream_ids:
                 for stream_id in stream_ids:
                     self._insert_row("master", stream_id)
@@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
         self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
 
-    def test_multiple_instance(self):
+    def test_multiple_instance(self) -> None:
         """Tests that having multiple instances that get advanced over
         federation works corretly.
         """
         id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
         id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with id_gen_1.get_next() as stream_id:
                 self._insert_row("first", stream_id)
                 id_gen_2.advance("first", stream_id)
@@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
 
-        async def _get_next_async2():
+        async def _get_next_async2() -> None:
             async with id_gen_2.get_next() as stream_id:
                 self._insert_row("second", stream_id)
                 id_gen_1.advance("second", stream_id)
@@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
-    def _setup_db(self, txn):
+    def _setup_db(self, txn: LoggingTransaction) -> None:
         txn.execute("CREATE SEQUENCE foobar_seq")
         txn.execute(
             """
@@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         from the postgres sequence.
         """
 
-        def _insert(txn):
+        def _insert(txn: LoggingTransaction) -> None:
             for _ in range(number):
                 txn.execute(
                     "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
@@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
 
-    def test_load_existing_stream(self):
+    def test_load_existing_stream(self) -> None:
         """Test creating ID gens with multiple tables that have rows from after
         the position in `stream_positions` table.
         """
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 38e9f58ac6..5d1aa025d1 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -12,7 +12,7 @@ from tests.unittest import TestCase
 
 
 class DummyDistribution(metadata.Distribution):
-    def __init__(self, version: str):
+    def __init__(self, version: object):
         self._version = version
 
     @property
@@ -30,6 +30,7 @@ old = DummyDistribution("0.1.2")
 old_release_candidate = DummyDistribution("0.1.2rc3")
 new = DummyDistribution("1.2.3")
 new_release_candidate = DummyDistribution("1.2.3rc4")
+distribution_with_no_version = DummyDistribution(None)
 
 # could probably use stdlib TestCase --- no need for twisted here
 
@@ -67,6 +68,18 @@ class TestDependencyChecker(TestCase):
                 # should not raise
                 check_requirements()
 
+    def test_version_reported_as_none(self) -> None:
+        """Complain if importlib.metadata.version() returns None.
+
+        This shouldn't normally happen, but it was seen in the wild (#12223).
+        """
+        with patch(
+            "synapse.util.check_dependencies.metadata.requires",
+            return_value=["dummypkg >= 1"],
+        ):
+            with self.mock_installed_package(distribution_with_no_version):
+                self.assertRaises(DependencyException, check_requirements)
+
     def test_checks_ignore_dev_dependencies(self) -> None:
         """Bot generic and per-extra checks should ignore dev dependencies."""
         with patch(