summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-01-18 11:14:37 +0000
committerErik Johnston <erik@matrix.org>2021-01-18 11:14:37 +0000
commitf5ab7d83067fb3e1027f61bef838994022db5514 (patch)
tree44bffab26e96b5d64bacb2e9ac14762e984f5606 /synapse
parentMerge remote-tracking branch 'origin/develop' into matrix-org-hotfixes (diff)
parentEnsure the user ID is serialized in the payload instead of used as an instanc... (diff)
downloadsynapse-f5ab7d83067fb3e1027f61bef838994022db5514.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/cas.py2
-rw-r--r--synapse/config/oidc_config.py329
-rw-r--r--synapse/config/registration.py11
-rw-r--r--synapse/events/__init__.py3
-rw-r--r--synapse/handlers/devicemessage.py2
-rw-r--r--synapse/handlers/oidc_handler.py27
-rw-r--r--synapse/http/client.py12
-rw-r--r--synapse/rest/admin/media.py64
-rw-r--r--synapse/rest/media/v1/_base.py76
-rw-r--r--synapse/rest/media/v1/config_resource.py14
-rw-r--r--synapse/rest/media/v1/download_resource.py18
-rw-r--r--synapse/rest/media/v1/filepath.py50
-rw-r--r--synapse/rest/media/v1/media_repository.py50
-rw-r--r--synapse/rest/media/v1/media_storage.py12
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py77
-rw-r--r--synapse/rest/media/v1/storage_provider.py37
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py81
-rw-r--r--synapse/rest/media/v1/thumbnailer.py18
-rw-r--r--synapse/rest/media/v1/upload_resource.py14
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py329
-rw-r--r--synapse/storage/databases/main/media_repository.py3
-rw-r--r--synapse/storage/databases/main/pusher.py5
-rw-r--r--synapse/util/stringutils.py19
23 files changed, 761 insertions, 492 deletions
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 2f97e6d258..c7877b4095 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -40,7 +40,7 @@ class CasConfig(Config):
             self.cas_required_attributes = {}
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
-        return """
+        return """\
         # Enable Central Authentication Service (CAS) for registration and login.
         #
         cas_config:
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index fddca19223..c7fa749377 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import string
-from typing import Optional, Type
+from typing import Iterable, Optional, Type
 
 import attr
 
@@ -33,16 +33,8 @@ class OIDCConfig(Config):
     section = "oidc"
 
     def read_config(self, config, **kwargs):
-        validate_config(MAIN_CONFIG_SCHEMA, config, ())
-
-        self.oidc_provider = None  # type: Optional[OidcProviderConfig]
-
-        oidc_config = config.get("oidc_config")
-        if oidc_config and oidc_config.get("enabled", False):
-            validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
-            self.oidc_provider = _parse_oidc_config_dict(oidc_config)
-
-        if not self.oidc_provider:
+        self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
+        if not self.oidc_providers:
             return
 
         try:
@@ -58,144 +50,153 @@ class OIDCConfig(Config):
     @property
     def oidc_enabled(self) -> bool:
         # OIDC is enabled if we have a provider
-        return bool(self.oidc_provider)
+        return bool(self.oidc_providers)
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
-        # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
+        # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
+        # and login.
+        #
+        # Options for each entry include:
+        #
+        #   idp_id: a unique identifier for this identity provider. Used internally
+        #       by Synapse; should be a single word such as 'github'.
+        #
+        #       Note that, if this is changed, users authenticating via that provider
+        #       will no longer be recognised as the same user!
+        #
+        #   idp_name: A user-facing name for this identity provider, which is used to
+        #       offer the user a choice of login mechanisms.
+        #
+        #   discover: set to 'false' to disable the use of the OIDC discovery mechanism
+        #       to discover endpoints. Defaults to true.
+        #
+        #   issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
+        #       is enabled) to discover the provider's endpoints.
+        #
+        #   client_id: Required. oauth2 client id to use.
+        #
+        #   client_secret: Required. oauth2 client secret to use.
+        #
+        #   client_auth_method: auth method to use when exchanging the token. Valid
+        #       values are 'client_secret_basic' (default), 'client_secret_post' and
+        #       'none'.
+        #
+        #   scopes: list of scopes to request. This should normally include the "openid"
+        #       scope. Defaults to ["openid"].
+        #
+        #   authorization_endpoint: the oauth2 authorization endpoint. Required if
+        #       provider discovery is disabled.
+        #
+        #   token_endpoint: the oauth2 token endpoint. Required if provider discovery is
+        #       disabled.
+        #
+        #   userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
+        #       disabled and the 'openid' scope is not requested.
+        #
+        #   jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
+        #       the 'openid' scope is used.
+        #
+        #   skip_verification: set to 'true' to skip metadata verification. Use this if
+        #       you are connecting to a provider that is not OpenID Connect compliant.
+        #       Defaults to false. Avoid this in production.
+        #
+        #   user_profile_method: Whether to fetch the user profile from the userinfo
+        #       endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+        #
+        #       Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
+        #       included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+        #       userinfo endpoint.
+        #
+        #   allow_existing_users: set to 'true' to allow a user logging in via OIDC to
+        #       match a pre-existing account instead of failing. This could be used if
+        #       switching from password logins to OIDC. Defaults to false.
+        #
+        #   user_mapping_provider: Configuration for how attributes returned from a OIDC
+        #       provider are mapped onto a matrix user. This setting has the following
+        #       sub-properties:
+        #
+        #       module: The class name of a custom mapping module. Default is
+        #           {mapping_provider!r}.
+        #           See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
+        #           for information on implementing a custom mapping provider.
+        #
+        #       config: Configuration for the mapping provider module. This section will
+        #           be passed as a Python dictionary to the user mapping provider
+        #           module's `parse_config` method.
+        #
+        #           For the default provider, the following settings are available:
+        #
+        #             sub: name of the claim containing a unique identifier for the
+        #                 user. Defaults to 'sub', which OpenID Connect compliant
+        #                 providers should provide.
+        #
+        #             localpart_template: Jinja2 template for the localpart of the MXID.
+        #                 If this is not set, the user will be prompted to choose their
+        #                 own username.
+        #
+        #             display_name_template: Jinja2 template for the display name to set
+        #                 on first login. If unset, no displayname will be set.
+        #
+        #             extra_attributes: a map of Jinja2 templates for extra attributes
+        #                 to send back to the client during login.
+        #                 Note that these are non-standard and clients will ignore them
+        #                 without modifications.
+        #
+        #           When rendering, the Jinja2 templates are given a 'user' variable,
+        #           which is set to the claims returned by the UserInfo Endpoint and/or
+        #           in the ID Token.
         #
         # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
-        # for some example configurations.
+        # for information on how to configure these options.
         #
