summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10776.feature1
-rw-r--r--changelog.d/10777.misc1
-rw-r--r--changelog.d/10785.misc1
-rw-r--r--changelog.d/10810.bugfix1
-rw-r--r--changelog.d/10812.misc1
-rw-r--r--changelog.d/10815.misc1
-rw-r--r--changelog.d/10816.misc1
-rw-r--r--changelog.d/10817.misc1
-rw-r--r--changelog.d/10823.misc1
-rw-r--r--changelog.d/10834.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/api/auth.py4
-rw-r--r--synapse/api/auth_blocking.py16
-rw-r--r--synapse/crypto/context_factory.py8
-rw-r--r--synapse/crypto/keyring.py2
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/federation/sender/__init__.py2
-rw-r--r--synapse/handlers/auth.py2
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/presence.py12
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/http/client.py7
-rw-r--r--synapse/http/matrixfederationclient.py17
-rw-r--r--synapse/push/httppusher.py2
-rw-r--r--synapse/push/mailer.py10
-rw-r--r--synapse/push/pusher.py8
-rw-r--r--synapse/push/pusherpool.py2
-rw-r--r--synapse/rest/__init__.py11
-rw-r--r--synapse/rest/admin/devices.py2
-rw-r--r--synapse/rest/admin/server_notice_servlet.py2
-rw-r--r--synapse/rest/admin/users.py2
-rw-r--r--synapse/rest/client/room_batch.py42
-rw-r--r--synapse/rest/consent/consent_resource.py39
-rw-r--r--synapse/rest/health.py3
-rw-r--r--synapse/rest/key/v2/__init__.py7
-rw-r--r--synapse/rest/key/v2/local_key_resource.py15
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py30
-rw-r--r--synapse/rest/media/v1/_base.py108
-rw-r--r--synapse/rest/media/v1/filepath.py6
-rw-r--r--synapse/rest/media/v1/media_repository.py48
-rw-r--r--synapse/rest/media/v1/media_storage.py68
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py5
-rw-r--r--synapse/rest/media/v1/storage_provider.py4
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py77
-rw-r--r--synapse/rest/media/v1/thumbnailer.py2
-rw-r--r--synapse/rest/synapse/client/new_user_consent.py6
-rw-r--r--synapse/rest/synapse/client/oidc/__init__.py6
-rw-r--r--synapse/rest/synapse/client/oidc/callback_resource.py5
-rw-r--r--synapse/rest/synapse/client/pick_username.py9
-rw-r--r--synapse/rest/synapse/client/saml2/__init__.py6
-rw-r--r--synapse/rest/synapse/client/saml2/metadata_resource.py9
-rw-r--r--synapse/rest/synapse/client/saml2/response_resource.py7
-rw-r--r--synapse/rest/well_known.py20
-rw-r--r--synapse/server.py16
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/room_batch.py36
-rw-r--r--synapse/storage/databases/state/bg_updates.py60
-rw-r--r--synapse/storage/databases/state/store.py136
-rw-r--r--synapse/storage/relations.py4
-rw-r--r--synapse/storage/state.py48
-rw-r--r--synapse/types.py6
-rw-r--r--synapse/util/caches/dictionary_cache.py4
-rw-r--r--tests/replication/test_multi_media_repo.py18
-rw-r--r--tests/rest/admin/test_admin.py23
-rw-r--r--tests/rest/admin/test_media.py34
-rw-r--r--tests/rest/admin/test_statistics.py12
-rw-r--r--tests/rest/admin/test_user.py19
-rw-r--r--tests/rest/media/v1/test_media_storage.py18
-rw-r--r--tests/storage/test_state.py46
-rw-r--r--tests/test_utils/__init__.py14
70 files changed, 686 insertions, 461 deletions
diff --git a/changelog.d/10776.feature b/changelog.d/10776.feature
new file mode 100644
index 0000000000..aec0685a3d
--- /dev/null
+++ b/changelog.d/10776.feature
@@ -0,0 +1 @@
+Only allow the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send?chunk_id=xxx` endpoint to connect to an already existing insertion event.
diff --git a/changelog.d/10777.misc b/changelog.d/10777.misc
new file mode 100644
index 0000000000..aed78a16f5
--- /dev/null
+++ b/changelog.d/10777.misc
@@ -0,0 +1 @@
+Split out [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) meta events to their own fields in the `/batch_send` response.
diff --git a/changelog.d/10785.misc b/changelog.d/10785.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10785.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10810.bugfix b/changelog.d/10810.bugfix
new file mode 100644
index 0000000000..43e91f1f51
--- /dev/null
+++ b/changelog.d/10810.bugfix
@@ -0,0 +1 @@
+Fix a case where logging contexts would go missing when federation requests time out.
diff --git a/changelog.d/10812.misc b/changelog.d/10812.misc
new file mode 100644
index 0000000000..586a0b3a96
--- /dev/null
+++ b/changelog.d/10812.misc
@@ -0,0 +1 @@
+Use direct references to config flags.
diff --git a/changelog.d/10815.misc b/changelog.d/10815.misc
new file mode 100644
index 0000000000..fc2534dc14
--- /dev/null
+++ b/changelog.d/10815.misc
@@ -0,0 +1 @@
+Specify the type of token in generic "Invalid token" error messages.
\ No newline at end of file
diff --git a/changelog.d/10816.misc b/changelog.d/10816.misc
new file mode 100644
index 0000000000..2ca55b334a
--- /dev/null
+++ b/changelog.d/10816.misc
@@ -0,0 +1 @@
+Make `StateFilter` frozen so it is hashable.
diff --git a/changelog.d/10817.misc b/changelog.d/10817.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10817.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10823.misc b/changelog.d/10823.misc
new file mode 100644
index 0000000000..0532969900
--- /dev/null
+++ b/changelog.d/10823.misc
@@ -0,0 +1 @@
+Add type hints to the state database.
diff --git a/changelog.d/10834.misc b/changelog.d/10834.misc
new file mode 100644
index 0000000000..037695e6e9
--- /dev/null
+++ b/changelog.d/10834.misc
@@ -0,0 +1 @@
+Factor out PNG image data to a constant to be used in several tests.
diff --git a/mypy.ini b/mypy.ini
index 09ffdda1b9..b21e1555ab 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -60,6 +60,7 @@ files =
   synapse/storage/databases/main/session.py,
   synapse/storage/databases/main/stream.py,
   synapse/storage/databases/main/ui_auth.py,
+  synapse/storage/databases/state,
   synapse/storage/database.py,
   synapse/storage/engines,
   synapse/storage/keys.py,
@@ -86,10 +87,11 @@ files =
   tests/handlers/test_sync.py,
   tests/rest/client/test_login.py,
   tests/rest/client/test_auth.py,
+  tests/storage/test_state.py,
   tests/util/test_itertools.py,
   tests/util/test_stream_change_cache.py
 
-[mypy-synapse.rest.client.*]
+[mypy-synapse.rest.*]
 disallow_untyped_defs = True
 
 [mypy-synapse.util.batching_queue]
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 05699714ee..e6ca9232ee 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -70,8 +70,8 @@ class Auth:
 
         self._auth_blocking = AuthBlocking(self.hs)
 
-        self._track_appservice_user_ips = hs.config.track_appservice_user_ips
-        self._macaroon_secret_key = hs.config.macaroon_secret_key
+        self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
+        self._macaroon_secret_key = hs.config.key.macaroon_secret_key
         self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
 
     async def check_user_in_room(
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index e6bced93d5..a3b95f4de0 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -30,13 +30,15 @@ class AuthBlocking:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
-        self._server_notices_mxid = hs.config.server_notices_mxid
-        self._hs_disabled = hs.config.hs_disabled
-        self._hs_disabled_message = hs.config.hs_disabled_message
-        self._admin_contact = hs.config.admin_contact
-        self._max_mau_value = hs.config.max_mau_value
-        self._limit_usage_by_mau = hs.config.limit_usage_by_mau
-        self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
+        self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
+        self._hs_disabled = hs.config.server.hs_disabled
+        self._hs_disabled_message = hs.config.server.hs_disabled_message
+        self._admin_contact = hs.config.server.admin_contact
+        self._max_mau_value = hs.config.server.max_mau_value
+        self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
+        self._mau_limits_reserved_threepids = (
+            hs.config.server.mau_limits_reserved_threepids
+        )
         self._server_name = hs.hostname
         self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
 
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index c644b4dfc5..d310976fe3 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -102,7 +102,7 @@ class FederationPolicyForHTTPS:
         self._config = config
 
         # Check if we're using a custom list of a CA certificates
-        trust_root = config.federation_ca_trust_root
+        trust_root = config.tls.federation_ca_trust_root
         if trust_root is None:
             # Use CA root certs provided by OpenSSL
             trust_root = platformTrust()
@@ -113,7 +113,7 @@ class FederationPolicyForHTTPS:
         # moving to TLS 1.2 by default, we want to respect the config option if
         # it is set to 1.0 (which the alternate option, raiseMinimumTo, will not
         # let us do).
-        minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version]
+        minTLS = _TLS_VERSION_MAP[config.tls.federation_client_minimum_tls_version]
 
         _verify_ssl = CertificateOptions(
             trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS
@@ -125,10 +125,10 @@ class FederationPolicyForHTTPS:
         self._no_verify_ssl_context = _no_verify_ssl.getContext()
         self._no_verify_ssl_context.set_info_callback(_context_info_cb)
 
-        self._should_verify = self._config.federation_verify_certificates
+        self._should_verify = self._config.tls.federation_verify_certificates
 
         self._federation_certificate_verification_whitelist = (
-            self._config.federation_certificate_verification_whitelist
+            self._config.tls.federation_certificate_verification_whitelist
         )
 
     def get_options(self, host: bytes):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 9e9b1c1c86..e1e13a2412 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -572,7 +572,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         super().__init__(hs)
         self.clock = hs.get_clock()
         self.client = hs.get_federation_http_client()
-        self.key_servers = self.config.key_servers
+        self.key_servers = self.config.key.key_servers
 
     async def _fetch_keys(
         self, keys_to_fetch: List[_FetchKeyRequest]
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 214ee948fa..638959cbec 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1237,7 +1237,7 @@ class FederationHandlerRegistry:
         self._edu_type_to_instance[edu_type] = instance_names
 
     async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
-        if not self.config.use_presence and edu_type == EduTypes.Presence:
+        if not self.config.server.use_presence and edu_type == EduTypes.Presence:
             return
 
         # Check if we have a handler on this instance
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4671ac0242..720d7bd74d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -594,7 +594,7 @@ class FederationSender(AbstractFederationSender):
         destinations (list[str])
         """
 
-        if not states or not self.hs.config.use_presence:
+        if not states or not self.hs.config.server.use_presence:
             # No-op if presence is disabled.
             return
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fbbf6fd834..3ea6270083 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1347,7 +1347,7 @@ class AuthHandler(BaseHandler):
         try:
             res = self.macaroon_gen.verify_short_term_login_token(login_token)
         except Exception:
-            raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
+            raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
 
         await self.auth.check_auth_blocking(res.user_id)
         return res
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 4e8f7f1d85..0b24b40eb9 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -413,7 +413,7 @@ class InitialSyncHandler(BaseHandler):
 
         async def get_presence():
             # If presence is disabled, return an empty list
-            if not self.hs.config.use_presence:
+            if not self.hs.config.server.use_presence:
                 return []
 
             states = await presence_handler.get_states(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 39b39cd3e2..4ab962a84b 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -374,7 +374,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
 
         self._presence_writer_instance = hs.config.worker.writers.presence[0]
 
-        self._presence_enabled = hs.config.use_presence
+        self._presence_enabled = hs.config.server.use_presence
 
         # Route presence EDUs to the right worker
         hs.get_federation_registry().register_instances_for_edu(
@@ -584,7 +584,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
         user_id = target_user.to_string()
 
         # If presence is disabled, no-op
-        if not self.hs.config.use_presence:
+        if not self.hs.config.server.use_presence:
             return
 
         # Proxy request to instance that writes presence
@@ -601,7 +601,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
         with the app.
         """
         # If presence is disabled, no-op
-        if not self.hs.config.use_presence:
+        if not self.hs.config.server.use_presence:
             return
 
         # Proxy request to instance that writes presence
@@ -618,7 +618,7 @@ class PresenceHandler(BasePresenceHandler):
         self.server_name = hs.hostname
         self.wheel_timer: WheelTimer[str] = WheelTimer()
         self.notifier = hs.get_notifier()
-        self._presence_enabled = hs.config.use_presence
+        self._presence_enabled = hs.config.server.use_presence
 
         federation_registry = hs.get_federation_registry()
 
@@ -916,7 +916,7 @@ class PresenceHandler(BasePresenceHandler):
         with the app.
         """
         # If presence is disabled, no-op
-        if not self.hs.config.use_presence:
+        if not self.hs.config.server.use_presence:
             return
 
         user_id = user.to_string()
@@ -949,7 +949,7 @@ class PresenceHandler(BasePresenceHandler):
         """
         # Override if it should affect the user's presence, if presence is
         # disabled.
-        if not self.hs.config.use_presence:
+        if not self.hs.config.server.use_presence:
             affect_presence = False
 
         if affect_presence:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index edfdb99cbd..7523d8e839 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1090,7 +1090,7 @@ class SyncHandler:
         block_all_presence_data = (
             since_token is None and sync_config.filter_collection.blocks_all_presence()
         )
-        if self.hs_config.use_presence and not block_all_presence_data:
+        if self.hs_config.server.use_presence and not block_all_presence_data:
             logger.debug("Fetching presence data")
             await self._generate_sync_entry_for_presence(
                 sync_result_builder,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index c2ea51ee16..5204c3d08c 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -321,8 +321,11 @@ class SimpleHttpClient:
 
         self.user_agent = hs.version_string
         self.clock = hs.get_clock()
-        if hs.config.user_agent_suffix:
-            self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
+        if hs.config.server.user_agent_suffix:
+            self.user_agent = "%s %s" % (
+                self.user_agent,
+                hs.config.server.user_agent_suffix,
+            )
 
         # We use this for our body producers to ensure that they use the correct
         # reactor.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 2e9898997c..ef10ec0937 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -66,7 +66,7 @@ from synapse.http.client import (
 )
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
 from synapse.logging import opentracing
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import set_tag, start_active_span, tags
 from synapse.types import JsonDict
 from synapse.util import json_decoder
@@ -553,20 +553,29 @@ class MatrixFederationHttpClient:
                         with Measure(self.clock, "outbound_request"):
                             # we don't want all the fancy cookie and redirect handling
                             # that treq.request gives: just use the raw Agent.
-                            request_deferred = self.agent.request(
+
+                            # To preserve the logging context, the timeout is treated
+                            # in a similar way to `defer.gatherResults`:
+                            # * Each logging context-preserving fork is wrapped in
+                            #   `run_in_background`. In this case there is only one,
+                            #   since the timeout fork is not logging-context aware.
+                            # * The `Deferred` that joins the forks back together is
+                            #   wrapped in `make_deferred_yieldable` to restore the
+                            #   logging context regardless of the path taken.
+                            request_deferred = run_in_background(
+                                self.agent.request,
                                 method_bytes,
                                 url_bytes,
                                 headers=Headers(headers_dict),
                                 bodyProducer=producer,
                             )
-
                             request_deferred = timeout_deferred(
                                 request_deferred,
                                 timeout=_sec_timeout,
                                 reactor=self.reactor,
                             )
 
-                            response = await request_deferred
+                            response = await make_deferred_yieldable(request_deferred)
                     except DNSLookupError as e:
                         raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
                     except Exception as e:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 36aabd8422..065948f982 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -365,7 +365,7 @@ class HttpPusher(Pusher):
         if event.type == "m.room.member" and event.is_state():
             d["notification"]["membership"] = event.content["membership"]
             d["notification"]["user_is_target"] = event.state_key == self.user_id
-        if self.hs.config.push_include_content and event.content:
+        if self.hs.config.push.push_include_content and event.content:
             d["notification"]["content"] = event.content
 
         # We no longer send aliases separately, instead, we send the human
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index b89c6e6f2b..e38e3c5d44 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -110,7 +110,7 @@ class Mailer:
         self.state_handler = self.hs.get_state_handler()
         self.storage = hs.get_storage()
         self.app_name = app_name
-        self.email_subjects: EmailSubjectConfig = hs.config.email_subjects
+        self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
 
         logger.info("Created Mailer for app_name %s" % app_name)
 
@@ -796,8 +796,8 @@ class Mailer:
         Returns:
              A link to open a room in the web client.
         """
-        if self.hs.config.email_riot_base_url:
-            base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
+        if self.hs.config.email.email_riot_base_url:
+            base_url = "%s/#/room" % (self.hs.config.email.email_riot_base_url)
         elif self.app_name == "Vector":
             # need /beta for Universal Links to work on iOS
             base_url = "https://vector.im/beta/#/room"
@@ -815,9 +815,9 @@ class Mailer:
         Returns:
              A link to open the notification in the web client.
         """
-        if self.hs.config.email_riot_base_url:
+        if self.hs.config.email.email_riot_base_url:
             return "%s/#/room/%s/%s" % (
-                self.hs.config.email_riot_base_url,
+                self.hs.config.email.email_riot_base_url,
                 notif["room_id"],
                 notif["event_id"],
             )
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 021275437c..29ed346d37 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -35,12 +35,12 @@ class PusherFactory:
             "http": HttpPusher
         }
 
-        logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
-        if hs.config.email_enable_notifs:
+        logger.info("email enable notifs: %r", hs.config.email.email_enable_notifs)
+        if hs.config.email.email_enable_notifs:
             self.mailers: Dict[str, Mailer] = {}
 
-            self._notif_template_html = hs.config.email_notif_template_html
-            self._notif_template_text = hs.config.email_notif_template_text
+            self._notif_template_html = hs.config.email.email_notif_template_html
+            self._notif_template_text = hs.config.email.email_notif_template_text
 
             self.pusher_types["email"] = self._create_email_pusher
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index a1436f3930..26735447a6 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -62,7 +62,7 @@ class PusherPool:
         self.clock = self.hs.get_clock()
 
         # We shard the handling of push notifications by user ID.
-        self._pusher_shard_config = hs.config.push.pusher_shard_config
+        self._pusher_shard_config = hs.config.worker.pusher_shard_config
         self._instance_name = hs.get_instance_name()
         self._should_start_pushers = (
             self._instance_name in self._pusher_shard_config.instances
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 3adc576124..e04af705eb 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -12,7 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.http.server import JsonResource
+from typing import TYPE_CHECKING
+
+from synapse.http.server import HttpServer, JsonResource
 from synapse.rest import admin
 from synapse.rest.client import (
     account,
@@ -57,6 +59,9 @@ from synapse.rest.client import (
     voip,
 )
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class ClientRestResource(JsonResource):
     """Matrix Client API REST resource.
@@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
        * etc
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         JsonResource.__init__(self, hs, canonical_json=False)
         self.register_servlets(self, hs)
 
     @staticmethod
-    def register_servlets(client_resource, hs):
+    def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
         versions.register_servlets(hs, client_resource)
 
         # Deprecated in r0
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 5715190a78..a6fa03c90f 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet):
         self.store = hs.get_datastore()
 
     async def on_GET(
-        self, request: SynapseRequest, user_id, device_id: str
+        self, request: SynapseRequest, user_id: str, device_id: str
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index f5a38c2670..19f84f33f2 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet):
         self.admin_handler = hs.get_admin_handler()
         self.txns = HttpTransactionCache(hs)
 
-    def register(self, json_resource: HttpServer):
+    def register(self, json_resource: HttpServer) -> None:
         PATTERN = "/send_server_notice"
         json_resource.register_paths(
             "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index c1a1ba645e..681e491826 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet):
         self.nonces: Dict[str, int] = {}
         self.hs = hs
 
-    def _clear_old_nonces(self):
+    def _clear_old_nonces(self) -> None:
         """
         Clear out old nonces that are older than NONCE_TIMEOUT.
         """
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index ed96978448..d466edeec2 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -14,6 +14,7 @@
 
 import logging
 import re
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Awaitable, List, Tuple
 
 from twisted.web.server import Request
@@ -179,7 +180,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
 
         if not requester.app_service:
             raise AuthError(
-                403,
+                HTTPStatus.FORBIDDEN,
                 "Only application services can use the /batchsend endpoint",
             )
 
@@ -192,7 +193,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
 
         if prev_events_from_query is None:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "prev_event query parameter is required when inserting historical messages back in time",
                 errcode=Codes.MISSING_PARAM,
             )
@@ -213,7 +214,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
         prev_state_ids = list(prev_state_map.values())
         auth_event_ids = prev_state_ids
 
-        state_events_at_start = []
+        state_event_ids_at_start = []
         for state_event in body["state_events_at_start"]:
             assert_params_in_dict(
                 state_event, ["type", "origin_server_ts", "content", "sender"]
@@ -279,7 +280,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
                 )
                 event_id = event.event_id
 
-            state_events_at_start.append(event_id)
+            state_event_ids_at_start.append(event_id)
             auth_event_ids.append(event_id)
 
         events_to_create = body["events"]
@@ -299,7 +300,18 @@ class RoomBatchSendEventRestServlet(RestServlet):
             #  event, which causes the HS to ask for the state at the start of
             #  the chunk later.
             prev_event_ids = [fake_prev_event_id]
-            # TODO: Verify the chunk_id_from_query corresponds to an insertion event
+
+            # Verify the chunk_id_from_query corresponds to an actual insertion event
+            # and have the chunk connected.
+            corresponding_insertion_event_id = (
+                await self.store.get_insertion_event_by_chunk_id(chunk_id_from_query)
+            )
+            if corresponding_insertion_event_id is None:
+                raise SynapseError(
+                    400,
+                    "No insertion event corresponds to the given ?chunk_id",
+                    errcode=Codes.INVALID_PARAM,
+                )
             pass
         # Otherwise, create an insertion event to act as a starting point.
         #
@@ -424,20 +436,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
                 context=context,
             )
 
-        # Add the base_insertion_event to the bottom of the list we return
-        if base_insertion_event is not None:
-            event_ids.append(base_insertion_event.event_id)
+        insertion_event_id = event_ids[0]
+        chunk_event_id = event_ids[-1]
+        historical_event_ids = event_ids[1:-1]
 
-        return 200, {
-            "state_events": state_events_at_start,
-            "events": event_ids,
+        response_dict = {
+            "state_event_ids": state_event_ids_at_start,
+            "event_ids": historical_event_ids,
             "next_chunk_id": insertion_event["content"][
                 EventContentFields.MSC2716_NEXT_CHUNK_ID
             ],
+            "insertion_event_id": insertion_event_id,
+            "chunk_event_id": chunk_event_id,
         }
+        if base_insertion_event is not None:
+            response_dict["base_insertion_event_id"] = base_insertion_event.event_id
+
+        return HTTPStatus.OK, response_dict
 
     def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
-        return 501, "Not implemented"
+        return HTTPStatus.NOT_IMPLEMENTED, "Not implemented"
 
     def on_PUT(
         self, request: SynapseRequest, room_id: str
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 11f7320832..06e0fbde22 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -17,17 +17,22 @@ import logging
 from hashlib import sha256
 from http import HTTPStatus
 from os import path
-from typing import Dict, List
+from typing import TYPE_CHECKING, Any, Dict, List
 
 import jinja2
 from jinja2 import TemplateNotFound
 
+from twisted.web.server import Request
+
 from synapse.api.errors import NotFoundError, StoreError, SynapseError
 from synapse.config import ConfigError
 from synapse.http.server import DirectServeHtmlResource, respond_with_html
 from synapse.http.servlet import parse_bytes_from_args, parse_string
 from synapse.types import UserID
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 # language to use for the templates. TODO: figure this out from Accept-Language
 TEMPLATE_LANGUAGE = "en"
 
@@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
            against the user.
     """
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): homeserver
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
 
         self.hs = hs
@@ -106,18 +107,14 @@ class ConsentResource(DirectServeHtmlResource):
 
         self._hmac_secret = hs.config.form_secret.encode("utf-8")
 
-    async def _async_render_GET(self, request):
-        """
-        Args:
-            request (twisted.web.http.Request):
-        """
+    async def _async_render_GET(self, request: Request) -> None:
         version = parse_string(request, "v", default=self._default_consent_version)
         username = parse_string(request, "u", default="")
         userhmac = None
         has_consented = False
         public_version = username == ""
         if not public_version:
-            args: Dict[bytes, List[bytes]] = request.args
+            args: Dict[bytes, List[bytes]] = request.args  # type: ignore
             userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
 
             self._check_hash(username, userhmac_bytes)
@@ -147,14 +144,10 @@ class ConsentResource(DirectServeHtmlResource):
         except TemplateNotFound:
             raise NotFoundError("Unknown policy version")
 
-    async def _async_render_POST(self, request):
-        """
-        Args:
-            request (twisted.web.http.Request):
-        """
+    async def _async_render_POST(self, request: Request) -> None:
         version = parse_string(request, "v", required=True)
         username = parse_string(request, "u", required=True)
-        args: Dict[bytes, List[bytes]] = request.args
+        args: Dict[bytes, List[bytes]] = request.args  # type: ignore
         userhmac = parse_bytes_from_args(args, "h", required=True)
 
         self._check_hash(username, userhmac)
@@ -177,7 +170,9 @@ class ConsentResource(DirectServeHtmlResource):
         except TemplateNotFound:
             raise NotFoundError("success.html not found")
 
-    def _render_template(self, request, template_name, **template_args):
+    def _render_template(
+        self, request: Request, template_name: str, **template_args: Any
+    ) -> None:
         # get_template checks for ".." so we don't need to worry too much
         # about path traversal here.
         template_html = self._jinja_env.get_template(
@@ -186,11 +181,11 @@ class ConsentResource(DirectServeHtmlResource):
         html = template_html.render(**template_args)
         respond_with_html(request, 200, html)
 
-    def _check_hash(self, userid, userhmac):
+    def _check_hash(self, userid: str, userhmac: bytes) -> None:
         """
         Args:
-            userid (unicode):
-            userhmac (bytes):
+            userid:
+            userhmac:
 
         Raises:
               SynapseError if the hash doesn't match
diff --git a/synapse/rest/health.py b/synapse/rest/health.py
index 4487b54abf..78df7af2cf 100644
--- a/synapse/rest/health.py
+++ b/synapse/rest/health.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 from twisted.web.resource import Resource
+from twisted.web.server import Request
 
 
 class HealthResource(Resource):
@@ -25,6 +26,6 @@ class HealthResource(Resource):
 
     isLeaf = 1
 
-    def render_GET(self, request):
+    def render_GET(self, request: Request) -> bytes:
         request.setHeader(b"Content-Type", b"text/plain")
         return b"OK"
diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index c6c63073ea..7f8c1de1ff 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -12,14 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING
+
 from twisted.web.resource import Resource
 
 from .local_key_resource import LocalKey
 from .remote_key_resource import RemoteKey
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class KeyApiV2Resource(Resource):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         Resource.__init__(self)
         self.putChild(b"server", LocalKey(hs))
         self.putChild(b"query", RemoteKey(hs))
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 25f6eb842f..ebe243bcfd 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -12,16 +12,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import logging
+from typing import TYPE_CHECKING
 
 from canonicaljson import encode_canonical_json
 from signedjson.sign import sign_json
 from unpaddedbase64 import encode_base64
 
 from twisted.web.resource import Resource
+from twisted.web.server import Request
 
 from synapse.http.server import respond_with_json_bytes
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -58,18 +63,18 @@ class LocalKey(Resource):
 
     isLeaf = True
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.config = hs.config
         self.clock = hs.get_clock()
         self.update_response_body(self.clock.time_msec())
         Resource.__init__(self)
 
-    def update_response_body(self, time_now_msec):
+    def update_response_body(self, time_now_msec: int) -> None:
         refresh_interval = self.config.key_refresh_interval
         self.valid_until_ts = int(time_now_msec + refresh_interval)
         self.response_body = encode_canonical_json(self.response_json_object())
 
-    def response_json_object(self):
+    def response_json_object(self) -> JsonDict:
         verify_keys = {}
         for key in self.config.signing_key:
             verify_key_bytes = key.verify_key.encode()
@@ -94,7 +99,7 @@ class LocalKey(Resource):
             json_object = sign_json(json_object, self.config.server.server_name, key)
         return json_object
 
-    def render_GET(self, request):
+    def render_GET(self, request: Request) -> int:
         time_now = self.clock.time_msec()
         # Update the expiry time if less than half the interval remains.
         if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 744360e5fd..d8fd7938a4 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,17 +13,23 @@
 # limitations under the License.
 
 import logging
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
 
 from signedjson.sign import sign_json
 
+from twisted.web.server import Request
+
 from synapse.api.errors import Codes, SynapseError
 from synapse.crypto.keyring import ServerKeyFetcher
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.types import JsonDict
 from synapse.util import json_decoder
 from synapse.util.async_helpers import yieldable_gather_results
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -85,7 +91,7 @@ class RemoteKey(DirectServeJsonResource):
 
     isLeaf = True
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
 
         self.fetcher = ServerKeyFetcher(hs)
@@ -94,7 +100,8 @@ class RemoteKey(DirectServeJsonResource):
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
         self.config = hs.config
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
+        assert request.postpath is not None
         if len(request.postpath) == 1:
             (server,) = request.postpath
             query: dict = {server.decode("ascii"): {}}
@@ -110,14 +117,19 @@ class RemoteKey(DirectServeJsonResource):
 
         await self.query_keys(request, query, query_remote_on_cache_miss=True)
 
-    async def _async_render_POST(self, request):
+    async def _async_render_POST(self, request: Request) -> None:
         content = parse_json_object_from_request(request)
 
         query = content["server_keys"]
 
         await self.query_keys(request, query, query_remote_on_cache_miss=True)
 
-    async def query_keys(self, request, query, query_remote_on_cache_miss=False):
+    async def query_keys(
+        self,
+        request: Request,
+        query: JsonDict,
+        query_remote_on_cache_miss: bool = False,
+    ) -> None:
         logger.info("Handling query for keys %r", query)
 
         store_queries = []
@@ -142,8 +154,8 @@ class RemoteKey(DirectServeJsonResource):
 
         # Note that the value is unused.
         cache_misses: Dict[str, Dict[str, int]] = {}
-        for (server_name, key_id, _), results in cached.items():
-            results = [(result["ts_added_ms"], result) for result in results]
+        for (server_name, key_id, _), key_results in cached.items():
+            results = [(result["ts_added_ms"], result) for result in key_results]
 
             if not results and key_id is not None:
                 cache_misses.setdefault(server_name, {})[key_id] = 0
@@ -230,6 +242,6 @@ class RemoteKey(DirectServeJsonResource):
 
                 signed_keys.append(key_json)
 
-            results = {"server_keys": signed_keys}
+            response = {"server_keys": signed_keys}
 
-            respond_with_json(request, 200, results, canonical_json=True)
+            respond_with_json(request, 200, response, canonical_json=True)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 90364ebcf7..7c881f2bdb 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -16,7 +16,10 @@
 import logging
 import os
 import urllib
-from typing import Awaitable, Dict, Generator, List, Optional, Tuple
+from types import TracebackType
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
+
+import attr
 
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
@@ -120,7 +123,7 @@ def add_file_headers(
         upload_name: The name of the requested file, if any.
     """
 
-    def _quote(x):
+    def _quote(x: str) -> str:
         return urllib.parse.quote(x.encode("utf-8"))
 
     # Default to a UTF-8 charset for text content types.
@@ -280,51 +283,74 @@ class Responder:
         """
         pass
 
-    def __enter__(self):
+    def __enter__(self) -> None:
         pass
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
         pass
 
 
-class FileInfo:
-    """Details about a requested/uploaded file.
-
-    Attributes:
-        server_name (str): The server name where the media originated from,
-            or None if local.
-        file_id (str): The local ID of the file. For local files this is the
-            same as the media_id
-        url_cache (bool): If the file is for the url preview cache
-        thumbnail (bool): Whether the file is a thumbnail or not.
-        thumbnail_width (int)
-        thumbnail_height (int)
-        thumbnail_method (str)
-        thumbnail_type (str): Content type of thumbnail, e.g. image/png
-        thumbnail_length (int): The size of the media file, in bytes.
-    """
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThumbnailInfo:
+    """Details about a generated thumbnail."""
 
-    def __init__(
-        self,
-        server_name,
-        file_id,
-        url_cache=False,
-        thumbnail=False,
-        thumbnail_width=None,
-        thumbnail_height=None,
-        thumbnail_method=None,
-        thumbnail_type=None,
-        thumbnail_length=None,
-    ):
-        self.server_name = server_name
-        self.file_id = file_id
-        self.url_cache = url_cache
-        self.thumbnail = thumbnail
-        self.thumbnail_width = thumbnail_width
-        self.thumbnail_height = thumbnail_height
-        self.thumbnail_method = thumbnail_method
-        self.thumbnail_type = thumbnail_type
-        self.thumbnail_length = thumbnail_length
+    width: int
+    height: int
+    method: str
+    # Content type of thumbnail, e.g. image/png
+    type: str
+    # The size of the media file, in bytes.
+    length: Optional[int] = None
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FileInfo:
+    """Details about a requested/uploaded file."""
+
+    # The server name where the media originated from, or None if local.
+    server_name: Optional[str]
+    # The local ID of the file. For local files this is the same as the media_id
+    file_id: str
+    # If the file is for the url preview cache
+    url_cache: bool = False
+    # Whether the file is a thumbnail or not.
+    thumbnail: Optional[ThumbnailInfo] = None
+
+    # The below properties exist to maintain compatibility with third-party modules.
+    @property
+    def thumbnail_width(self) -> Optional[int]:
+        if not self.thumbnail:
+            return None
+        return self.thumbnail.width
+
+    @property
+    def thumbnail_height(self) -> Optional[int]:
+        if not self.thumbnail:
+            return None
+        return self.thumbnail.height
+
+    @property
+    def thumbnail_method(self) -> Optional[str]:
+        if not self.thumbnail:
+            return None
+        return self.thumbnail.method
+
+    @property
+    def thumbnail_type(self) -> Optional[str]:
+        if not self.thumbnail:
+            return None
+        return self.thumbnail.type
+
+    @property
+    def thumbnail_length(self) -> Optional[int]:
+        if not self.thumbnail:
+            return None
+        return self.thumbnail.length
 
 
 def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 09531ebf54..39bbe4e874 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -16,7 +16,7 @@
 import functools
 import os
 import re
-from typing import Callable, List
+from typing import Any, Callable, List
 
 NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
 
@@ -27,7 +27,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]:
     """
 
     @functools.wraps(func)
-    def _wrapped(self, *args, **kwargs):
+    def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str:
         path = func(self, *args, **kwargs)
         return os.path.join(self.base_path, path)
 
@@ -129,7 +129,7 @@ class MediaFilePaths:
     # using the new path.
     def remote_media_thumbnail_rel_legacy(
         self, server_name: str, file_id: str, width: int, height: int, content_type: str
-    ):
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
         return os.path.join(
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0f5ce41ff8..50e4c9e29f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -21,6 +21,7 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 import twisted.internet.error
 import twisted.web.http
+from twisted.internet.defer import Deferred
 from twisted.web.resource import Resource
 from twisted.web.server import Request
 
@@ -32,6 +33,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.config._base import ConfigError
+from synapse.config.repository import ThumbnailRequirement
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import UserID
@@ -42,6 +44,7 @@ from synapse.util.stringutils import random_string
 from ._base import (
     FileInfo,
     Responder,
+    ThumbnailInfo,
     get_filename_from_headers,
     respond_404,
     respond_with_responder,
@@ -113,7 +116,7 @@ class MediaRepository:
             self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
         )
 
-    def _start_update_recently_accessed(self):
+    def _start_update_recently_accessed(self) -> Deferred:
         return run_as_background_process(
             "update_recently_accessed_media", self._update_recently_accessed
         )
@@ -210,7 +213,7 @@ class MediaRepository:
         upload_name = name if name else media_info["upload_name"]
         url_cache = media_info["url_cache"]
 
-        file_info = FileInfo(None, media_id, url_cache=url_cache)
+        file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
 
         responder = await self.media_storage.fetch_media(file_info)
         await respond_with_responder(
@@ -468,7 +471,9 @@ class MediaRepository:
 
         return media_info
 
-    def _get_thumbnail_requirements(self, media_type):
+    def _get_thumbnail_requirements(
+        self, media_type: str
+    ) -> Tuple[ThumbnailRequirement, ...]:
         scpos = media_type.find(";")
         if scpos > 0:
             media_type = media_type[:scpos]
@@ -514,7 +519,7 @@ class MediaRepository:
         t_height: int,
         t_method: str,
         t_type: str,
-        url_cache: Optional[str],
+        url_cache: bool,
     ) -> Optional[str]:
         input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(None, media_id, url_cache=url_cache)
@@ -548,11 +553,12 @@ class MediaRepository:
                     server_name=None,
                     file_id=media_id,
                     url_cache=url_cache,
-                    thumbnail=True,
-                    thumbnail_width=t_width,
-                    thumbnail_height=t_height,
-                    thumbnail_method=t_method,
-                    thumbnail_type=t_type,
+                    thumbnail=ThumbnailInfo(
+                        width=t_width,
+                        height=t_height,
+                        method=t_method,
+                        type=t_type,
+                    ),
                 )
 
                 output_path = await self.media_storage.store_file(
@@ -585,7 +591,7 @@ class MediaRepository:
         t_type: str,
     ) -> Optional[str]:
         input_path = await self.media_storage.ensure_media_is_in_local_cache(
-            FileInfo(server_name, file_id, url_cache=False)
+            FileInfo(server_name, file_id)
         )
 
         try:
@@ -616,11 +622,12 @@ class MediaRepository:
                 file_info = FileInfo(
                     server_name=server_name,
                     file_id=file_id,
-                    thumbnail=True,
-                    thumbnail_width=t_width,
-                    thumbnail_height=t_height,
-                    thumbnail_method=t_method,
-                    thumbnail_type=t_type,
+                    thumbnail=ThumbnailInfo(
+                        width=t_width,
+                        height=t_height,
+                        method=t_method,
+                        type=t_type,
+                    ),
                 )
 
                 output_path = await self.media_storage.store_file(
@@ -742,12 +749,13 @@ class MediaRepository:
             file_info = FileInfo(
                 server_name=server_name,
                 file_id=file_id,
-                thumbnail=True,
-                thumbnail_width=t_width,
-                thumbnail_height=t_height,
-                thumbnail_method=t_method,
-                thumbnail_type=t_type,
                 url_cache=url_cache,
+                thumbnail=ThumbnailInfo(
+                    width=t_width,
+                    height=t_height,
+                    method=t_method,
+                    type=t_type,
+                ),
             )
 
             with self.media_storage.store_into_file(file_info) as (f, fname, finish):
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 56cdc1b4ed..01fada8fb5 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -15,7 +15,20 @@ import contextlib
 import logging
 import os
 import shutil
-from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+from types import TracebackType
+from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    BinaryIO,
+    Callable,
+    Generator,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+)
 
 import attr
 
@@ -83,12 +96,14 @@ class MediaStorage:
 
         return fname
 
-    async def write_to_file(self, source: IO, output: IO):
+    async def write_to_file(self, source: IO, output: IO) -> None:
         """Asynchronously write the `source` to `output`."""
         await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
 
     @contextlib.contextmanager
-    def store_into_file(self, file_info: FileInfo):
+    def store_into_file(
+        self, file_info: FileInfo
+    ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
         """Context manager used to get a file like object to write into, as
         described by file_info.
 
@@ -125,7 +140,7 @@ class MediaStorage:
         try:
             with open(fname, "wb") as f:
 
-                async def finish():
+                async def finish() -> None:
                     # Ensure that all writes have been flushed and close the
                     # file.
                     f.flush()
@@ -176,9 +191,9 @@ class MediaStorage:
                 self.filepaths.remote_media_thumbnail_rel_legacy(
                     server_name=file_info.server_name,
                     file_id=file_info.file_id,
-                    width=file_info.thumbnail_width,
-                    height=file_info.thumbnail_height,
-                    content_type=file_info.thumbnail_type,
+                    width=file_info.thumbnail.width,
+                    height=file_info.thumbnail.height,
+                    content_type=file_info.thumbnail.type,
                 )
             )
 
@@ -220,9 +235,9 @@ class MediaStorage:
             legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
                 server_name=file_info.server_name,
                 file_id=file_info.file_id,
-                width=file_info.thumbnail_width,
-                height=file_info.thumbnail_height,
-                content_type=file_info.thumbnail_type,
+                width=file_info.thumbnail.width,
+                height=file_info.thumbnail.height,
+                content_type=file_info.thumbnail.type,
             )
             legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
             if os.path.exists(legacy_local_path):
@@ -255,10 +270,10 @@ class MediaStorage:
             if file_info.thumbnail:
                 return self.filepaths.url_cache_thumbnail_rel(
                     media_id=file_info.file_id,
-                    width=file_info.thumbnail_width,
-                    height=file_info.thumbnail_height,
-                    content_type=file_info.thumbnail_type,
-                    method=file_info.thumbnail_method,
+                    width=file_info.thumbnail.width,
+                    height=file_info.thumbnail.height,
+                    content_type=file_info.thumbnail.type,
+                    method=file_info.thumbnail.method,
                 )
             return self.filepaths.url_cache_filepath_rel(file_info.file_id)
 
@@ -267,10 +282,10 @@ class MediaStorage:
                 return self.filepaths.remote_media_thumbnail_rel(
                     server_name=file_info.server_name,
                     file_id=file_info.file_id,
-                    width=file_info.thumbnail_width,
-                    height=file_info.thumbnail_height,
-                    content_type=file_info.thumbnail_type,
-                    method=file_info.thumbnail_method,
+                    width=file_info.thumbnail.width,
+                    height=file_info.thumbnail.height,
+                    content_type=file_info.thumbnail.type,
+                    method=file_info.thumbnail.method,
                 )
             return self.filepaths.remote_media_filepath_rel(
                 file_info.server_name, file_info.file_id
@@ -279,10 +294,10 @@ class MediaStorage:
         if file_info.thumbnail:
             return self.filepaths.local_media_thumbnail_rel(
                 media_id=file_info.file_id,
-                width=file_info.thumbnail_width,
-                height=file_info.thumbnail_height,
-                content_type=file_info.thumbnail_type,
-                method=file_info.thumbnail_method,
+                width=file_info.thumbnail.width,
+                height=file_info.thumbnail.height,
+                content_type=file_info.thumbnail.type,
+                method=file_info.thumbnail.method,
             )
         return self.filepaths.local_media_filepath_rel(file_info.file_id)
 
@@ -315,7 +330,12 @@ class FileResponder(Responder):
             FileSender().beginFileTransfer(self.open_file, consumer)
         )
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
         self.open_file.close()
 
 
@@ -339,7 +359,7 @@ class ReadableFileWrapper:
     clock = attr.ib(type=Clock)
     path = attr.ib(type=str)
 
-    async def write_chunks_to(self, callback: Callable[[bytes], None]):
+    async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
         """Reads the file in chunks and calls the callback with each chunk."""
 
         with open(self.path, "rb") as file:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index f108da05db..fe0627d9b0 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -27,6 +27,7 @@ from urllib import parse as urlparse
 
 import attr
 
+from twisted.internet.defer import Deferred
 from twisted.internet.error import DNSLookupError
 from twisted.web.server import Request
 
@@ -473,7 +474,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             etag=etag,
         )
 
-    def _start_expire_url_cache_data(self):
+    def _start_expire_url_cache_data(self) -> Deferred:
         return run_as_background_process(
             "expire_url_cache_data", self._expire_url_cache_data
         )
@@ -782,7 +783,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
 
 
 def _iterate_over_text(
-    tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+    tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
 ) -> Generator[str, None, None]:
     """Iterate over the tree returning text nodes in a depth first fashion,
     skipping text nodes inside certain tags.
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 0ff6ad3c0c..6c9969e55f 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -99,7 +99,7 @@ class StorageProviderWrapper(StorageProvider):
             await maybe_awaitable(self.backend.store_file(path, file_info))  # type: ignore
         else:
             # TODO: Handle errors.
-            async def store():
+            async def store() -> None:
                 try:
                     return await maybe_awaitable(
                         self.backend.store_file(path, file_info)
@@ -128,7 +128,7 @@ class FileStorageProviderBackend(StorageProvider):
         self.cache_directory = hs.config.media_store_path
         self.base_directory = config
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "FileStorageProviderBackend[%s]" % (self.base_directory,)
 
     async def store_file(self, path: str, file_info: FileInfo) -> None:
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 12bd745cb2..22f43d8531 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -26,6 +26,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
 
 from ._base import (
     FileInfo,
+    ThumbnailInfo,
     parse_media_id,
     respond_404,
     respond_with_file,
@@ -114,7 +115,7 @@ class ThumbnailResource(DirectServeJsonResource):
             thumbnail_infos,
             media_id,
             media_id,
-            url_cache=media_info["url_cache"],
+            url_cache=bool(media_info["url_cache"]),
             server_name=None,
         )
 
@@ -149,11 +150,12 @@ class ThumbnailResource(DirectServeJsonResource):
                     server_name=None,
                     file_id=media_id,
                     url_cache=media_info["url_cache"],
-                    thumbnail=True,
-                    thumbnail_width=info["thumbnail_width"],
-                    thumbnail_height=info["thumbnail_height"],
-                    thumbnail_type=info["thumbnail_type"],
-                    thumbnail_method=info["thumbnail_method"],
+                    thumbnail=ThumbnailInfo(
+                        width=info["thumbnail_width"],
+                        height=info["thumbnail_height"],
+                        type=info["thumbnail_type"],
+                        method=info["thumbnail_method"],
+                    ),
                 )
 
                 t_type = file_info.thumbnail_type
@@ -173,7 +175,7 @@ class ThumbnailResource(DirectServeJsonResource):
             desired_height,
             desired_method,
             desired_type,
-            url_cache=media_info["url_cache"],
+            url_cache=bool(media_info["url_cache"]),
         )
 
         if file_path:
@@ -210,11 +212,12 @@ class ThumbnailResource(DirectServeJsonResource):
                 file_info = FileInfo(
                     server_name=server_name,
                     file_id=media_info["filesystem_id"],
-                    thumbnail=True,
-                    thumbnail_width=info["thumbnail_width"],
-                    thumbnail_height=info["thumbnail_height"],
-                    thumbnail_type=info["thumbnail_type"],
-                    thumbnail_method=info["thumbnail_method"],
+                    thumbnail=ThumbnailInfo(
+                        width=info["thumbnail_width"],
+                        height=info["thumbnail_height"],
+                        type=info["thumbnail_type"],
+                        method=info["thumbnail_method"],
+                    ),
                 )
 
                 t_type = file_info.thumbnail_type
@@ -271,7 +274,7 @@ class ThumbnailResource(DirectServeJsonResource):
             thumbnail_infos,
             media_id,
             media_info["filesystem_id"],
-            url_cache=None,
+            url_cache=False,
             server_name=server_name,
         )
 
@@ -285,7 +288,7 @@ class ThumbnailResource(DirectServeJsonResource):
         thumbnail_infos: List[Dict[str, Any]],
         media_id: str,
         file_id: str,
-        url_cache: Optional[str] = None,
+        url_cache: bool,
         server_name: Optional[str] = None,
     ) -> None:
         """
@@ -299,7 +302,7 @@ class ThumbnailResource(DirectServeJsonResource):
             desired_type: The desired content-type of the thumbnail.
             thumbnail_infos: A list of dictionaries of candidate thumbnails.
             file_id: The ID of the media that a thumbnail is being requested for.
-            url_cache: The URL cache value.
+            url_cache: True if this is from a URL cache.
             server_name: The server name, if this is a remote thumbnail.
         """
         if thumbnail_infos:
@@ -318,13 +321,16 @@ class ThumbnailResource(DirectServeJsonResource):
                 respond_404(request)
                 return
 
+            # The thumbnail property must exist.
+            assert file_info.thumbnail is not None
+
             responder = await self.media_storage.fetch_media(file_info)
             if responder:
                 await respond_with_responder(
                     request,
                     responder,
-                    file_info.thumbnail_type,
-                    file_info.thumbnail_length,
+                    file_info.thumbnail.type,
+                    file_info.thumbnail.length,
                 )
                 return
 
@@ -351,18 +357,18 @@ class ThumbnailResource(DirectServeJsonResource):
                     server_name,
                     file_id=file_id,
                     media_id=media_id,
-                    t_width=file_info.thumbnail_width,
-                    t_height=file_info.thumbnail_height,
-                    t_method=file_info.thumbnail_method,
-                    t_type=file_info.thumbnail_type,
+                    t_width=file_info.thumbnail.width,
+                    t_height=file_info.thumbnail.height,
+                    t_method=file_info.thumbnail.method,
+                    t_type=file_info.thumbnail.type,
                 )
             else:
                 await self.media_repo.generate_local_exact_thumbnail(
                     media_id=media_id,
-                    t_width=file_info.thumbnail_width,
-                    t_height=file_info.thumbnail_height,
-                    t_method=file_info.thumbnail_method,
-                    t_type=file_info.thumbnail_type,
+                    t_width=file_info.thumbnail.width,
+                    t_height=file_info.thumbnail.height,
+                    t_method=file_info.thumbnail.method,
+                    t_type=file_info.thumbnail.type,
                     url_cache=url_cache,
                 )
 
@@ -370,8 +376,8 @@ class ThumbnailResource(DirectServeJsonResource):
             await respond_with_responder(
                 request,
                 responder,
-                file_info.thumbnail_type,
-                file_info.thumbnail_length,
+                file_info.thumbnail.type,
+                file_info.thumbnail.length,
             )
         else:
             logger.info("Failed to find any generated thumbnails")
@@ -385,7 +391,7 @@ class ThumbnailResource(DirectServeJsonResource):
         desired_type: str,
         thumbnail_infos: List[Dict[str, Any]],
         file_id: str,
-        url_cache: Optional[str],
+        url_cache: bool,
         server_name: Optional[str],
     ) -> Optional[FileInfo]:
         """
@@ -398,7 +404,7 @@ class ThumbnailResource(DirectServeJsonResource):
             desired_type: The desired content-type of the thumbnail.
             thumbnail_infos: A list of dictionaries of candidate thumbnails.
             file_id: The ID of the media that a thumbnail is being requested for.
-            url_cache: The URL cache value.
+            url_cache: True if this is from a URL cache.
             server_name: The server name, if this is a remote thumbnail.
 
         Returns:
@@ -495,12 +501,13 @@ class ThumbnailResource(DirectServeJsonResource):
                 file_id=file_id,
                 url_cache=url_cache,
                 server_name=server_name,
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-                thumbnail_length=thumbnail_info["thumbnail_length"],
+                thumbnail=ThumbnailInfo(
+                    width=thumbnail_info["thumbnail_width"],
+                    height=thumbnail_info["thumbnail_height"],
+                    type=thumbnail_info["thumbnail_type"],
+                    method=thumbnail_info["thumbnail_method"],
+                    length=thumbnail_info["thumbnail_length"],
+                ),
             )
 
         # No matching thumbnail was found.
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index a65e9e1802..df54a40649 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -41,7 +41,7 @@ class Thumbnailer:
     FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
 
     @staticmethod
-    def set_limits(max_image_pixels: int):
+    def set_limits(max_image_pixels: int) -> None:
         Image.MAX_IMAGE_PIXELS = max_image_pixels
 
     def __init__(self, input_path: str):
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
index 67c1ed1f5f..1c1c7b3613 100644
--- a/synapse/rest/synapse/client/new_user_consent.py
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Generator
 
 from twisted.web.server import Request
 
@@ -45,7 +45,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
         self._server_name = hs.hostname
         self._consent_version = hs.config.consent.user_consent_version
 
-        def template_search_dirs():
+        def template_search_dirs() -> Generator[str, None, None]:
             if hs.config.server.custom_template_directory:
                 yield hs.config.server.custom_template_directory
             if hs.config.sso.sso_template_dir:
@@ -88,7 +88,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
         html = template.render(template_params)
         respond_with_html(request, 200, html)
 
-    async def _async_render_POST(self, request: Request):
+    async def _async_render_POST(self, request: Request) -> None:
         try:
             session_id = get_username_mapping_session_cookie_from_request(request)
         except SynapseError as e:
diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index 36ba401656..81fec39659 100644
--- a/synapse/rest/synapse/client/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -13,16 +13,20 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from twisted.web.resource import Resource
 
 from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class OIDCResource(Resource):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         Resource.__init__(self)
         self.putChild(b"callback", OIDCCallbackResource(hs))
 
diff --git a/synapse/rest/synapse/client/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index 7785f17e90..4f375cb74c 100644
--- a/synapse/rest/synapse/client/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
@@ -16,6 +16,7 @@ import logging
 from typing import TYPE_CHECKING
 
 from synapse.http.server import DirectServeHtmlResource
+from synapse.http.site import SynapseRequest
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -30,10 +31,10 @@ class OIDCCallbackResource(DirectServeHtmlResource):
         super().__init__()
         self._oidc_handler = hs.get_oidc_handler()
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: SynapseRequest) -> None:
         await self._oidc_handler.handle_oidc_callback(request)
 
-    async def _async_render_POST(self, request):
+    async def _async_render_POST(self, request: SynapseRequest) -> None:
         # the auth response can be returned via an x-www-form-urlencoded form instead
         # of GET params, as per
         # https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index d30b478b98..28ae083497 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, Generator, List, Tuple
 
 from twisted.web.resource import Resource
 from twisted.web.server import Request
@@ -27,6 +27,7 @@ from synapse.http.server import (
 )
 from synapse.http.servlet import parse_boolean, parse_string
 from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 from synapse.util.templates import build_jinja_env
 
 if TYPE_CHECKING:
@@ -57,7 +58,7 @@ class AvailabilityCheckResource(DirectServeJsonResource):
         super().__init__()
         self._sso_handler = hs.get_sso_handler()
 
-    async def _async_render_GET(self, request: Request):
+    async def _async_render_GET(self, request: Request) -> Tuple[int, JsonDict]:
         localpart = parse_string(request, "username", required=True)
 
         session_id = get_username_mapping_session_cookie_from_request(request)
@@ -73,7 +74,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
         super().__init__()
         self._sso_handler = hs.get_sso_handler()
 
-        def template_search_dirs():
+        def template_search_dirs() -> Generator[str, None, None]:
             if hs.config.server.custom_template_directory:
                 yield hs.config.server.custom_template_directory
             if hs.config.sso.sso_template_dir:
@@ -104,7 +105,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
         html = template.render(template_params)
         respond_with_html(request, 200, html)
 
-    async def _async_render_POST(self, request: SynapseRequest):
+    async def _async_render_POST(self, request: SynapseRequest) -> None:
         # This will always be set by the time Twisted calls us.
         assert request.args is not None
 
diff --git a/synapse/rest/synapse/client/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 781ccb237c..3f247e6a2c 100644
--- a/synapse/rest/synapse/client/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -13,17 +13,21 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from twisted.web.resource import Resource
 
 from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
 from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SAML2Resource(Resource):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         Resource.__init__(self)
         self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
         self.putChild(b"authn_response", SAML2ResponseResource(hs))
diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index b37c7083dc..64378ed57b 100644
--- a/synapse/rest/synapse/client/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
@@ -12,10 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING
 
 import saml2.metadata
 
 from twisted.web.resource import Resource
+from twisted.web.server import Request
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 
 class SAML2MetadataResource(Resource):
@@ -23,11 +28,11 @@ class SAML2MetadataResource(Resource):
 
     isLeaf = 1
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         Resource.__init__(self)
         self.sp_config = hs.config.saml2_sp_config
 
-    def render_GET(self, request):
+    def render_GET(self, request: Request) -> bytes:
         metadata_xml = saml2.metadata.create_metadata_string(
             configfile=None, config=self.sp_config
         )
diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index 774ccd870f..47d2a6a229 100644
--- a/synapse/rest/synapse/client/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
@@ -15,7 +15,10 @@
 
 from typing import TYPE_CHECKING
 
+from twisted.web.server import Request
+
 from synapse.http.server import DirectServeHtmlResource
+from synapse.http.site import SynapseRequest
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -31,7 +34,7 @@ class SAML2ResponseResource(DirectServeHtmlResource):
         self._saml_handler = hs.get_saml_handler()
         self._sso_handler = hs.get_sso_handler()
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         # We're not expecting any GET request on that resource if everything goes right,
         # but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
         # In this case, just tell the user that something went wrong and they should
@@ -40,5 +43,5 @@ class SAML2ResponseResource(DirectServeHtmlResource):
             request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
         )
 
-    async def _async_render_POST(self, request):
+    async def _async_render_POST(self, request: SynapseRequest) -> None:
         await self._saml_handler.handle_saml_response(request)
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 6a66a88c53..c80a3a99aa 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -13,26 +13,26 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
 
 from twisted.web.resource import Resource
+from twisted.web.server import Request
 
 from synapse.http.server import set_cors_headers
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class WellKnownBuilder:
-    """Utility to construct the well-known response
-
-    Args:
-        hs (synapse.server.HomeServer):
-    """
-
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._config = hs.config
 
-    def get_well_known(self):
+    def get_well_known(self) -> Optional[JsonDict]:
         # if we don't have a public_baseurl, we can't help much here.
         if self._config.server.public_baseurl is None:
             return None
@@ -52,11 +52,11 @@ class WellKnownResource(Resource):
 
     isLeaf = 1
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         Resource.__init__(self)
         self._well_known_builder = WellKnownBuilder(hs)
 
-    def render_GET(self, request):
+    def render_GET(self, request: Request) -> bytes:
         set_cors_headers(request)
         r = self._well_known_builder.get_well_known()
         if not r:
diff --git a/synapse/server.py b/synapse/server.py
index 4777ef585d..637eb15b78 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -392,7 +392,7 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_http_client_context_factory(self) -> IPolicyForHTTPS:
-        if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
+        if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use:
             return InsecureInterceptableContextFactory()
         return RegularPolicyForHTTPS()
 
@@ -418,8 +418,8 @@ class HomeServer(metaclass=abc.ABCMeta):
         """
         return SimpleHttpClient(
             self,
-            ip_whitelist=self.config.ip_range_whitelist,
-            ip_blacklist=self.config.ip_range_blacklist,
+            ip_whitelist=self.config.server.ip_range_whitelist,
+            ip_blacklist=self.config.server.ip_range_blacklist,
             use_proxy=True,
         )
 
@@ -801,18 +801,18 @@ class HomeServer(metaclass=abc.ABCMeta):
 
         logger.info(
             "Connecting to redis (host=%r port=%r) for external cache",
-            self.config.redis_host,
-            self.config.redis_port,
+            self.config.redis.redis_host,
+            self.config.redis.redis_port,
         )
 
         return lazyConnection(
             hs=self,
-            host=self.config.redis_host,
-            port=self.config.redis_port,
+            host=self.config.redis.redis_host,
+            port=self.config.redis.redis_port,
             password=self.config.redis.redis_password,
             reconnect=True,
         )
 
     def should_send_federation(self) -> bool:
         "Should this server be sending federation traffic directly?"
-        return self.config.send_federation
+        return self.config.worker.send_federation
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1dc347f0c9..5c21402dea 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -61,6 +61,7 @@ from .registration import RegistrationStore
 from .rejections import RejectionsStore
 from .relations import RelationsStore
 from .room import RoomStore
+from .room_batch import RoomBatchStore
 from .roommember import RoomMemberStore
 from .search import SearchStore
 from .session import SessionStore
@@ -81,6 +82,7 @@ class DataStore(
     EventsBackgroundUpdatesStore,
     RoomMemberStore,
     RoomStore,
+    RoomBatchStore,
     RegistrationStore,
     StreamStore,
     ProfileStore,
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
new file mode 100644
index 0000000000..54fa361d3e
--- /dev/null
+++ b/synapse/storage/databases/main/room_batch.py
@@ -0,0 +1,36 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from synapse.storage._base import SQLBaseStore
+
+
+class RoomBatchStore(SQLBaseStore):
+    async def get_insertion_event_by_chunk_id(self, chunk_id: str) -> Optional[str]:
+        """Retrieve a insertion event ID.
+
+        Args:
+            chunk_id: The chunk ID of the insertion event to retrieve.
+
+        Returns:
+            The event_id of an insertion event, or None if there is no known
+            insertion event for the given insertion event.
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            table="insertion_events",
+            keyvalues={"next_chunk_id": chunk_id},
+            retcol="event_id",
+            allow_none=True,
+        )
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index c2891cb07f..eb1118d2cb 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -13,12 +13,20 @@
 # limitations under the License.
 
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
 
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.state import StateFilter
+from synapse.types import MutableStateMap, StateMap
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
     updates.
     """
 
-    def _count_state_group_hops_txn(self, txn, state_group):
+    def _count_state_group_hops_txn(
+        self, txn: LoggingTransaction, state_group: int
+    ) -> int:
         """Given a state group, count how many hops there are in the tree.
 
         This is used to ensure the delta chains don't get too long.
@@ -56,7 +66,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
         else:
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            next_group = state_group
+            next_group: Optional[int] = state_group
             count = 0
 
             while next_group:
@@ -73,11 +83,14 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             return count
 
     def _get_state_groups_from_groups_txn(
-        self, txn, groups, state_filter: Optional[StateFilter] = None
-    ):
+        self,
+        txn: LoggingTransaction,
+        groups: List[int],
+        state_filter: Optional[StateFilter] = None,
+    ) -> Mapping[int, StateMap[str]]:
         state_filter = state_filter or StateFilter.all()
 
-        results = {group: {} for group in groups}
+        results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
 
         where_clause, where_args = state_filter.make_sql_filter_clause()
 
@@ -117,7 +130,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             """
 
             for group in groups:
-                args = [group]
+                args: List[Union[int, str]] = [group]
                 args.extend(where_args)
 
                 txn.execute(sql % (where_clause,), args)
@@ -131,7 +144,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
             for group in groups:
-                next_group = group
+                next_group: Optional[int] = group
 
                 while next_group:
                     # We did this before by getting the list of group ids, and
@@ -173,6 +186,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
                         allow_none=True,
                     )
 
+        # The results shouldn't be considered mutable.
         return results
 
 
@@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
@@ -198,7 +217,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
             columns=["room_id"],
         )
 
-    async def _background_deduplicate_state(self, progress, batch_size):
+    async def _background_deduplicate_state(
+        self, progress: dict, batch_size: int
+    ) -> int:
         """This background update will slowly deduplicate state by reencoding
         them as deltas.
         """
@@ -218,7 +239,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
             )
             max_group = rows[0][0]
 
-        def reindex_txn(txn):
+        def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
             new_last_state_group = last_state_group
             for count in range(batch_size):
                 txn.execute(
@@ -251,7 +272,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
                     " WHERE id < ? AND room_id = ?",
                     (state_group, room_id),
                 )
-                (prev_group,) = txn.fetchone()
+                # There will be a result due to the coalesce.
+                (prev_group,) = txn.fetchone()  # type: ignore
                 new_last_state_group = state_group
 
                 if prev_group:
@@ -261,15 +283,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
                         # otherwise read performance degrades.
                         continue
 
-                    prev_state = self._get_state_groups_from_groups_txn(
+                    prev_state_by_group = self._get_state_groups_from_groups_txn(
                         txn, [prev_group]
                     )
-                    prev_state = prev_state[prev_group]
+                    prev_state = prev_state_by_group[prev_group]
 
-                    curr_state = self._get_state_groups_from_groups_txn(
+                    curr_state_by_group = self._get_state_groups_from_groups_txn(
                         txn, [state_group]
                     )
-                    curr_state = curr_state[state_group]
+                    curr_state = curr_state_by_group[state_group]
 
                     if not set(prev_state.keys()) - set(curr_state.keys()):
                         # We can only do a delta if the current has a strict super set
@@ -340,8 +362,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
 
         return result * BATCH_SIZE_SCALE_FACTOR
 
-    async def _background_index_state(self, progress, batch_size):
-        def reindex_txn(conn):
+    async def _background_index_state(self, progress: dict, batch_size: int) -> int:
+        def reindex_txn(conn: LoggingDatabaseConnection) -> None:
             conn.rollback()
             if isinstance(self.database_engine, PostgresEngine):
                 # postgres insists on autocommit for the index
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index f839c0c24f..f1e3a27e63 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,43 +13,56 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+
+import attr
 
 from synapse.api.constants import EventTypes
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
 from synapse.storage.state import StateFilter
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateKey, StateMap
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 MAX_STATE_DELTA_HOPS = 100
 
 
-class _GetStateGroupDelta(
-    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _GetStateGroupDelta:
     """Return type of get_state_group_delta that implements __len__, which lets
-    us use the itrable flag when caching
+    us use the iterable flag when caching
     """
 
-    __slots__ = []
+    prev_group: Optional[int]
+    delta_ids: Optional[StateMap[str]]
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.delta_ids) if self.delta_ids else 0
 
 
 class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
     """A data store for fetching/storing state groups."""
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Originally the state store used a single DictionaryCache to cache the
@@ -81,19 +94,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         # We size the non-members cache to be smaller than the members cache as the
         # vast majority of state in Matrix (today) is member events.
 
-        self._state_group_cache = DictionaryCache(
+        self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache(
             "*stateGroupCache*",
             # TODO: this hasn't been tuned yet
             50000,
         )
-        self._state_group_members_cache = DictionaryCache(
+        self._state_group_members_cache: DictionaryCache[
+            int, StateKey, str
+        ] = DictionaryCache(
             "*stateGroupMembersCache*",
             500000,
         )
 
-        def get_max_state_group_txn(txn: Cursor):
+        def get_max_state_group_txn(txn: Cursor) -> int:
             txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
-            return txn.fetchone()[0]
+            return txn.fetchone()[0]  # type: ignore
 
         self._state_group_seq_gen = build_sequence_generator(
             db_conn,
@@ -105,15 +120,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     @cached(max_entries=10000, iterable=True)
-    async def get_state_group_delta(self, state_group):
+    async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta:
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
         Returns:
-            (prev_group, delta_ids), where both may be None.
+            _GetStateGroupDelta containing prev_group and delta_ids, where both may be None.
         """
 
-        def _get_state_group_delta_txn(txn):
+        def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
             prev_group = self.db_pool.simple_select_one_onecol_txn(
                 txn,
                 table="state_group_edges",
@@ -154,7 +169,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         Returns:
             Dict of state group to state map.
         """
-        results = {}
+        results: Dict[int, StateMap[str]] = {}
 
         chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
         for chunk in chunks:
@@ -168,19 +183,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return results
 
-    def _get_state_for_group_using_cache(self, cache, group, state_filter):
+    def _get_state_for_group_using_cache(
+        self,
+        cache: DictionaryCache[int, StateKey, str],
+        group: int,
+        state_filter: StateFilter,
+    ) -> Tuple[MutableStateMap[str], bool]:
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Args:
-            cache(DictionaryCache): the state group cache to use
-            group(int): The state group to lookup
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            cache: the state group cache to use
+            group: The state group to lookup
+            state_filter: The state filter used to fetch state from the database.
 
-        Returns 2-tuple (`state_dict`, `got_all`).
-        `got_all` is a bool indicating if we successfully retrieved all
-        requests state from the cache, if False we need to query the DB for the
-        missing state.
+        Returns:
+             2-tuple (`state_dict`, `got_all`).
+                `got_all` is a bool indicating if we successfully retrieved all
+                requests state from the cache, if False we need to query the DB for the
+                missing state.
         """
         cache_entry = cache.get(group)
         state_dict_ids = cache_entry.value
@@ -277,8 +297,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         return state
 
     def _get_state_for_groups_using_cache(
-        self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
-    ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
+        self,
+        groups: Iterable[int],
+        cache: DictionaryCache[int, StateKey, str],
+        state_filter: StateFilter,
+    ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key, querying from a specific cache.
 
@@ -310,21 +333,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
     def _insert_into_cache(
         self,
-        group_to_state_dict,
-        state_filter,
-        cache_seq_num_members,
-        cache_seq_num_non_members,
-    ):
+        group_to_state_dict: Dict[int, StateMap[str]],
+        state_filter: StateFilter,
+        cache_seq_num_members: int,
+        cache_seq_num_non_members: int,
+    ) -> None:
         """Inserts results from querying the database into the relevant cache.
 
         Args:
-            group_to_state_dict (dict): The new entries pulled from database.
+            group_to_state_dict: The new entries pulled from database.
                 Map from state group to state dict
-            state_filter (StateFilter): The state filter used to fetch state
+            state_filter: The state filter used to fetch state
                 from the database.
-            cache_seq_num_members (int): Sequence number of member cache since
+            cache_seq_num_members: Sequence number of member cache since
                 last lookup in cache
-            cache_seq_num_non_members (int): Sequence number of member cache since
+            cache_seq_num_non_members: Sequence number of member cache since
                 last lookup in cache
         """
 
@@ -395,7 +418,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             The state group ID
         """
 
-        def _store_state_group_txn(txn):
+        def _store_state_group_txn(txn: LoggingTransaction) -> int:
             if current_state_ids is None:
                 # AFAIK, this can never happen
                 raise Exception("current_state_ids cannot be None")
@@ -426,6 +449,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
                 potential_hops = self._count_state_group_hops_txn(txn, prev_group)
             if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+                assert delta_ids is not None
+
                 self.db_pool.simple_insert_txn(
                     txn,
                     table="state_group_edges",
@@ -498,7 +523,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     async def purge_unreferenced_state_groups(
-        self, room_id: str, state_groups_to_delete
+        self, room_id: str, state_groups_to_delete: Collection[int]
     ) -> None:
         """Deletes no longer referenced state groups and de-deltas any state
         groups that reference them.
@@ -506,8 +531,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         Args:
             room_id: The room the state groups belong to (must all be in the
                 same room).
-            state_groups_to_delete (Collection[int]): Set of all state groups
-                to delete.
+            state_groups_to_delete: Set of all state groups to delete.
         """
 
         await self.db_pool.runInteraction(
@@ -517,7 +541,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state_groups_to_delete,
         )
 
-    def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+    def _purge_unreferenced_state_groups(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        state_groups_to_delete: Collection[int],
+    ) -> None:
         logger.info(
             "[purge] found %i state groups to delete", len(state_groups_to_delete)
         )
@@ -546,8 +575,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         # groups to non delta versions.
         for sg in remaining_state_groups:
             logger.info("[purge] de-delta-ing remaining state group %s", sg)
-            curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
-            curr_state = curr_state[sg]
+            curr_state_by_group = self._get_state_groups_from_groups_txn(txn, [sg])
+            curr_state = curr_state_by_group[sg]
 
             self.db_pool.simple_delete_txn(
                 txn, table="state_groups_state", keyvalues={"state_group": sg}
@@ -605,12 +634,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return {row["state_group"]: row["prev_state_group"] for row in rows}
 
-    async def purge_room_state(self, room_id, state_groups_to_delete):
+    async def purge_room_state(
+        self, room_id: str, state_groups_to_delete: Collection[int]
+    ) -> None:
         """Deletes all record of a room from state tables
 
         Args:
-            room_id (str):
-            state_groups_to_delete (list[int]): State groups to delete
+            room_id:
+            state_groups_to_delete: State groups to delete
         """
 
         await self.db_pool.runInteraction(
@@ -620,7 +651,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state_groups_to_delete,
         )
 
-    def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+    def _purge_room_state_txn(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        state_groups_to_delete: Collection[int],
+    ) -> None:
         # first we have to delete the state groups states
         logger.info("[purge] removing %s from state_groups_state", room_id)
 
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index c552dbf04c..10a46b5e82 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -73,7 +73,7 @@ class RelationPaginationToken:
             t, s = string.split("-")
             return RelationPaginationToken(int(t), int(s))
         except ValueError:
-            raise SynapseError(400, "Invalid token")
+            raise SynapseError(400, "Invalid relation pagination token")
 
     def to_string(self) -> str:
         return "%d-%d" % (self.topological, self.stream)
@@ -103,7 +103,7 @@ class AggregationPaginationToken:
             c, s = string.split("-")
             return AggregationPaginationToken(int(c), int(s))
         except ValueError:
-            raise SynapseError(400, "Invalid token")
+            raise SynapseError(400, "Invalid aggregation pagination token")
 
     def to_string(self) -> str:
         return "%d-%d" % (self.count, self.stream)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e5400d681a..5e86befde4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -25,12 +25,15 @@ from typing import (
 )
 
 import attr
+from frozendict import frozendict
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 
 if TYPE_CHECKING:
+    from typing import FrozenSet  # noqa: used within quoted type hint; flake8 sad
+
     from synapse.server import HomeServer
     from synapse.storage.databases import Databases
 
@@ -40,7 +43,7 @@ logger = logging.getLogger(__name__)
 T = TypeVar("T")
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True)
 class StateFilter:
     """A filter used when querying for state.
 
@@ -53,14 +56,19 @@ class StateFilter:
             appear in `types`.
     """
 
-    types = attr.ib(type=Dict[str, Optional[Set[str]]])
+    types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
     include_others = attr.ib(default=False, type=bool)
 
     def __attrs_post_init__(self):
         # If `include_others` is set we canonicalise the filter by removing
         # wildcards from the types dictionary
         if self.include_others:
-            self.types = {k: v for k, v in self.types.items() if v is not None}
+            # this is needed to work around the fact that StateFilter is frozen
+            object.__setattr__(
+                self,
+                "types",
+                frozendict({k: v for k, v in self.types.items() if v is not None}),
+            )
 
     @staticmethod
     def all() -> "StateFilter":
@@ -69,7 +77,7 @@ class StateFilter:
         Returns:
             The new state filter.
         """
-        return StateFilter(types={}, include_others=True)
+        return StateFilter(types=frozendict(), include_others=True)
 
     @staticmethod
     def none() -> "StateFilter":
@@ -78,7 +86,7 @@ class StateFilter:
         Returns:
             The new state filter.
         """
-        return StateFilter(types={}, include_others=False)
+        return StateFilter(types=frozendict(), include_others=False)
 
     @staticmethod
     def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@@ -103,7 +111,12 @@ class StateFilter:
 
             type_dict.setdefault(typ, set()).add(s)  # type: ignore
 
-        return StateFilter(types=type_dict)
+        return StateFilter(
+            types=frozendict(
+                (k, frozenset(v) if v is not None else None)
+                for k, v in type_dict.items()
+            )
+        )
 
     @staticmethod
     def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
@@ -116,7 +129,10 @@ class StateFilter:
         Returns:
             The new state filter
         """
-        return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
+        return StateFilter(
+            types=frozendict({EventTypes.Member: frozenset(members)}),
+            include_others=True,
+        )
 
     def return_expanded(self) -> "StateFilter":
         """Creates a new StateFilter where type wild cards have been removed
@@ -173,7 +189,7 @@ class StateFilter:
             # We want to return all non-members, but only particular
             # memberships
             return StateFilter(
-                types={EventTypes.Member: self.types[EventTypes.Member]},
+                types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
                 include_others=True,
             )
 
@@ -245,14 +261,15 @@ class StateFilter:
 
         return len(self.concrete_types())
 
-    def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
-        """Returns the state filtered with by this StateFilter
+    def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
+        """Returns the state filtered with by this StateFilter.
 
         Args:
             state: The state map to filter
 
         Returns:
-            The filtered state map
+            The filtered state map.
+            This is a copy, so it's safe to mutate.
         """
         if self.is_full():
             return dict(state_dict)
@@ -324,14 +341,16 @@ class StateFilter:
             if state_keys is None:
                 member_filter = StateFilter.all()
             else:
-                member_filter = StateFilter({EventTypes.Member: state_keys})
+                member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
         elif self.include_others:
             member_filter = StateFilter.all()
         else:
             member_filter = StateFilter.none()
 
         non_member_filter = StateFilter(
-            types={k: v for k, v in self.types.items() if k != EventTypes.Member},
+            types=frozendict(
+                {k: v for k, v in self.types.items() if k != EventTypes.Member}
+            ),
             include_others=self.include_others,
         )
 
@@ -358,7 +377,8 @@ class StateGroupStorage:
             make up the delta between the old and new state groups.
         """
 
-        return await self.stores.state.get_state_group_delta(state_group)
+        state_group_delta = await self.stores.state.get_state_group_delta(state_group)
+        return state_group_delta.prev_group, state_group_delta.delta_ids
 
     async def get_state_groups_ids(
         self, _room_id: str, event_ids: Iterable[str]
diff --git a/synapse/types.py b/synapse/types.py
index d4759b2dfd..90168ce8fa 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -511,7 +511,7 @@ class RoomStreamToken:
                 )
         except Exception:
             pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
+        raise SynapseError(400, "Invalid room stream token %r" % (string,))
 
     @classmethod
     def parse_stream_token(cls, string: str) -> "RoomStreamToken":
@@ -520,7 +520,7 @@ class RoomStreamToken:
                 return cls(topological=None, stream=int(string[1:]))
         except Exception:
             pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
+        raise SynapseError(400, "Invalid room stream token %r" % (string,))
 
     def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
         """Return a new token such that if an event is after both this token and
@@ -619,7 +619,7 @@ class StreamToken:
                 await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
             )
         except Exception:
-            raise SynapseError(400, "Invalid Token")
+            raise SynapseError(400, "Invalid stream token")
 
     async def to_string(self, store: "DataStore") -> str:
         return self._SEPARATOR.join(
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index ade088aae2..485ddb1893 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -130,7 +130,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         sequence: int,
         key: KT,
         value: Dict[DKT, DV],
-        fetched_keys: Optional[Set[DKT]] = None,
+        fetched_keys: Optional[Iterable[DKT]] = None,
     ) -> None:
         """Updates the entry in the cache
 
@@ -155,7 +155,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
                 self._update_or_insert(key, value, fetched_keys)
 
     def _update_or_insert(
-        self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
+        self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
     ) -> None:
         # We pop and reinsert as we need to tell the cache the size may have
         # changed
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index ac419f0db3..01b1b0d4a0 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -1,4 +1,4 @@
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
 # limitations under the License.
 import logging
 import os
-from binascii import unhexlify
 from typing import Optional, Tuple
 
 from twisted.internet.protocol import Factory
@@ -28,6 +27,7 @@ from synapse.server import HomeServer
 from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
+from tests.test_utils import SMALL_PNG
 
 logger = logging.getLogger(__name__)
 
@@ -190,31 +190,25 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
         channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
 
-        png_data = unhexlify(
-            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-            b"0000001f15c4890000000a49444154789c63000100000500010d"
-            b"0a2db40000000049454e44ae426082"
-        )
-
         request1.setResponseCode(200)
         request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
-        request1.write(png_data)
+        request1.write(SMALL_PNG)
         request1.finish()
 
         self.pump(0.1)
 
         self.assertEqual(channel1.code, 200, channel1.result["body"])
-        self.assertEqual(channel1.result["body"], png_data)
+        self.assertEqual(channel1.result["body"], SMALL_PNG)
 
         request2.setResponseCode(200)
         request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
-        request2.write(png_data)
+        request2.write(SMALL_PNG)
         request2.finish()
 
         self.pump(0.1)
 
         self.assertEqual(channel2.code, 200, channel2.result["body"])
-        self.assertEqual(channel2.result["body"], png_data)
+        self.assertEqual(channel2.result["body"], SMALL_PNG)
 
         # We expect only three new thumbnails to have been persisted.
         self.assertEqual(start_count + 3, self._count_remote_thumbnails())
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index bfa638fb4b..febd40b656 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
 import json
 import os
 import urllib.parse
-from binascii import unhexlify
 from unittest.mock import Mock
 
 from twisted.internet.defer import Deferred
@@ -28,6 +27,7 @@ from synapse.rest.client import groups, login, room
 
 from tests import unittest
 from tests.server import FakeSite, make_request
+from tests.test_utils import SMALL_PNG
 
 
 class VersionTestCase(unittest.HomeserverTestCase):
@@ -150,11 +150,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         self.media_repo = hs.get_media_repository_resource()
         self.download_resource = self.media_repo.children[b"download"]
         self.upload_resource = self.media_repo.children[b"upload"]
-        self.image_data = unhexlify(
-            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-            b"0000001f15c4890000000a49444154789c63000100000500010d"
-            b"0a2db40000000049454e44ae426082"
-        )
 
     def make_homeserver(self, reactor, clock):
 
@@ -266,7 +261,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         # Upload some media into the room
         response = self.helper.upload_media(
-            self.upload_resource, self.image_data, tok=admin_user_tok
+            self.upload_resource, SMALL_PNG, tok=admin_user_tok
         )
 
         # Extract media ID from the response
@@ -314,10 +309,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         # Upload some media
         response_1 = self.helper.upload_media(
-            self.upload_resource, self.image_data, tok=non_admin_user_tok
+            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
         )
         response_2 = self.helper.upload_media(
-            self.upload_resource, self.image_data, tok=non_admin_user_tok
+            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
         )
 
         # Extract mxcs
@@ -381,10 +376,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         # Upload some media
         response_1 = self.helper.upload_media(
-            self.upload_resource, self.image_data, tok=non_admin_user_tok
+            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
         )
         response_2 = self.helper.upload_media(
-            self.upload_resource, self.image_data, tok=non_admin_user_tok
+            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
         )
 
         # Extract media IDs
@@ -421,10 +416,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         # Upload some media
         response_1 = self.helper.upload_media(
-            self.upload_resource, self.image_data, tok=non_admin_user_tok
+            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
         )
         response_2 = self.helper.upload_media(
-            self.upload_resource, self.image_data, tok=non_admin_user_tok
+            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
         )
 
         # Extract media IDs
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 972d60570c..2f02934e72 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -1,4 +1,5 @@
 # Copyright 2020 Dirk Klimpel
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,7 +15,6 @@
 
 import json
 import os
-from binascii import unhexlify
 
 from parameterized import parameterized
 
@@ -25,6 +25,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
 
 from tests import unittest
 from tests.server import FakeSite, make_request
+from tests.test_utils import SMALL_PNG
 
 
 class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
@@ -110,15 +111,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
 
         download_resource = self.media_repo.children[b"download"]
         upload_resource = self.media_repo.children[b"upload"]
-        image_data = unhexlify(
-            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-            b"0000001f15c4890000000a49444154789c63000100000500010d"
-            b"0a2db40000000049454e44ae426082"
-        )
 
         # Upload some media into the room
         response = self.helper.upload_media(
-            upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+            upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
         )
         # Extract media ID from the response
         server_and_media_id = response["content_uri"][6:]  # Cut off 'mxc://'
@@ -504,16 +500,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         Create a media and return media_id and server_and_media_id
         """
         upload_resource = self.media_repo.children[b"upload"]
-        # file size is 67 Byte
-        image_data = unhexlify(
-            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-            b"0000001f15c4890000000a49444154789c63000100000500010d"
-            b"0a2db40000000049454e44ae426082"
-        )
 
         # Upload some media into the room
         response = self.helper.upload_media(
-            upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+            upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
         )
         # Extract media ID from the response
         server_and_media_id = response["content_uri"][6:]  # Cut off 'mxc://'
@@ -584,16 +574,10 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
 
         # Create media
         upload_resource = media_repo.children[b"upload"]
-        # file size is 67 Byte
-        image_data = unhexlify(
-            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-            b"0000001f15c4890000000a49444154789c63000100000500010d"
-            b"0a2db40000000049454e44ae426082"
-        )
 
         # Upload some media into the room
         response = self.helper.upload_media(
-            upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+            upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
         )
         # Extract media ID from the response
         server_and_media_id = response["content_uri"][6:]  # Cut off 'mxc://'
@@ -711,16 +695,10 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
 
         # Create media
         upload_resource = media_repo.children[b"upload"]
-        # file size is 67 Byte
-        image_data = unhexlify(
-            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-            b"0000001f15c4890000000a49444154789c63000100000500010d"
-            b"0a2db40000000049454e44ae426082"
-        )
 
         # Upload some media into the room
         response = self.helper.upload_media(
-            upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+            upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
         )
         # Extract media ID from the response
         server_and_media_id = response["content_uri"][6:]  # Cut off 'mxc://'
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 5cd82209c4..ece89a65ac 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -1,4 +1,5 @@
 # Copyright 2020 Dirk Klimpel
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,7 +14,6 @@
 # limitations under the License.
 
 import json
-from binascii import unhexlify
 from typing import Any, Dict, List, Optional
 
 import synapse.rest.admin
@@ -21,6 +21,7 @@ from synapse.api.errors import Codes
 from synapse.rest.client import login
 
 from tests import unittest
+from tests.test_utils import SMALL_PNG
 
 
 class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
@@ -468,16 +469,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         upload_resource = self.media_repo.children[b"upload"]
         for _ in range(number_media):
-            # file size is 67 Byte
-            image_data = unhexlify(
-                b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-                b"0000001f15c4890000000a49444154789c63000100000500010d"
-                b"0a2db40000000049454e44ae426082"
-            )
-
             # Upload some media into the room
             self.helper.upload_media(
-                upload_resource, image_data, tok=user_token, expect_code=200
+                upload_resource, SMALL_PNG, tok=user_token, expect_code=200
             )
 
     def _check_fields(self, content: List[Dict[str, Any]]):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ee204c404b..cc3f16c62a 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@ from synapse.types import JsonDict, UserID
 
 from tests import unittest
 from tests.server import FakeSite, make_request
-from tests.test_utils import make_awaitable
+from tests.test_utils import SMALL_PNG, make_awaitable
 from tests.unittest import override_config
 
 
@@ -2835,11 +2835,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
 
         # Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
-        image_data1 = unhexlify(
-            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-            b"0000001f15c4890000000a49444154789c63000100000500010d"
-            b"0a2db40000000049454e44ae426082"
-        )
+        image_data1 = SMALL_PNG
         # Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B
         image_data2 = unhexlify(
             b"47494638376101000100800100000000"
@@ -2943,14 +2939,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
         media_ids = []
         for _ in range(number_media):
-            # file size is 67 Byte
-            image_data = unhexlify(
-                b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-                b"0000001f15c4890000000a49444154789c63000100000500010d"
-                b"0a2db40000000049454e44ae426082"
-            )
-
-            media_ids.append(self._create_media_and_access(user_token, image_data))
+            media_ids.append(self._create_media_and_access(user_token, SMALL_PNG))
 
         return media_ids
 
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 2f7eebfe69..9ea1c2bf25 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -38,6 +38,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
 
 from tests import unittest
 from tests.server import FakeSite, make_request
+from tests.test_utils import SMALL_PNG
 from tests.utils import default_config
 
 
@@ -134,11 +135,7 @@ class _TestImage:
         # smoll png
         (
             _TestImage(
-                unhexlify(
-                    b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-                    b"0000001f15c4890000000a49444154789c63000100000500010d"
-                    b"0a2db40000000049454e44ae426082"
-                ),
+                SMALL_PNG,
                 b"image/png",
                 b".png",
                 unhexlify(
@@ -593,15 +590,8 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
 
     def test_upload_innocent(self):
         """Attempt to upload some innocent data that should be allowed."""
-
-        image_data = unhexlify(
-            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
-            b"0000001f15c4890000000a49444154789c63000100000500010d"
-            b"0a2db40000000049454e44ae426082"
-        )
-
         self.helper.upload_media(
-            self.upload_resource, image_data, tok=self.tok, expect_code=200
+            self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
         )
 
     def test_upload_ban(self):
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8695264595..32060f2abd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -14,6 +14,8 @@
 
 import logging
 
+from frozendict import frozendict
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.storage.state import StateFilter
@@ -183,7 +185,9 @@ class StateStoreTestCase(HomeserverTestCase):
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
-                    types={EventTypes.Member: {self.u_alice.to_string()}},
+                    types=frozendict(
+                        {EventTypes.Member: frozenset({self.u_alice.to_string()})}
+                    ),
                     include_others=True,
                 ),
             )
@@ -203,7 +207,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
-                    types={EventTypes.Member: set()}, include_others=True
+                    types=frozendict({EventTypes.Member: frozenset()}),
+                    include_others=True,
                 ),
             )
         )
@@ -228,7 +233,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset()}), include_others=True
             ),
         )
 
@@ -245,7 +250,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset()}), include_others=True
             ),
         )
 
@@ -258,7 +263,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -275,7 +280,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -295,7 +300,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=True,
             ),
         )
 
@@ -312,7 +318,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=True,
             ),
         )
 
@@ -325,7 +332,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=False
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=False,
             ),
         )
 
@@ -375,7 +383,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset()}), include_others=True
             ),
         )
 
@@ -387,7 +395,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset()}), include_others=True
             ),
         )
 
@@ -400,7 +408,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -411,7 +419,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -430,7 +438,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=True,
             ),
         )
 
@@ -441,7 +450,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=True,
             ),
         )
 
@@ -454,7 +464,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=False
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=False,
             ),
         )
 
@@ -465,7 +476,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=False
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=False,
             ),
         )
 
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index be6302d170..15ac2bfeba 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -1,5 +1,4 @@
-# Copyright 2019 New Vector Ltd
-# Copyright 2020 The Matrix.org Foundation C.I.C
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,6 +18,7 @@ Utilities for running the unit tests
 import sys
 import warnings
 from asyncio import Future
+from binascii import unhexlify
 from typing import Any, Awaitable, Callable, TypeVar
 from unittest.mock import Mock
 
@@ -117,3 +117,13 @@ class FakeResponse:
     def deliverBody(self, protocol):
         protocol.dataReceived(self.body)
         protocol.connectionLost(Failure(ResponseDone()))
+
+
+# A small image used in some tests.
+#
+# Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
+SMALL_PNG = unhexlify(
+    b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+    b"0000001f15c4890000000a49444154789c63000100000500010d"
+    b"0a2db40000000049454e44ae426082"
+)