-        oidc_config:
-          # Uncomment the following to enable authorization against an OpenID Connect
-          # server. Defaults to false.
-          #
-          #enabled: true
-
-          # Uncomment the following to disable use of the OIDC discovery mechanism to
-          # discover endpoints. Defaults to true.
-          #
-          #discover: false
-
-          # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
-          # discover the provider's endpoints.
-          #
-          # Required if 'enabled' is true.
-          #
-          #issuer: "https://accounts.example.com/"
-
-          # oauth2 client id to use.
-          #
-          # Required if 'enabled' is true.
-          #
-          #client_id: "provided-by-your-issuer"
-
-          # oauth2 client secret to use.
-          #
-          # Required if 'enabled' is true.
-          #
-          #client_secret: "provided-by-your-issuer"
-
-          # auth method to use when exchanging the token.
-          # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
-          # 'none'.
-          #
-          #client_auth_method: client_secret_post
-
-          # list of scopes to request. This should normally include the "openid" scope.
-          # Defaults to ["openid"].
-          #
-          #scopes: ["openid", "profile"]
-
-          # the oauth2 authorization endpoint. Required if provider discovery is disabled.
-          #
-          #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
-
-          # the oauth2 token endpoint. Required if provider discovery is disabled.
-          #
-          #token_endpoint: "https://accounts.example.com/oauth2/token"
-
-          # the OIDC userinfo endpoint. Required if discovery is disabled and the
-          # "openid" scope is not requested.
-          #
-          #userinfo_endpoint: "https://accounts.example.com/userinfo"
-
-          # URI where to fetch the JWKS. Required if discovery is disabled and the
-          # "openid" scope is used.
-          #
-          #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
-
-          # Uncomment to skip metadata verification. Defaults to false.
-          #
-          # Use this if you are connecting to a provider that is not OpenID Connect
-          # compliant.
-          # Avoid this in production.
-          #
-          #skip_verification: true
-
-          # Whether to fetch the user profile from the userinfo endpoint. Valid
-          # values are: "auto" or "userinfo_endpoint".
+        # For backwards compatibility, it is also possible to configure a single OIDC
+        # provider via an 'oidc_config' setting. This is now deprecated and admins are
+        # advised to migrate to the 'oidc_providers' format.
+        #
+        oidc_providers:
+          # Generic example
           #
-          # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
-          # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
+          #- idp_id: my_idp
+          #  idp_name: "My OpenID provider"
+          #  discover: false
+          #  issuer: "https://accounts.example.com/"
+          #  client_id: "provided-by-your-issuer"
+          #  client_secret: "provided-by-your-issuer"
+          #  client_auth_method: client_secret_post
+          #  scopes: ["openid", "profile"]
+          #  authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+          #  token_endpoint: "https://accounts.example.com/oauth2/token"
+          #  userinfo_endpoint: "https://accounts.example.com/userinfo"
+          #  jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+          #  skip_verification: true
+
+          # For use with Keycloak
           #
-          #user_profile_method: "userinfo_endpoint"
-
-          # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
-          # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
-          #
-          #allow_existing_users: true
-
-          # An external module can be provided here as a custom solution to mapping
-          # attributes returned from a OIDC provider onto a matrix user.
+          #- idp_id: keycloak
+          #  idp_name: Keycloak
+          #  issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
+          #  client_id: "synapse"
+          #  client_secret: "copy secret generated in Keycloak UI"
+          #  scopes: ["openid", "profile"]
+
+          # For use with Github
           #
-          user_mapping_provider:
-            # The custom module's class. Uncomment to use a custom module.
-            # Default is {mapping_provider!r}.
-            #
-            # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
-            # for information on implementing a custom mapping provider.
-            #
-            #module: mapping_provider.OidcMappingProvider
-
-            # Custom configuration values for the module. This section will be passed as
-            # a Python dictionary to the user mapping provider module's `parse_config`
-            # method.
-            #
-            # The examples below are intended for the default provider: they should be
-            # changed if using a custom provider.
-            #
-            config:
-              # name of the claim containing a unique identifier for the user.
-              # Defaults to `sub`, which OpenID Connect compliant providers should provide.
-              #
-              #subject_claim: "sub"
-
-              # Jinja2 template for the localpart of the MXID.
-              #
-              # When rendering, this template is given the following variables:
-              #   * user: The claims returned by the UserInfo Endpoint and/or in the ID
-              #     Token
-              #
-              # If this is not set, the user will be prompted to choose their
-              # own username.
-              #
-              #localpart_template: "{{{{ user.preferred_username }}}}"
-
-              # Jinja2 template for the display name to set on first login.
-              #
-              # If unset, no displayname will be set.
-              #
-              #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
-
-              # Jinja2 templates for extra attributes to send back to the client during
-              # login.
-              #
-              # Note that these are non-standard and clients will ignore them without modifications.
-              #
-              #extra_attributes:
-                #birthdate: "{{{{ user.birthdate }}}}"
+          #- idp_id: google
+          #  idp_name: Google
+          #  discover: false
+          #  issuer: "https://github.com/"
+          #  client_id: "your-client-id" # TO BE FILLED
+          #  client_secret: "your-client-secret" # TO BE FILLED
+          #  authorization_endpoint: "https://github.com/login/oauth/authorize"
+          #  token_endpoint: "https://github.com/login/oauth/access_token"
+          #  userinfo_endpoint: "https://api.github.com/user"
+          #  scopes: ["read:user"]
+          #  user_mapping_provider:
+          #    config:
+          #      subject_claim: "id"
+          #      localpart_template: "{{ user.login }}"
+          #      display_name_template: "{{ user.name }}"
         """.format(
             mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
         )
@@ -234,7 +235,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
     },
 }
 
-# the `oidc_config` setting can either be None (as it is in the default
+# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
+OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
+    "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
+}
+
+
+# the `oidc_providers` list can either be None (as it is in the default config), or
+# a list of provider configs, each of which requires an explicit ID and name.
+OIDC_PROVIDER_LIST_SCHEMA = {
+    "oneOf": [
+        {"type": "null"},
+        {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
+    ]
+}
+
+# the `oidc_config` setting can either be None (which it used to be in the default
 # config), or an object. If an object, it is ignored unless it has an "enabled: True"
 # property.
 #
@@ -243,12 +259,41 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
 # additional checks in the code.
 OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
 
+# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
 MAIN_CONFIG_SCHEMA = {
     "type": "object",
-    "properties": {"oidc_config": OIDC_CONFIG_SCHEMA},
+    "properties": {
+        "oidc_config": OIDC_CONFIG_SCHEMA,
+        "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
+    },
 }
 
 
+def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
+    """extract and parse the OIDC provider configs from the config dict
+
+    The configuration may contain either a single `oidc_config` object with an
+    `enabled: True` property, or a list of provider configurations under
+    `oidc_providers`, *or both*.
+
+    Returns a generator which yields the OidcProviderConfig objects
+    """
+    validate_config(MAIN_CONFIG_SCHEMA, config, ())
+
+    for p in config.get("oidc_providers") or []:
+        yield _parse_oidc_config_dict(p)
+
+    # for backwards-compatibility, it is also possible to provide a single "oidc_config"
+    # object with an "enabled: True" property.
+    oidc_config = config.get("oidc_config")
+    if oidc_config and oidc_config.get("enabled", False):
+        # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
+        # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
+        # above), so now we need to validate it.
+        validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
+        yield _parse_oidc_config_dict(oidc_config)
+
+
 def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
     """Take the configuration dict and parse it into an OidcProviderConfig
 
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index cc5f75123c..740c3fc1b1 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -14,14 +14,13 @@
 # limitations under the License.
 
 import os
-from distutils.util import strtobool
 
 import pkg_resources
 
 from synapse.api.constants import RoomCreationPreset
 from synapse.config._base import Config, ConfigError
 from synapse.types import RoomAlias, UserID
-from synapse.util.stringutils import random_string_with_symbols
+from synapse.util.stringutils import random_string_with_symbols, strtobool
 
 
 class AccountValidityConfig(Config):
@@ -86,12 +85,12 @@ class RegistrationConfig(Config):
     section = "registration"
 
     def read_config(self, config, **kwargs):
-        self.enable_registration = bool(
-            strtobool(str(config.get("enable_registration", False)))
+        self.enable_registration = strtobool(
+            str(config.get("enable_registration", False))
         )
         if "disable_registration" in config:
-            self.enable_registration = not bool(
-                strtobool(str(config["disable_registration"]))
+            self.enable_registration = not strtobool(
+                str(config["disable_registration"])
             )
 
         self.account_validity = AccountValidityConfig(
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8028663fa8..3ec4120f85 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -17,7 +17,6 @@
 
 import abc
 import os
-from distutils.util import strtobool
 from typing import Dict, Optional, Tuple, Type
 
 from unpaddedbase64 import encode_base64
@@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
 from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.caches import intern_dict
 from synapse.util.frozenutils import freeze
+from synapse.util.stringutils import strtobool
 
 # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
 # bugs where we accidentally share e.g. signature dicts. However, converting a
@@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
 # NOTE: This is overridden by the configuration by the Synapse worker apps, but
 # for the sake of tests, it is set here while it cannot be configured on the
 # homeserver object itself.
+
 USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
 
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 109dc7932f..37a678b6ce 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -163,7 +163,7 @@ class DeviceMessageHandler:
             await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
 
             # Immediately attempt a resync in the background
-            run_in_background(self._user_device_resync, sender_user_id)
+            run_in_background(self._user_device_resync, user_id=sender_user_id)
 
     async def send_device_message(
         self,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index f63a90ec5c..5e5fda7b2f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -78,21 +78,28 @@ class OidcHandler:
     def __init__(self, hs: "HomeServer"):
         self._sso_handler = hs.get_sso_handler()
 
-        provider_conf = hs.config.oidc.oidc_provider
+        provider_confs = hs.config.oidc.oidc_providers
         # we should not have been instantiated if there is no configured provider.
-        assert provider_conf is not None
+        assert provider_confs
 
         self._token_generator = OidcSessionTokenGenerator(hs)
-
-        self._provider = OidcProvider(hs, self._token_generator, provider_conf)
+        self._providers = {
+            p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
+        }
 
     async def load_metadata(self) -> None:
         """Validate the config and load the metadata from the remote endpoint.
 
         Called at startup to ensure we have everything we need.
         """
-        await self._provider.load_metadata()
-        await self._provider.load_jwks()
+        for idp_id, p in self._providers.items():
+            try:
+                await p.load_metadata()
+                await p.load_jwks()
+            except Exception as e:
+                raise Exception(
+                    "Error while initialising OIDC provider %r" % (idp_id,)
+                ) from e
 
     async def handle_oidc_callback(self, request: SynapseRequest) -> None:
         """Handle an incoming request to /_synapse/oidc/callback
@@ -184,6 +191,12 @@ class OidcHandler:
             self._sso_handler.render_error(request, "mismatching_session", str(e))
             return
 
+        oidc_provider = self._providers.get(session_data.idp_id)
+        if not oidc_provider:
+            logger.error("OIDC session uses unknown IdP %r", oidc_provider)
+            self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
+            return
+
         if b"code" not in request.args:
             logger.info("Code parameter is missing")
             self._sso_handler.render_error(
@@ -193,7 +206,7 @@ class OidcHandler:
 
         code = request.args[b"code"][0].decode()
 
-        await self._provider.handle_oidc_callback(request, session_data, code)
+        await oidc_provider.handle_oidc_callback(request, session_data, code)
 
 
 class OidcError(Exception):
diff --git a/synapse/http/client.py b/synapse/http/client.py
index dc4b81ca60..df498c8645 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -766,14 +766,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
         self.max_size = max_size
 
     def dataReceived(self, data: bytes) -> None:
+        # If the deferred was called, bail early.
+        if self.deferred.called:
+            return
+
         self.stream.write(data)
         self.length += len(data)
+        # The first time the maximum size is exceeded, error and cancel the
+        # connection. dataReceived might be called again if data was received
+        # in the meantime.
         if self.max_size is not None and self.length >= self.max_size:
             self.deferred.errback(BodyExceededMaxSize())
-            self.deferred = defer.Deferred()
             self.transport.loseConnection()
 
     def connectionLost(self, reason: Failure) -> None:
+        # If the maximum size was already exceeded, there's nothing to do.
+        if self.deferred.called:
+            return
+
         if reason.check(ResponseDone):
             self.deferred.callback(self.length)
         elif reason.check(PotentialDataLoss):
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c82b4f87d6..8720b1401f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,6 +15,9 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
 
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
     assert_requester_is_admin,
     assert_user_is_admin,
 )
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
         admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, room_id: str):
+    async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
 
     PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, user_id: str):
+    async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
         "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, server_name: str, media_id: str):
+    async def on_POST(
+        self, request: Request, server_name: str, media_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
         return 200, {}
 
 
+class ProtectMediaByID(RestServlet):
+    """Protect local media from being quarantined.
+    """
+
+    PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+
+    def __init__(self, hs: "HomeServer"):
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        logging.info("Protecting local media by ID: %s", media_id)
+
+        # Quarantine this media id
+        await self.store.mark_local_media_as_safe(media_id)
+
+        return 200, {}
+
+
 class ListMediaInRoom(RestServlet):
     """Lists all of the media in a given room.
     """
 
     PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         is_admin = await self.auth.is_server_admin(requester.user)
         if not is_admin:
@@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
 class PurgeMediaCacheRestServlet(RestServlet):
     PATTERNS = admin_patterns("/purge_media_cache")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.media_repository = hs.get_media_repository()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
@@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
 
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
 
-    async def on_DELETE(self, request, server_name: str, media_id: str):
+    async def on_DELETE(
+        self, request: Request, server_name: str, media_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         if self.server_name != server_name:
@@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
 
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
 
-    async def on_POST(self, request, server_name: str):
+    async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
@@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
         return 200, {"deleted_media": deleted_media, "total": total}
 
 
-def register_servlets_for_media_repo(hs, http_server):
+def register_servlets_for_media_repo(hs: "HomeServer", http_server):
     """
     Media repo specific APIs.
     """
@@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
     QuarantineMediaInRoom(hs).register(http_server)
     QuarantineMediaByID(hs).register(http_server)
     QuarantineMediaByUser(hs).register(http_server)
+    ProtectMediaByID(hs).register(http_server)
     ListMediaInRoom(hs).register(http_server)
     DeleteMediaByID(hs).register(http_server)
     DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 47c2b44bff..31a41e4a27 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,10 +17,11 @@
 import logging
 import os
 import urllib
-from typing import Awaitable
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple
 
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError, cs_error
 from synapse.http.server import finish_request, respond_with_json
@@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
 ]
 
 
-def parse_media_id(request):
+def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
     try:
         # This allows users to append e.g. /test.png to the URL. Useful for
         # clients that parse the URL to see content type.
@@ -69,7 +70,7 @@ def parse_media_id(request):
         )
 
 
-def respond_404(request):
+def respond_404(request: Request) -> None:
     respond_with_json(
         request,
         404,
@@ -79,8 +80,12 @@ def respond_404(request):
 
 
 async def respond_with_file(
-    request, media_type, file_path, file_size=None, upload_name=None
-):
+    request: Request,
+    media_type: str,
+    file_path: str,
+    file_size: Optional[int] = None,
+    upload_name: Optional[str] = None,
+) -> None:
     logger.debug("Responding with %r", file_path)
 
     if os.path.isfile(file_path):
@@ -98,15 +103,20 @@ async def respond_with_file(
         respond_404(request)
 
 
-def add_file_headers(request, media_type, file_size, upload_name):
+def add_file_headers(
+    request: Request,
+    media_type: str,
+    file_size: Optional[int],
+    upload_name: Optional[str],
+) -> None:
     """Adds the correct response headers in preparation for responding with the
     media.
 
     Args:
-        request (twisted.web.http.Request)
-        media_type (str): The media/content type.
-        file_size (int): Size in bytes of the media, if known.
-        upload_name (str): The name of the requested file, if any.
+        request
+        media_type: The media/content type.
+        file_size: Size in bytes of the media, if known.
+        upload_name: The name of the requested file, if any.
     """
 
     def _quote(x):
@@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name):
     # select private. don't bother setting Expires as all our
     # clients are smart enough to be happy with Cache-Control
     request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
-    request.setHeader(b"Content-Length", b"%d" % (file_size,))
+    if file_size is not None:
+        request.setHeader(b"Content-Length", b"%d" % (file_size,))
 
     # Tell web crawlers to not index, archive, or follow links in media. This
     # should help to prevent things in the media repo from showing up in web
@@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
 }
 
 
-def _can_encode_filename_as_token(x):
+def _can_encode_filename_as_token(x: str) -> bool:
     for c in x:
         # from RFC2616:
         #
@@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
 
 
 async def respond_with_responder(
-    request, responder, media_type, file_size, upload_name=None
-):
+    request: Request,
+    responder: "Optional[Responder]",
+    media_type: str,
+    file_size: Optional[int],
+    upload_name: Optional[str] = None,
+) -> None:
     """Responds to the request with given responder. If responder is None then
     returns 404.
 
     Args:
-        request (twisted.web.http.Request)
-        responder (Responder|None)
-        media_type (str): The media/content type.
-        file_size (int|None): Size in bytes of the media. If not known it should be None
-        upload_name (str|None): The name of the requested file, if any.
+        request
+        responder
+        media_type: The media/content type.
+        file_size: Size in bytes of the media. If not known it should be None
+        upload_name: The name of the requested file, if any.
     """
     if request._disconnected:
         logger.warning(
@@ -308,22 +323,22 @@ class FileInfo:
         self.thumbnail_type = thumbnail_type
 
 
-def get_filename_from_headers(headers):
+def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
     """
     Get the filename of the downloaded file by inspecting the
     Content-Disposition HTTP header.
 
     Args:
-        headers (dict[bytes, list[bytes]]): The HTTP request headers.
+        headers: The HTTP request headers.
 
     Returns:
-        A Unicode string of the filename, or None.
+        The filename, or None.
     """
     content_disposition = headers.get(b"Content-Disposition", [b""])
 
     # No header, bail out.
     if not content_disposition[0]:
-        return
+        return None
 
     _, params = _parse_header(content_disposition[0])
 
@@ -356,17 +371,16 @@ def get_filename_from_headers(headers):
     return upload_name
 
 
-def _parse_header(line):
+def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
     """Parse a Content-type like header.
 
     Cargo-culted from `cgi`, but works on bytes rather than strings.
 
     Args:
-        line (bytes): header to be parsed
+        line: header to be parsed
 
     Returns:
-        Tuple[bytes, dict[bytes, bytes]]:
-            the main content-type, followed by the parameter dictionary
+        The main content-type, followed by the parameter dictionary
     """
     parts = _parseparam(b";" + line)
     key = next(parts)
@@ -386,16 +400,16 @@ def _parse_header(line):
     return key, pdict
 
 
-def _parseparam(s):
+def _parseparam(s: bytes) -> Generator[bytes, None, None]:
     """Generator which splits the input on ;, respecting double-quoted sequences
 
     Cargo-culted from `cgi`, but works on bytes rather than strings.
 
     Args:
-        s (bytes): header to be parsed
+        s: header to be parsed
 
     Returns:
-        Iterable[bytes]: the split input
+        The split input
     """
     while s[:1] == b";":
         s = s[1:]
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 68dd2a1c8a..4e4c6971f7 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2018 Will Hunt <will@half-shot.uk>
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,22 +15,29 @@
 # limitations under the License.
 #
 
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 
 class MediaConfigResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         config = hs.get_config()
         self.clock = hs.get_clock()
         self.auth = hs.get_auth()
         self.limits_dict = {"m.upload.size": config.max_upload_size}
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         await self.auth.get_user_by_req(request)
         respond_with_json(request, 200, self.limits_dict, send_cors=True)
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index d3d8457303..3ed219ae43 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,24 +14,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
 
-import synapse.http.servlet
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
+from synapse.http.servlet import parse_boolean
 
 from ._base import parse_media_id, respond_404
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class DownloadResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
         self.media_repo = media_repo
         self.server_name = hs.hostname
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         set_cors_headers(request)
         request.setHeader(
             b"Content-Security-Policy",
@@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
         if server_name == self.server_name:
             await self.media_repo.get_local_media(request, media_id, name)
         else:
-            allow_remote = synapse.http.servlet.parse_boolean(
-                request, "allow_remote", default=True
-            )
+            allow_remote = parse_boolean(request, "allow_remote", default=True)
             if not allow_remote:
                 logger.info(
                     "Rejecting request for remote media %s/%s due to allow_remote",
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 9e079f672f..7792f26e78 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,11 +17,12 @@
 import functools
 import os
 import re
+from typing import Callable, List
 
 NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
 
 
-def _wrap_in_base_path(func):
+def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
     """Takes a function that returns a relative path and turns it into an
     absolute path based on the location of the primary media store
     """
@@ -41,12 +43,18 @@ class MediaFilePaths:
     to write to the backup media store (when one is configured)
     """
 
-    def __init__(self, primary_base_path):
+    def __init__(self, primary_base_path: str):
         self.base_path = primary_base_path
 
     def default_thumbnail_rel(
-        self, default_top_level, default_sub_type, width, height, content_type, method
-    ):
+        self,
+        default_top_level: str,
+        default_sub_type: str,
+        width: int,
+        height: int,
+        content_type: str,
+        method: str,
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -55,12 +63,14 @@ class MediaFilePaths:
 
     default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
 
-    def local_media_filepath_rel(self, media_id):
+    def local_media_filepath_rel(self, media_id: str) -> str:
         return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
 
     local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
 
-    def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
+    def local_media_thumbnail_rel(
+        self, media_id: str, width: int, height: int, content_type: str, method: str
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -86,7 +96,7 @@ class MediaFilePaths:
             media_id[4:],
         )
 
-    def remote_media_filepath_rel(self, server_name, file_id):
+    def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
         return os.path.join(
             "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
         )
@@ -94,8 +104,14 @@ class MediaFilePaths:
     remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
 
     def remote_media_thumbnail_rel(
-        self, server_name, file_id, width, height, content_type, method
-    ):
+        self,
+        server_name: str,
+        file_id: str,
+        width: int,
+        height: int,
+        content_type: str,
+        method: str,
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -113,7 +129,7 @@ class MediaFilePaths:
     # Should be removed after some time, when most of the thumbnails are stored
     # using the new path.
     def remote_media_thumbnail_rel_legacy(
-        self, server_name, file_id, width, height, content_type
+        self, server_name: str, file_id: str, width: int, height: int, content_type: str
     ):
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@@ -126,7 +142,7 @@ class MediaFilePaths:
             file_name,
         )
 
-    def remote_media_thumbnail_dir(self, server_name, file_id):
+    def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
         return os.path.join(
             self.base_path,
             "remote_thumbnail",
@@ -136,7 +152,7 @@ class MediaFilePaths:
             file_id[4:],
         )
 
-    def url_cache_filepath_rel(self, media_id):
+    def url_cache_filepath_rel(self, media_id: str) -> str:
         if NEW_FORMAT_ID_RE.match(media_id):
             # Media id is of the form <DATE><RANDOM_STRING>
             # E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -146,7 +162,7 @@ class MediaFilePaths:
 
     url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
 
-    def url_cache_filepath_dirs_to_delete(self, media_id):
+    def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
         "The dirs to try and remove if we delete the media_id file"
         if NEW_FORMAT_ID_RE.match(media_id):
             return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@@ -156,7 +172,9 @@ class MediaFilePaths:
                 os.path.join(self.base_path, "url_cache", media_id[0:2]),
             ]
 
-    def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
+    def url_cache_thumbnail_rel(
+        self, media_id: str, width: int, height: int, content_type: str, method: str
+    ) -> str:
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
 
@@ -178,7 +196,7 @@ class MediaFilePaths:
 
     url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
 
-    def url_cache_thumbnail_directory(self, media_id):
+    def url_cache_thumbnail_directory(self, media_id: str) -> str:
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
 
@@ -195,7 +213,7 @@ class MediaFilePaths:
                 media_id[4:],
             )
 
-    def url_cache_thumbnail_dirs_to_delete(self, media_id):
+    def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
         "The dirs to try and remove if we delete the media_id thumbnails"
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 83beb02b05..4c9946a616 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,12 +13,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import errno
 import logging
 import os
 import shutil
-from typing import IO, Dict, List, Optional, Tuple
+from io import BytesIO
+from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 import twisted.internet.error
 import twisted.web.http
@@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
 from .thumbnailer import Thumbnailer, ThumbnailError
 from .upload_resource import UploadResource
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
 
 
 class MediaRepository:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.client = hs.get_federation_http_client()
@@ -73,16 +76,16 @@ class MediaRepository:
         self.max_upload_size = hs.config.max_upload_size
         self.max_image_pixels = hs.config.max_image_pixels
 
-        self.primary_base_path = hs.config.media_store_path
-        self.filepaths = MediaFilePaths(self.primary_base_path)
+        self.primary_base_path = hs.config.media_store_path  # type: str
+        self.filepaths = MediaFilePaths(self.primary_base_path)  # type: MediaFilePaths
 
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.thumbnail_requirements = hs.config.thumbnail_requirements
 
         self.remote_media_linearizer = Linearizer(name="media_remote")
 
-        self.recently_accessed_remotes = set()
-        self.recently_accessed_locals = set()
+        self.recently_accessed_remotes = set()  # type: Set[Tuple[str, str]]
+        self.recently_accessed_locals = set()  # type: Set[str]
 
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
@@ -113,7 +116,7 @@ class MediaRepository:
             "update_recently_accessed_media", self._update_recently_accessed
         )
 
-    async def _update_recently_accessed(self):
+    async def _update_recently_accessed(self) -> None:
         remote_media = self.recently_accessed_remotes
         self.recently_accessed_remotes = set()
 
@@ -124,12 +127,12 @@ class MediaRepository:
             local_media, remote_media, self.clock.time_msec()
         )
 
-    def mark_recently_accessed(self, server_name, media_id):
+    def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
         """Mark the given media as recently accessed.
 
         Args:
-            server_name (str|None): Origin server of media, or None if local
-            media_id (str): The media ID of the content
+            server_name: Origin server of media, or None if local
+            media_id: The media ID of the content
         """
         if server_name:
             self.recently_accessed_remotes.add((server_name, media_id))
@@ -459,7 +462,14 @@ class MediaRepository:
     def _get_thumbnail_requirements(self, media_type):
         return self.thumbnail_requirements.get(media_type, ())
 
-    def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
+    def _generate_thumbnail(
+        self,
+        thumbnailer: Thumbnailer,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+    ) -> Optional[BytesIO]:
         m_width = thumbnailer.width
         m_height = thumbnailer.height
 
@@ -470,22 +480,20 @@ class MediaRepository:
                 m_height,
                 self.max_image_pixels,
             )
-            return
+            return None
 
         if thumbnailer.transpose_method is not None:
             m_width, m_height = thumbnailer.transpose()
 
         if t_method == "crop":
-            t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
+            return thumbnailer.crop(t_width, t_height, t_type)
         elif t_method == "scale":
             t_width, t_height = thumbnailer.aspect(t_width, t_height)
             t_width = min(m_width, t_width)
             t_height = min(m_height, t_height)
-            t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
-        else:
-            t_byte_source = None
+            return thumbnailer.scale(t_width, t_height, t_type)
 
-        return t_byte_source
+        return None
 
     async def generate_local_exact_thumbnail(
         self,
@@ -776,7 +784,7 @@ class MediaRepository:
 
         return {"width": m_width, "height": m_height}
 
-    async def delete_old_remote_media(self, before_ts):
+    async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
         old_media = await self.store.get_remote_media_before(before_ts)
 
         deleted = 0
@@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
     within a given rectangle.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         # If we're not configured to use it, raise if we somehow got here.
         if not hs.config.can_load_media_repo:
             raise ConfigError("Synapse is not configured to use a media repo.")
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 268e0c8f50..89cdd605aa 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vecotr Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -18,6 +18,8 @@ import os
 import shutil
 from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
 
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@@ -270,7 +272,7 @@ class MediaStorage:
         return self.filepaths.local_media_filepath_rel(file_info.file_id)
 
 
-def _write_file_synchronously(source, dest):
+def _write_file_synchronously(source: IO, dest: IO) -> None:
     """Write `source` to the file like `dest` synchronously. Should be called
     from a thread.
 
@@ -286,14 +288,14 @@ class FileResponder(Responder):
     """Wraps an open file that can be sent to a request.
 
     Args:
-        open_file (file): A file like object to be streamed ot the client,
+        open_file: A file like object to be streamed ot the client,
             is closed when finished streaming.
     """
 
-    def __init__(self, open_file):
+    def __init__(self, open_file: IO):
         self.open_file = open_file
 
-    def write_to_consumer(self, consumer):
+    def write_to_consumer(self, consumer: IConsumer) -> Deferred:
         return make_deferred_yieldable(
             FileSender().beginFileTransfer(self.open_file, consumer)
         )
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 1082389d9b..a632099167 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import datetime
 import errno
 import fnmatch
@@ -23,12 +23,13 @@ import re
 import shutil
 import sys
 import traceback
-from typing import Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
 from urllib import parse as urlparse
 
 import attr
 
 from twisted.internet.error import DNSLookupError
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.client import SimpleHttpClient
@@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
 
 from ._base import FileInfo
 
+if TYPE_CHECKING:
+    from lxml import etree
+
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@@ -119,7 +127,12 @@ class OEmbedError(Exception):
 class PreviewUrlResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo, media_storage):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
         super().__init__()
 
         self.auth = hs.get_auth()
@@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
                 self._start_expire_url_cache_data, 10 * 1000
             )
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         request.setHeader(b"Allow", b"OPTIONS, GET")
         respond_with_json(request, 200, {}, send_cors=True)
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
 
         # XXX: if get_user_by_req fails, what should we do in an async render?
         requester = await self.auth.get_user_by_req(request)
@@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
             raise OEmbedError() from e
 
-    async def _download_url(self, url: str, user):
+    async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             "expire_url_cache_data", self._expire_url_cache_data
         )
 
-    async def _expire_url_cache_data(self):
+    async def _expire_url_cache_data(self) -> None:
         """Clean up expired url cache content, media and thumbnails.
         """
         # TODO: Delete from backup media store
@@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
-def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
+def decode_and_calc_og(
+    body: bytes, media_uri: str, request_encoding: Optional[str] = None
+) -> Dict[str, Optional[str]]:
     # If there's no body, nothing useful is going to be found.
     if not body:
         return {}
@@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
     return og
 
 
-def _calc_og(tree, media_uri):
+def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
     # suck our tree into lxml and define our OG response.
 
     # if we see any image URLs in the OG response, then spider them
@@ -801,7 +816,9 @@ def _calc_og(tree, media_uri):
                 for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
             )
             og["og:description"] = summarize_paragraphs(text_nodes)
-    else:
+    elif og["og:description"]:
+        # This must be a non-empty string at this point.
+        assert isinstance(og["og:description"], str)
         og["og:description"] = summarize_paragraphs([og["og:description"]])
 
     # TODO: delete the url downloads to stop diskfilling,
@@ -809,7 +826,9 @@ def _calc_og(tree, media_uri):
     return og
 
 
-def _iterate_over_text(tree, *tags_to_ignore):
+def _iterate_over_text(
+    tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
     """Iterate over the tree returning text nodes in a depth first fashion,
     skipping text nodes inside certain tags.
     """
@@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
             )
 
 
-def _rebase_url(url, base):
-    base = list(urlparse.urlparse(base))
-    url = list(urlparse.urlparse(url))
-    if not url[0]:  # fix up schema
-        url[0] = base[0] or "http"
-    if not url[1]:  # fix up hostname
-        url[1] = base[1]
-        if not url[2].startswith("/"):
-            url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
-    return urlparse.urlunparse(url)
+def _rebase_url(url: str, base: str) -> str:
+    base_parts = list(urlparse.urlparse(base))
+    url_parts = list(urlparse.urlparse(url))
+    if not url_parts[0]:  # fix up schema
+        url_parts[0] = base_parts[0] or "http"
+    if not url_parts[1]:  # fix up hostname
+        url_parts[1] = base_parts[1]
+        if not url_parts[2].startswith("/"):
+            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+    return urlparse.urlunparse(url_parts)
 
 
-def _is_media(content_type):
-    if content_type.lower().startswith("image/"):
-        return True
+def _is_media(content_type: str) -> bool:
+    return content_type.lower().startswith("image/")
 
 
-def _is_html(content_type):
+def _is_html(content_type: str) -> bool:
     content_type = content_type.lower()
-    if content_type.startswith("text/html") or content_type.startswith(
+    return content_type.startswith("text/html") or content_type.startswith(
         "application/xhtml"
-    ):
-        return True
+    )
 
 
-def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
+def summarize_paragraphs(
+    text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
     # Try to get a summary of between 200 and 500 words, respecting
     # first paragraph and then word boundaries.
     # TODO: Respect sentences?
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 67f67efde7..e92006faa9 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,10 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import logging
 import os
 import shutil
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.config._base import Config
 from synapse.logging.context import defer_to_thread, run_in_background
@@ -27,13 +28,17 @@ from .media_storage import FileResponder
 
 logger = logging.getLogger(__name__)
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
-class StorageProvider:
+
+class StorageProvider(metaclass=abc.ABCMeta):
     """A storage provider is a service that can store uploaded media and
     retrieve them.
     """
 
-    async def store_file(self, path: str, file_info: FileInfo):
+    @abc.abstractmethod
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         """Store the file described by file_info. The actual contents can be
         retrieved by reading the file in file_info.upload_path.
 
@@ -42,6 +47,7 @@ class StorageProvider:
             file_info: The metadata of the file.
         """
 
+    @abc.abstractmethod
     async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         """Attempt to fetch the file described by file_info and stream it
         into writer.
@@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
         self.store_synchronous = store_synchronous
         self.store_remote = store_remote
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "StorageProviderWrapper[%s]" % (self.backend,)
 
-    async def store_file(self, path, file_info):
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         if not file_info.server_name and not self.store_local:
             return None
 
@@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
         if self.store_synchronous:
             # store_file is supposed to return an Awaitable, but guard
             # against improper implementations.
-            return await maybe_awaitable(self.backend.store_file(path, file_info))
+            await maybe_awaitable(self.backend.store_file(path, file_info))  # type: ignore
         else:
             # TODO: Handle errors.
             async def store():
@@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
                     logger.exception("Error storing file")
 
             run_in_background(store)
-            return None
 
-    async def fetch(self, path, file_info):
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         # store_file is supposed to return an Awaitable, but guard
         # against improper implementations.
         return await maybe_awaitable(self.backend.fetch(path, file_info))
@@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
     """A storage provider that stores files in a directory on a filesystem.
 
     Args:
-        hs (HomeServer)
+        hs
         config: The config returned by `parse_config`.
     """
 
-    def __init__(self, hs, config):
+    def __init__(self, hs: "HomeServer", config: str):
         self.hs = hs
         self.cache_directory = hs.config.media_store_path
         self.base_directory = config
@@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
     def __str__(self):
         return "FileStorageProviderBackend[%s]" % (self.base_directory,)
 
-    async def store_file(self, path, file_info):
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         """See StorageProvider.store_file"""
 
         primary_fname = os.path.join(self.cache_directory, path)
@@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
         if not os.path.exists(dirname):
             os.makedirs(dirname)
 
-        return await defer_to_thread(
+        await defer_to_thread(
             self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
         )
 
-    async def fetch(self, path, file_info):
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         """See StorageProvider.fetch"""
 
         backup_fname = os.path.join(self.base_directory, path)
         if os.path.isfile(backup_fname):
             return FileResponder(open(backup_fname, "rb"))
 
+        return None
+
     @staticmethod
-    def parse_config(config):
+    def parse_config(config: dict) -> str:
         """Called on startup to parse config supplied. This should parse
         the config and raise if there is a problem.
 
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 30421b663a..d6880f2e6e 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,10 +16,14 @@
 
 
 import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
 
 from synapse.api.errors import SynapseError
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
 from synapse.http.servlet import parse_integer, parse_string
+from synapse.rest.media.v1.media_storage import MediaStorage
 
 from ._base import (
     FileInfo,
@@ -28,13 +33,22 @@ from ._base import (
     respond_with_responder,
 )
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class ThumbnailResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo, media_storage):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
         super().__init__()
 
         self.store = hs.get_datastore()
@@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.server_name = hs.hostname
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         set_cors_headers(request)
         server_name, media_id, _ = parse_media_id(request)
         width = parse_integer(request, "width", required=True)
@@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
             self.media_repo.mark_recently_accessed(server_name, media_id)
 
     async def _respond_local_thumbnail(
-        self, request, media_id, width, height, method, m_type
-    ):
+        self,
+        request: Request,
+        media_id: str,
+        width: int,
+        height: int,
+        method: str,
+        m_type: str,
+    ) -> None:
         media_info = await self.store.get_local_media(media_id)
 
         if not media_info:
@@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource):
 
     async def _select_or_generate_local_thumbnail(
         self,
-        request,
-        media_id,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-    ):
+        request: Request,
+        media_id: str,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+    ) -> None:
         media_info = await self.store.get_local_media(media_id)
 
         if not media_info:
@@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource):
 
     async def _select_or_generate_remote_thumbnail(
         self,
-        request,
-        server_name,
-        media_id,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-    ):
+        request: Request,
+        server_name: str,
+        media_id: str,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+    ) -> None:
         media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
 
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
@@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource):
             raise SynapseError(400, "Failed to generate thumbnail.")
 
     async def _respond_remote_thumbnail(
-        self, request, server_name, media_id, width, height, method, m_type
-    ):
+        self,
+        request: Request,
+        server_name: str,
+        media_id: str,
+        width: int,
+        height: int,
+        method: str,
+        m_type: str,
+    ) -> None:
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead of
         # downloading the remote file and generating our own thumbnails.
@@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource):
 
     def _select_thumbnail(
         self,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
         thumbnail_infos,
-    ):
+    ) -> dict:
         d_w = desired_width
         d_h = desired_height
 
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 32a8e4f960..07903e4017 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
 # limitations under the License.
 import logging
 from io import BytesIO
+from typing import Tuple
 
 from PIL import Image
 
@@ -39,7 +41,7 @@ class Thumbnailer:
 
     FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
 
-    def __init__(self, input_path):
+    def __init__(self, input_path: str):
         try:
             self.image = Image.open(input_path)
         except OSError as e:
@@ -59,11 +61,11 @@ class Thumbnailer:
             # A lot of parsing errors can happen when parsing EXIF
             logger.info("Error parsing image EXIF information: %s", e)
 
-    def transpose(self):
+    def transpose(self) -> Tuple[int, int]:
         """Transpose the image using its EXIF Orientation tag
 
         Returns:
-            Tuple[int, int]: (width, height) containing the new image size in pixels.
+            A tuple containing the new image size in pixels as (width, height).
         """
         if self.transpose_method is not None:
             self.image = self.image.transpose(self.transpose_method)
@@ -73,7 +75,7 @@ class Thumbnailer:
             self.image.info["exif"] = None
         return self.image.size
 
-    def aspect(self, max_width, max_height):
+    def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
         """Calculate the largest size that preserves aspect ratio which
         fits within the given rectangle::
 
@@ -91,7 +93,7 @@ class Thumbnailer:
         else:
             return (max_height * self.width) // self.height, max_height
 
-    def _resize(self, width, height):
+    def _resize(self, width: int, height: int) -> Image:
         # 1-bit or 8-bit color palette images need converting to RGB
         # otherwise they will be scaled using nearest neighbour which
         # looks awful
@@ -99,7 +101,7 @@ class Thumbnailer:
             self.image = self.image.convert("RGB")
         return self.image.resize((width, height), Image.ANTIALIAS)
 
-    def scale(self, width, height, output_type):
+    def scale(self, width: int, height: int, output_type: str) -> BytesIO:
         """Rescales the image to the given dimensions.
 
         Returns:
@@ -108,7 +110,7 @@ class Thumbnailer:
         scaled = self._resize(width, height)
         return self._encode_image(scaled, output_type)
 
-    def crop(self, width, height, output_type):
+    def crop(self, width: int, height: int, output_type: str) -> BytesIO:
         """Rescales and crops the image to the given dimensions preserving
         aspect::
             (w_in / h_in) = (w_scaled / h_scaled)
@@ -136,7 +138,7 @@ class Thumbnailer:
             cropped = scaled_image.crop((crop_left, 0, crop_right, height))
         return self._encode_image(cropped, output_type)
 
-    def _encode_image(self, output_image, output_type):
+    def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
         output_bytes_io = BytesIO()
         fmt = self.FORMATS[output_type]
         if fmt == "JPEG":
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 42febc9afc..6da76ae994 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,18 +15,25 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_string
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class UploadResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
 
         self.media_repo = media_repo
@@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource):
         self.max_upload_size = hs.config.max_upload_size
         self.clock = hs.get_clock()
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         respond_with_json(request, 200, {}, send_cors=True)
 
-    async def _async_render_POST(self, request):
+    async def _async_render_POST(self, request: Request) -> None:
         requester = await self.auth.get_user_by_req(request)
         # TODO: The checks here are a bit late. The content will have
         # already been uploaded to a tmp file at this point
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7128dc1742..e46e44ba54 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -16,6 +16,8 @@
 import logging
 from typing import Dict, List, Optional, Tuple
 
+import attr
+
 from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
@@ -28,6 +30,25 @@ from synapse.types import JsonDict
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True)
+class _CalculateChainCover:
+    """Return value for _calculate_chain_cover_txn.
+    """
+
+    # The last room_id/depth/stream processed.
+    room_id = attr.ib(type=str)
+    depth = attr.ib(type=int)
+    stream = attr.ib(type=int)
+
+    # Number of rows processed
+    processed_count = attr.ib(type=int)
+
+    # Map from room_id to last depth/stream processed for each room that we have
+    # processed all events for (i.e. the rooms we can flip the
+    # `has_auth_chain_index` for)
+    finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+
+
 class EventsBackgroundUpdatesStore(SQLBaseStore):
 
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         current_room_id = progress.get("current_room_id", "")
 
-        # Have we finished processing the current room.
-        finished = progress.get("finished", True)
-
         # Where we've processed up to in the room, defaults to the start of the
         # room.
         last_depth = progress.get("last_depth", -1)
         last_stream = progress.get("last_stream", -1)
 
-        # Have we set the `has_auth_chain_index` for the room yet.
-        has_set_room_has_chain_index = progress.get(
-            "has_set_room_has_chain_index", False
+        result = await self.db_pool.runInteraction(
+            "_chain_cover_index",
+            self._calculate_chain_cover_txn,
+            current_room_id,
+            last_depth,
+            last_stream,
+            batch_size,
+            single_room=False,
         )
 
-        if finished:
-            # If we've finished with the previous room (or its our first
-            # iteration) we move on to the next room.
-
-            def _get_next_room(txn: Cursor) -> Optional[str]:
-                sql = """
-                    SELECT room_id FROM rooms
-                    WHERE room_id > ?
-                        AND (
-                            NOT has_auth_chain_index
-                            OR has_auth_chain_index IS NULL
-                        )
-                    ORDER BY room_id
-                    LIMIT 1
-                """
-                txn.execute(sql, (current_room_id,))
-                row = txn.fetchone()
-                if row:
-                    return row[0]
+        finished = result.processed_count == 0
 
-                return None
-
-            current_room_id = await self.db_pool.runInteraction(
-                "_chain_cover_index", _get_next_room
-            )
-            if not current_room_id:
-                await self.db_pool.updates._end_background_update("chain_cover")
-                return 0
-
-            logger.debug("Adding chain cover to %s", current_room_id)
-
-        def _calculate_auth_chain(
-            txn: Cursor, last_depth: int, last_stream: int
-        ) -> Tuple[int, int, int]:
-            # Get the next set of events in the room (that we haven't already
-            # computed chain cover for). We do this in topological order.
-
-            # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
-            # comparison, but that is not supported on older SQLite versions
-            tuple_clause, tuple_args = make_tuple_comparison_clause(
-                self.database_engine,
-                [
-                    ("topological_ordering", last_depth),
-                    ("stream_ordering", last_stream),
-                ],
-            )
+        total_rows_processed = result.processed_count
+        current_room_id = result.room_id
+        last_depth = result.depth
+        last_stream = result.stream
 
-            sql = """
-                SELECT
-                    event_id, state_events.type, state_events.state_key,
-                    topological_ordering, stream_ordering
-                FROM events
-                INNER JOIN state_events USING (event_id)
-                LEFT JOIN event_auth_chains USING (event_id)
-                LEFT JOIN event_auth_chain_to_calculate USING (event_id)
-                WHERE events.room_id = ?
-                    AND event_auth_chains.event_id IS NULL
-                    AND event_auth_chain_to_calculate.event_id IS NULL
-                    AND %(tuple_cmp)s
-                ORDER BY topological_ordering, stream_ordering
-                LIMIT ?
-            """ % {
-                "tuple_cmp": tuple_clause,
-            }
-
-            args = [current_room_id]
-            args.extend(tuple_args)
-            args.append(batch_size)
-
-            txn.execute(sql, args)
-            rows = txn.fetchall()
-
-            # Put the results in the necessary format for
-            # `_add_chain_cover_index`
-            event_to_room_id = {row[0]: current_room_id for row in rows}
-            event_to_types = {row[0]: (row[1], row[2]) for row in rows}
-
-            new_last_depth = rows[-1][3] if rows else last_depth  # type: int
-            new_last_stream = rows[-1][4] if rows else last_stream  # type: int
-
-            count = len(rows)
-
-            # We also need to fetch the auth events for them.
-            auth_events = self.db_pool.simple_select_many_txn(
-                txn,
-                table="event_auth",
-                column="event_id",
-                iterable=event_to_room_id,
-                keyvalues={},
-                retcols=("event_id", "auth_id"),
-            )
-
-            event_to_auth_chain = {}  # type: Dict[str, List[str]]
-            for row in auth_events:
-                event_to_auth_chain.setdefault(row["event_id"], []).append(
-                    row["auth_id"]
-                )
-
-            # Calculate and persist the chain cover index for this set of events.
-            #
-            # Annoyingly we need to gut wrench into the persit event store so that
-            # we can reuse the function to calculate the chain cover for rooms.
-            PersistEventsStore._add_chain_cover_index(
-                txn,
-                self.db_pool,
-                event_to_room_id,
-                event_to_types,
-                event_to_auth_chain,
-            )
-
-            return new_last_depth, new_last_stream, count
-
-        last_depth, last_stream, count = await self.db_pool.runInteraction(
-            "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
-        )
-
-        total_rows_processed = count
-
-        if count < batch_size and not has_set_room_has_chain_index:
+        for room_id, (depth, stream) in result.finished_room_map.items():
             # If we've done all the events in the room we flip the
             # `has_auth_chain_index` in the DB. Note that its possible for
             # further events to be persisted between the above and setting the
@@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             await self.db_pool.simple_update(
                 table="rooms",
-                keyvalues={"room_id": current_room_id},
+                keyvalues={"room_id": room_id},
                 updatevalues={"has_auth_chain_index": True},
                 desc="_chain_cover_index",
             )
-            has_set_room_has_chain_index = True
 
             # Handle any events that might have raced with us flipping the
             # bit above.
-            last_depth, last_stream, count = await self.db_pool.runInteraction(
-                "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+            result = await self.db_pool.runInteraction(
+                "_chain_cover_index",
+                self._calculate_chain_cover_txn,
+                room_id,
+                depth,
+                stream,
+                batch_size=None,
+                single_room=True,
             )
 
-            total_rows_processed += count
+            total_rows_processed += result.processed_count
 
-            # Note that at this point its technically possible that more events
-            # than our `batch_size` have been persisted without their chain
-            # cover, so we need to continue processing this room if the last
-            # count returned was equal to the `batch_size`.
+        if finished:
+            await self.db_pool.updates._end_background_update("chain_cover")
+            return total_rows_processed
 
-        if count < batch_size:
-            # We've finished calculating the index for this room, move on to the
-            # next room.
-            await self.db_pool.updates._background_update_progress(
-                "chain_cover", {"current_room_id": current_room_id, "finished": True},
-            )
-        else:
-            # We still have outstanding events to calculate the index for.
-            await self.db_pool.updates._background_update_progress(
-                "chain_cover",
-                {
-                    "current_room_id": current_room_id,
-                    "last_depth": last_depth,
-                    "last_stream": last_stream,
-                    "has_auth_chain_index": has_set_room_has_chain_index,
-                    "finished": False,
-                },
-            )
+        await self.db_pool.updates._background_update_progress(
+            "chain_cover",
+            {
+                "current_room_id": current_room_id,
+                "last_depth": last_depth,
+                "last_stream": last_stream,
+            },
+        )
 
         return total_rows_processed
+
+    def _calculate_chain_cover_txn(
+        self,
+        txn: Cursor,
+        last_room_id: str,
+        last_depth: int,
+        last_stream: int,
+        batch_size: Optional[int],
+        single_room: bool,
+    ) -> _CalculateChainCover:
+        """Calculate the chain cover for `batch_size` events, ordered by
+        `(room_id, depth, stream)`.
+
+        Args:
+            txn,
+            last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
+                tuple to fetch results after.
+            batch_size: The maximum number of events to process. If None then
+                no limit.
+            single_room: Whether to calculate the index for just the given
+                room.
+        """
+
+        # Get the next set of events in the room (that we haven't already
+        # computed chain cover for). We do this in topological order.
+
+        # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+        # comparison, but that is not supported on older SQLite versions
+        tuple_clause, tuple_args = make_tuple_comparison_clause(
+            self.database_engine,
+            [
+                ("events.room_id", last_room_id),
+                ("topological_ordering", last_depth),
+                ("stream_ordering", last_stream),
+            ],
+        )
+
+        extra_clause = ""
+        if single_room:
+            extra_clause = "AND events.room_id = ?"
+            tuple_args.append(last_room_id)
+
+        sql = """
+            SELECT
+                event_id, state_events.type, state_events.state_key,
+                topological_ordering, stream_ordering,
+                events.room_id
+            FROM events
+            INNER JOIN state_events USING (event_id)
+            LEFT JOIN event_auth_chains USING (event_id)
+            LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+            WHERE event_auth_chains.event_id IS NULL
+                AND event_auth_chain_to_calculate.event_id IS NULL
+                AND %(tuple_cmp)s
+                %(extra)s
+            ORDER BY events.room_id, topological_ordering, stream_ordering
+            %(limit)s
+        """ % {
+            "tuple_cmp": tuple_clause,
+            "limit": "LIMIT ?" if batch_size is not None else "",
+            "extra": extra_clause,
+        }
+
+        if batch_size is not None:
+            tuple_args.append(batch_size)
+
+        txn.execute(sql, tuple_args)
+        rows = txn.fetchall()
+
+        # Put the results in the necessary format for
+        # `_add_chain_cover_index`
+        event_to_room_id = {row[0]: row[5] for row in rows}
+        event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+        # Calculate the new last position we've processed up to.
+        new_last_depth = rows[-1][3] if rows else last_depth  # type: int
+        new_last_stream = rows[-1][4] if rows else last_stream  # type: int
+        new_last_room_id = rows[-1][5] if rows else ""  # type: str
+
+        # Map from room_id to last depth/stream_ordering processed for the room,
+        # excluding the last room (which we're likely still processing). We also
+        # need to include the room passed in if it's not included in the result
+        # set (as we then know we've processed all events in said room).
+        #
+        # This is the set of rooms that we can now safely flip the
+        # `has_auth_chain_index` bit for.
+        finished_rooms = {
+            row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
+        }
+        if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
+            finished_rooms[last_room_id] = (last_depth, last_stream)
+
+        count = len(rows)
+
+        # We also need to fetch the auth events for them.
+        auth_events = self.db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth",
+            column="event_id",
+            iterable=event_to_room_id,
+            keyvalues={},
+            retcols=("event_id", "auth_id"),
+        )
+
+        event_to_auth_chain = {}  # type: Dict[str, List[str]]
+        for row in auth_events:
+            event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+
+        # Calculate and persist the chain cover index for this set of events.
+        #
+        # Annoyingly we need to gut wrench into the persit event store so that
+        # we can reuse the function to calculate the chain cover for rooms.
+        PersistEventsStore._add_chain_cover_index(
+            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+        )
+
+        return _CalculateChainCover(
+            room_id=new_last_room_id,
+            depth=new_last_depth,
+            stream=new_last_stream,
+            processed_count=count,
+            finished_room_map=finished_rooms,
+        )
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 4b2f224718..283c8a5e22 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def get_local_media_before(
         self, before_ts: int, size_gt: int, keep_profiles: bool,
-    ) -> Optional[List[str]]:
+    ) -> List[str]:
 
         # to find files that have never been accessed (last_access_ts IS NULL)
         # compare with `created_ts`
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 77ba9d819e..bc7621b8d6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -17,14 +17,13 @@
 import logging
 from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
 
-from canonicaljson import encode_canonical_json
-
 from synapse.push import PusherConfig, ThrottleParams
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
+from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
@@ -315,7 +314,7 @@ class PusherStore(PusherWorkerStore):
                     "device_display_name": device_display_name,
                     "ts": pushkey_ts,
                     "lang": lang,
-                    "data": bytearray(encode_canonical_json(data)),
+                    "data": json_encoder.encode(data),
                     "last_stream_ordering": last_stream_ordering,
                     "profile_tag": profile_tag,
                     "id": stream_id,
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 61d96a6c28..b103c8694c 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -75,3 +75,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
     if len(items) <= maxitems:
         return str(items)
     return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
+
+
+def strtobool(val: str) -> bool:
+    """Convert a string representation of truth to True or False
+
+    True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
+    are 'n', 'no', 'f', 'false', 'off', and '0'.  Raises ValueError if
+    'val' is anything else.
+
+    This is lifted from distutils.util.strtobool, with the exception that it actually
+    returns a bool, rather than an int.
+    """
+    val = val.lower()
+    if val in ("y", "yes", "t", "true", "on", "1"):
+        return True
+    elif val in ("n", "no", "f", "false", "off", "0"):
+        return False
+    else:
+        raise ValueError("invalid truth value %r" % (val,))