summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py12
-rw-r--r--synapse/app/_base.py17
-rw-r--r--synapse/app/generic_worker.py5
-rw-r--r--synapse/app/homeserver.py5
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/config/oembed.py180
-rw-r--r--synapse/config/server.py87
-rw-r--r--synapse/event_auth.py15
-rw-r--r--synapse/handlers/_base.py68
-rw-r--r--synapse/handlers/federation.py1
-rw-r--r--synapse/handlers/federation_event.py180
-rw-r--r--synapse/handlers/message.py55
-rw-r--r--synapse/handlers/room.py12
-rw-r--r--synapse/handlers/room_list.py12
-rw-r--r--synapse/handlers/room_member.py65
-rw-r--r--synapse/handlers/room_summary.py57
-rw-r--r--synapse/handlers/stats.py6
-rw-r--r--synapse/handlers/sync.py67
-rw-r--r--synapse/http/servlet.py19
-rw-r--r--synapse/push/mailer.py24
-rw-r--r--synapse/res/providers.json17
-rw-r--r--synapse/rest/admin/server_notice_servlet.py6
-rw-r--r--synapse/rest/client/_base.py11
-rw-r--r--synapse/rest/client/account.py82
-rw-r--r--synapse/rest/client/account_data.py37
-rw-r--r--synapse/rest/client/groups.py22
-rw-r--r--synapse/rest/client/knock.py6
-rw-r--r--synapse/rest/client/push_rule.py112
-rw-r--r--synapse/rest/client/receipts.py15
-rw-r--r--synapse/rest/client/register.py92
-rw-r--r--synapse/rest/client/relations.py80
-rw-r--r--synapse/rest/client/report_event.py15
-rw-r--r--synapse/rest/client/room.py233
-rw-r--r--synapse/rest/client/room_batch.py31
-rw-r--r--synapse/rest/client/room_keys.py53
-rw-r--r--synapse/rest/client/sendtodevice.py27
-rw-r--r--synapse/rest/client/sync.py16
-rw-r--r--synapse/rest/client/transactions.py52
-rw-r--r--synapse/rest/media/v1/oembed.py135
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py256
-rw-r--r--synapse/storage/database.py61
-rw-r--r--synapse/storage/databases/main/directory.py4
-rw-r--r--synapse/storage/databases/main/events.py62
-rw-r--r--synapse/storage/databases/main/presence.py23
-rw-r--r--synapse/storage/databases/main/room.py97
-rw-r--r--synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql17
-rw-r--r--synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql18
-rw-r--r--synapse/storage/util/id_generators.py5
-rw-r--r--synapse/util/manhole.py15
49 files changed, 1727 insertions, 762 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 829061c870..5f0f34119b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -198,6 +198,12 @@ class EventContentFields:
     # cf https://github.com/matrix-org/matrix-doc/pull/1772
     ROOM_TYPE = "type"
 
+    # The creator of the room, as used in `m.room.create` events.
+    ROOM_CREATOR = "creator"
+
+    # Used in m.room.guest_access events.
+    GUEST_ACCESS = "guest_access"
+
     # Used on normal messages to indicate they were historically imported after the fact
     MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
     # For "insertion" events to indicate what the next chunk ID should be in
@@ -232,5 +238,11 @@ class HistoryVisibility:
     WORLD_READABLE = "world_readable"
 
 
+class GuestAccess:
+    CAN_JOIN = "can_join"
+    # anything that is not "can_join" is considered "forbidden", but for completeness:
+    FORBIDDEN = "forbidden"
+
+
 class ReadReceiptEventFields:
     MSC2285_HIDDEN = "org.matrix.msc2285.hidden"
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 39e28aff9f..89bda00090 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import atexit
 import gc
 import logging
 import os
@@ -36,6 +37,7 @@ from synapse.api.constants import MAX_PDU_SIZE
 from synapse.app import check_bind_error
 from synapse.app.phone_stats_home import start_phone_stats_home
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import ManholeConfig
 from synapse.crypto import context_factory
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.events.spamcheck import load_legacy_spam_checkers
@@ -229,7 +231,12 @@ def listen_metrics(bind_addresses, port):
         start_http_server(port, addr=host, registry=RegistryProxy)
 
 
-def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: dict):
+def listen_manhole(
+    bind_addresses: Iterable[str],
+    port: int,
+    manhole_settings: ManholeConfig,
+    manhole_globals: dict,
+):
     # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
     # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
     # suppress the warning for now.
@@ -244,7 +251,7 @@ def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: di
     listen_tcp(
         bind_addresses,
         port,
-        manhole(username="matrix", password="rabbithole", globals=manhole_globals),
+        manhole(settings=manhole_settings, globals=manhole_globals),
     )
 
 
@@ -403,6 +410,12 @@ async def start(hs: "HomeServer"):
         gc.collect()
         gc.freeze()
 
+    # Speed up shutdowns by freezing all allocated objects. This moves everything
+    # into the permanent generation and excludes them from the final GC.
+    # Unfortunately only works on Python 3.7
+    if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
+        atexit.register(gc.freeze)
+
 
 def setup_sentry(hs):
     """Enable sentry integration, if enabled in configuration
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 9b71dd75e6..2eb8d5a79c 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -395,7 +395,10 @@ class GenericWorkerServer(HomeServer):
                 self._listen_http(listener)
             elif listener.type == "manhole":
                 _base.listen_manhole(
-                    listener.bind_addresses, listener.port, manhole_globals={"hs": self}
+                    listener.bind_addresses,
+                    listener.port,
+                    manhole_settings=self.config.server.manhole_settings,
+                    manhole_globals={"hs": self},
                 )
             elif listener.type == "metrics":
                 if not self.config.enable_metrics:
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 7dae163c1a..708db86f5d 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -291,7 +291,10 @@ class SynapseHomeServer(HomeServer):
                 )
             elif listener.type == "manhole":
                 _base.listen_manhole(
-                    listener.bind_addresses, listener.port, manhole_globals={"hs": self}
+                    listener.bind_addresses,
+                    listener.port,
+                    manhole_settings=self.config.server.manhole_settings,
+                    manhole_globals={"hs": self},
                 )
             elif listener.type == "replication":
                 services = listen_tcp(
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 1f42a51857..442f1b9ac0 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -30,6 +30,7 @@ from .key import KeyConfig
 from .logger import LoggingConfig
 from .metrics import MetricsConfig
 from .modules import ModulesConfig
+from .oembed import OembedConfig
 from .oidc import OIDCConfig
 from .password_auth_providers import PasswordAuthProviderConfig
 from .push import PushConfig
@@ -65,6 +66,7 @@ class HomeServerConfig(RootConfig):
         LoggingConfig,
         RatelimitConfig,
         ContentRepositoryConfig,
+        OembedConfig,
         CaptchaConfig,
         VoipConfig,
         RegistrationConfig,
diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py
new file mode 100644
index 0000000000..09267b5eef
--- /dev/null
+++ b/synapse/config/oembed.py
@@ -0,0 +1,180 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import re
+from typing import Any, Dict, Iterable, List, Pattern
+from urllib import parse as urlparse
+
+import attr
+import pkg_resources
+
+from synapse.types import JsonDict
+
+from ._base import Config, ConfigError
+from ._util import validate_config
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class OEmbedEndpointConfig:
+    # The API endpoint to fetch.
+    api_endpoint: str
+    # The patterns to match.
+    url_patterns: List[Pattern]
+
+
+class OembedConfig(Config):
+    """oEmbed Configuration"""
+
+    section = "oembed"
+
+    def read_config(self, config, **kwargs):
+        oembed_config: Dict[str, Any] = config.get("oembed") or {}
+
+        # A list of patterns which will be used.
+        self.oembed_patterns: List[OEmbedEndpointConfig] = list(
+            self._parse_and_validate_providers(oembed_config)
+        )
+
+    def _parse_and_validate_providers(
+        self, oembed_config: dict
+    ) -> Iterable[OEmbedEndpointConfig]:
+        """Extract and parse the oEmbed providers from the given JSON file.
+
+        Returns a generator which yields the OidcProviderConfig objects
+        """
+        # Whether to use the packaged providers.json file.
+        if not oembed_config.get("disable_default_providers") or False:
+            providers = json.load(
+                pkg_resources.resource_stream("synapse", "res/providers.json")
+            )
+            yield from self._parse_and_validate_provider(
+                providers, config_path=("oembed",)
+            )
+
+        # The JSON files which includes additional provider information.
+        for i, file in enumerate(oembed_config.get("additional_providers") or []):
+            # TODO Error checking.
+            with open(file) as f:
+                providers = json.load(f)
+
+            yield from self._parse_and_validate_provider(
+                providers,
+                config_path=(
+                    "oembed",
+                    "additional_providers",
+                    f"<item {i}>",
+                ),
+            )
+
+    def _parse_and_validate_provider(
+        self, providers: List[JsonDict], config_path: Iterable[str]
+    ) -> Iterable[OEmbedEndpointConfig]:
+        # Ensure it is the proper form.
+        validate_config(
+            _OEMBED_PROVIDER_SCHEMA,
+            providers,
+            config_path=config_path,
+        )
+
+        # Parse it and yield each result.
+        for provider in providers:
+            # Each provider might have multiple API endpoints, each which
+            # might have multiple patterns to match.
+            for endpoint in provider["endpoints"]:
+                api_endpoint = endpoint["url"]
+                patterns = [
+                    self._glob_to_pattern(glob, config_path)
+                    for glob in endpoint["schemes"]
+                ]
+                yield OEmbedEndpointConfig(api_endpoint, patterns)
+
+    def _glob_to_pattern(self, glob: str, config_path: Iterable[str]) -> Pattern:
+        """
+        Convert the glob into a sane regular expression to match against. The
+        rules followed will be slightly different for the domain portion vs.
+        the rest.
+
+        1. The scheme must be one of HTTP / HTTPS (and have no globs).
+        2. The domain can have globs, but we limit it to characters that can
+           reasonably be a domain part.
+           TODO: This does not attempt to handle Unicode domain names.
+           TODO: The domain should not allow wildcard TLDs.
+        3. Other parts allow a glob to be any one, or more, characters.
+        """
+        results = urlparse.urlparse(glob)
+
+        # Ensure the scheme does not have wildcards (and is a sane scheme).
+        if results.scheme not in {"http", "https"}:
+            raise ConfigError(f"Insecure oEmbed scheme: {results.scheme}", config_path)
+
+        pattern = urlparse.urlunparse(
+            [
+                results.scheme,
+                re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
+            ]
+            + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
+        )
+        return re.compile(pattern)
+
+    def generate_config_section(self, **kwargs):
+        return """\
+        # oEmbed allows for easier embedding content from a website. It can be
+        # used for generating URLs previews of services which support it.
+        #
+        oembed:
+          # A default list of oEmbed providers is included with Synapse.
+          #
+          # Uncomment the following to disable using these default oEmbed URLs.
+          # Defaults to 'false'.
+          #
+          #disable_default_providers: true
+
+          # Additional files with oEmbed configuration (each should be in the
+          # form of providers.json).
+          #
+          # By default, this list is empty (so only the default providers.json
+          # is used).
+          #
+          #additional_providers:
+          #  - oembed/my_providers.json
+        """
+
+
+_OEMBED_PROVIDER_SCHEMA = {
+    "type": "array",
+    "items": {
+        "type": "object",
+        "properties": {
+            "provider_name": {"type": "string"},
+            "provider_url": {"type": "string"},
+            "endpoints": {
+                "type": "array",
+                "items": {
+                    "type": "object",
+                    "properties": {
+                        "schemes": {
+                            "type": "array",
+                            "items": {"type": "string"},
+                        },
+                        "url": {"type": "string"},
+                        "formats": {"type": "array", "items": {"type": "string"}},
+                        "discovery": {"type": "boolean"},
+                    },
+                    "required": ["schemes", "url"],
+                },
+            },
+        },
+        "required": ["provider_name", "provider_url", "endpoints"],
+    },
+}
diff --git a/synapse/config/server.py b/synapse/config/server.py
index d2c900f50c..7b9109a592 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -25,11 +25,14 @@ import attr
 import yaml
 from netaddr import AddrFormatError, IPNetwork, IPSet
 
+from twisted.conch.ssh.keys import Key
+
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.util.module_loader import load_module
 from synapse.util.stringutils import parse_and_validate_server_name
 
 from ._base import Config, ConfigError
+from ._util import validate_config
 
 logger = logging.Logger(__name__)
 
@@ -216,6 +219,16 @@ class ListenerConfig:
     http_options = attr.ib(type=Optional[HttpListenerConfig], default=None)
 
 
+@attr.s(frozen=True)
+class ManholeConfig:
+    """Object describing the configuration of the manhole"""
+
+    username = attr.ib(type=str, validator=attr.validators.instance_of(str))
+    password = attr.ib(type=str, validator=attr.validators.instance_of(str))
+    priv_key = attr.ib(type=Optional[Key])
+    pub_key = attr.ib(type=Optional[Key])
+
+
 class ServerConfig(Config):
     section = "server"
 
@@ -649,6 +662,41 @@ class ServerConfig(Config):
                 )
             )
 
+        manhole_settings = config.get("manhole_settings") or {}
+        validate_config(
+            _MANHOLE_SETTINGS_SCHEMA, manhole_settings, ("manhole_settings",)
+        )
+
+        manhole_username = manhole_settings.get("username", "matrix")
+        manhole_password = manhole_settings.get("password", "rabbithole")
+        manhole_priv_key_path = manhole_settings.get("ssh_priv_key_path")
+        manhole_pub_key_path = manhole_settings.get("ssh_pub_key_path")
+
+        manhole_priv_key = None
+        if manhole_priv_key_path is not None:
+            try:
+                manhole_priv_key = Key.fromFile(manhole_priv_key_path)
+            except Exception as e:
+                raise ConfigError(
+                    f"Failed to read manhole private key file {manhole_priv_key_path}"
+                ) from e
+
+        manhole_pub_key = None
+        if manhole_pub_key_path is not None:
+            try:
+                manhole_pub_key = Key.fromFile(manhole_pub_key_path)
+            except Exception as e:
+                raise ConfigError(
+                    f"Failed to read manhole public key file {manhole_pub_key_path}"
+                ) from e
+
+        self.manhole_settings = ManholeConfig(
+            username=manhole_username,
+            password=manhole_password,
+            priv_key=manhole_priv_key,
+            pub_key=manhole_pub_key,
+        )
+
         metrics_port = config.get("metrics_port")
         if metrics_port:
             logger.warning(METRICS_PORT_WARNING)
@@ -715,7 +763,7 @@ class ServerConfig(Config):
         if not isinstance(templates_config, dict):
             raise ConfigError("The 'templates' section must be a dictionary")
 
-        self.custom_template_directory = templates_config.get(
+        self.custom_template_directory: Optional[str] = templates_config.get(
             "custom_template_directory"
         )
         if self.custom_template_directory is not None and not isinstance(
@@ -727,7 +775,13 @@ class ServerConfig(Config):
         return any(listener.tls for listener in self.listeners)
 
     def generate_config_section(
-        self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
+        self,
+        server_name,
+        data_dir_path,
+        open_private_ports,
+        listeners,
+        config_dir_path,
+        **kwargs,
     ):
         ip_range_blacklist = "\n".join(
             "        #  - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
@@ -1068,6 +1122,24 @@ class ServerConfig(Config):
           #  bind_addresses: ['::1', '127.0.0.1']
           #  type: manhole
 
+        # Connection settings for the manhole
+        #
+        manhole_settings:
+          # The username for the manhole. This defaults to 'matrix'.
+          #
+          #username: manhole
+
+          # The password for the manhole. This defaults to 'rabbithole'.
+          #
+          #password: mypassword
+
+          # The private and public SSH key pair used to encrypt the manhole traffic.
+          # If these are left unset, then hardcoded and non-secret keys are used,
+          # which could allow traffic to be intercepted if sent over a public network.
+          #
+          #ssh_priv_key_path: %(config_dir_path)s/id_rsa
+          #ssh_pub_key_path: %(config_dir_path)s/id_rsa.pub
+
         # Forward extremities can build up in a room due to networking delays between
         # homeservers. Once this happens in a large room, calculation of the state of
         # that room can become quite expensive. To mitigate this, once the number of
@@ -1436,3 +1508,14 @@ def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
                 if name == "webclient":
                     logger.warning(NO_MORE_WEB_CLIENT_WARNING)
                     return
+
+
+_MANHOLE_SETTINGS_SCHEMA = {
+    "type": "object",
+    "properties": {
+        "username": {"type": "string"},
+        "password": {"type": "string"},
+        "ssh_priv_key_path": {"type": "string"},
+        "ssh_pub_key_path": {"type": "string"},
+    },
+}
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c3a0c10499..b63a1afe93 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -216,21 +216,18 @@ def check(
 
 
 def _check_size_limits(event: EventBase) -> None:
-    def too_big(field):
-        raise EventSizeError("%s too large" % (field,))
-
     if len(event.user_id) > 255:
-        too_big("user_id")
+        raise EventSizeError("'user_id' too large")
     if len(event.room_id) > 255:
-        too_big("room_id")
+        raise EventSizeError("'room_id' too large")
     if event.is_state() and len(event.state_key) > 255:
-        too_big("state_key")
+        raise EventSizeError("'state_key' too large")
     if len(event.type) > 255:
-        too_big("type")
+        raise EventSizeError("'type' too large")
     if len(event.event_id) > 255:
-        too_big("event_id")
+        raise EventSizeError("'event_id' too large")
     if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE:
-        too_big("event")
+        raise EventSizeError("event too large")
 
 
 def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 6a05a65305..955cfa2207 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -15,10 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Optional
 
-import synapse.types
-from synapse.api.constants import EventTypes, Membership
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.types import UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -115,68 +112,3 @@ class BaseHandler:
                 burst_count=burst_count,
                 update=update,
             )
-
-    async def maybe_kick_guest_users(self, event, context=None):
-        # Technically this function invalidates current_state by changing it.
-        # Hopefully this isn't that important to the caller.
-        if event.type == EventTypes.GuestAccess:
-            guest_access = event.content.get("guest_access", "forbidden")
-            if guest_access != "can_join":
-                if context:
-                    current_state_ids = await context.get_current_state_ids()
-                    current_state_dict = await self.store.get_events(
-                        list(current_state_ids.values())
-                    )
-                    current_state = list(current_state_dict.values())
-                else:
-                    current_state_map = await self.state_handler.get_current_state(
-                        event.room_id
-                    )
-                    current_state = list(current_state_map.values())
-
-                logger.info("maybe_kick_guest_users %r", current_state)
-                await self.kick_guest_users(current_state)
-
-    async def kick_guest_users(self, current_state):
-        for member_event in current_state:
-            try:
-                if member_event.type != EventTypes.Member:
-                    continue
-
-                target_user = UserID.from_string(member_event.state_key)
-                if not self.hs.is_mine(target_user):
-                    continue
-
-                if member_event.content["membership"] not in {
-                    Membership.JOIN,
-                    Membership.INVITE,
-                }:
-                    continue
-
-                if (
-                    "kind" not in member_event.content
-                    or member_event.content["kind"] != "guest"
-                ):
-                    continue
-
-                # We make the user choose to leave, rather than have the
-                # event-sender kick them. This is partially because we don't
-                # need to worry about power levels, and partially because guest
-                # users are a concept which doesn't hugely work over federation,
-                # and having homeservers have their own users leave keeps more
-                # of that decision-making and control local to the guest-having
-                # homeserver.
-                requester = synapse.types.create_requester(
-                    target_user, is_guest=True, authenticated_entity=self.server_name
-                )
-                handler = self.hs.get_room_member_handler()
-                await handler.update_membership(
-                    requester,
-                    target_user,
-                    member_event.room_id,
-                    "leave",
-                    ratelimit=False,
-                    require_consent=False,
-                )
-            except Exception as e:
-                logger.exception("Error kicking guest user: %s" % (e,))
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index daf1d3bfb3..77df9185f6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -507,6 +507,7 @@ class FederationHandler(BaseHandler):
             await self.store.upsert_room_on_join(
                 room_id=room_id,
                 room_version=room_version_obj,
+                auth_events=auth_chain,
             )
 
             max_stream_id = await self._persist_auth_tree(
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9f055f00cf..69f8287b2b 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -36,6 +36,7 @@ from synapse import event_auth
 from synapse.api.constants import (
     EventContentFields,
     EventTypes,
+    GuestAccess,
     Membership,
     RejectedReason,
     RoomEncryptionAlgorithms,
@@ -53,7 +54,6 @@ from synapse.event_auth import auth_types_for_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers._base import BaseHandler
 from synapse.logging.context import (
     make_deferred_yieldable,
     nested_logging_context,
@@ -116,7 +116,7 @@ class _NewEventInfo:
     claimed_auth_event_map: StateMap[EventBase]
 
 
-class FederationEventHandler(BaseHandler):
+class FederationEventHandler:
     """Handles events that originated from federation.
 
     Responsible for handing incoming events and passing them on to the rest
@@ -124,26 +124,28 @@ class FederationEventHandler(BaseHandler):
     """
 
     def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+        self._store = hs.get_datastore()
+        self._storage = hs.get_storage()
+        self._state_store = self._storage.state
 
-        self.store = hs.get_datastore()
-        self.storage = hs.get_storage()
-        self.state_store = self.storage.state
-
-        self.state_handler = hs.get_state_handler()
-        self.event_creation_handler = hs.get_event_creation_handler()
+        self._state_handler = hs.get_state_handler()
+        self._event_creation_handler = hs.get_event_creation_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
         self._message_handler = hs.get_message_handler()
-        self.action_generator = hs.get_action_generator()
+        self._action_generator = hs.get_action_generator()
         self._state_resolution_handler = hs.get_state_resolution_handler()
+        # avoid a circular dependency by deferring execution here
+        self._get_room_member_handler = hs.get_room_member_handler
 
-        self.federation_client = hs.get_federation_client()
-        self.third_party_event_rules = hs.get_third_party_event_rules()
+        self._federation_client = hs.get_federation_client()
+        self._third_party_event_rules = hs.get_third_party_event_rules()
+        self._notifier = hs.get_notifier()
 
-        self.is_mine_id = hs.is_mine_id
+        self._is_mine_id = hs.is_mine_id
+        self._server_name = hs.hostname
         self._instance_name = hs.get_instance_name()
 
-        self.config = hs.config
+        self._config = hs.config
         self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
 
         self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
@@ -175,7 +177,7 @@ class FederationEventHandler(BaseHandler):
         event_id = pdu.event_id
 
         # We reprocess pdus when we have seen them only as outliers
-        existing = await self.store.get_event(
+        existing = await self._store.get_event(
             event_id, allow_none=True, allow_rejected=True
         )
 
@@ -221,7 +223,7 @@ class FederationEventHandler(BaseHandler):
         # Note that if we were never in the room then we would have already
         # dropped the event, since we wouldn't know the room version.
         is_in_room = await self._event_auth_handler.check_host_in_room(
-            room_id, self.server_name
+            room_id, self._server_name
         )
         if not is_in_room:
             logger.info(
@@ -238,7 +240,7 @@ class FederationEventHandler(BaseHandler):
         #  - Fetching state if we have a hole in the graph
         if not pdu.internal_metadata.is_outlier():
             prevs = set(pdu.prev_event_ids())
-            seen = await self.store.have_events_in_timeline(prevs)
+            seen = await self._store.have_events_in_timeline(prevs)
             missing_prevs = prevs - seen
 
             if missing_prevs:
@@ -272,7 +274,7 @@ class FederationEventHandler(BaseHandler):
 
                     # Update the set of things we've seen after trying to
                     # fetch the missing stuff
-                    seen = await self.store.have_events_in_timeline(prevs)
+                    seen = await self._store.have_events_in_timeline(prevs)
                     missing_prevs = prevs - seen
 
                     if not missing_prevs:
@@ -361,7 +363,7 @@ class FederationEventHandler(BaseHandler):
         # the room, so we send it on their behalf.
         event.internal_metadata.send_on_behalf_of = origin
 
-        context = await self.state_handler.compute_event_context(event)
+        context = await self._state_handler.compute_event_context(event)
         context = await self._check_event_auth(origin, event, context)
         if context.rejected:
             raise SynapseError(
@@ -375,7 +377,7 @@ class FederationEventHandler(BaseHandler):
         # for knock events, we run the third-party event rules. It's not entirely clear
         # why we don't do this for other sorts of membership events.
         if event.membership == Membership.KNOCK:
-            event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
+            event_allowed, _ = await self._third_party_event_rules.check_event_allowed(
                 event, context
             )
             if not event_allowed:
@@ -404,7 +406,7 @@ class FederationEventHandler(BaseHandler):
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
         prev_member_event = None
         if prev_member_event_id:
-            prev_member_event = await self.store.get_event(prev_member_event_id)
+            prev_member_event = await self._store.get_event(prev_member_event_id)
 
         # Check if the member should be allowed access via membership in a space.
         await self._event_auth_handler.check_restricted_join_rules(
@@ -434,10 +436,10 @@ class FederationEventHandler(BaseHandler):
         server from invalid events (there is probably no point in trying to
         re-fetch invalid events from every other HS in the room.)
         """
-        if dest == self.server_name:
+        if dest == self._server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        events = await self.federation_client.backfill(
+        events = await self._federation_client.backfill(
             dest, room_id, limit=limit, extremities=extremities
         )
 
@@ -469,12 +471,12 @@ class FederationEventHandler(BaseHandler):
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        seen = await self.store.have_events_in_timeline(prevs)
+        seen = await self._store.have_events_in_timeline(prevs)
 
         if not prevs - seen:
             return
 
-        latest_list = await self.store.get_latest_event_ids_in_room(room_id)
+        latest_list = await self._store.get_latest_event_ids_in_room(room_id)
 
         # We add the prev events that we have seen to the latest
         # list to ensure the remote server doesn't give them to us
@@ -536,7 +538,7 @@ class FederationEventHandler(BaseHandler):
         # All that said: Let's try increasing the timeout to 60s and see what happens.
 
         try:
-            missing_events = await self.federation_client.get_missing_events(
+            missing_events = await self._federation_client.get_missing_events(
                 origin,
                 room_id,
                 earliest_events_ids=list(latest),
@@ -609,7 +611,7 @@ class FederationEventHandler(BaseHandler):
 
         event_id = event.event_id
 
-        existing = await self.store.get_event(
+        existing = await self._store.get_event(
             event_id, allow_none=True, allow_rejected=True
         )
         if existing:
@@ -674,7 +676,7 @@ class FederationEventHandler(BaseHandler):
         event_id = event.event_id
 
         prevs = set(event.prev_event_ids())
-        seen = await self.store.have_events_in_timeline(prevs)
+        seen = await self._store.have_events_in_timeline(prevs)
         missing_prevs = prevs - seen
 
         if not missing_prevs:
@@ -691,7 +693,7 @@ class FederationEventHandler(BaseHandler):
         event_map = {event_id: event}
         try:
             # Get the state of the events we know about
-            ours = await self.state_store.get_state_groups_ids(room_id, seen)
+            ours = await self._state_store.get_state_groups_ids(room_id, seen)
 
             # state_maps is a list of mappings from (type, state_key) to event_id
             state_maps: List[StateMap[str]] = list(ours.values())
@@ -720,13 +722,13 @@ class FederationEventHandler(BaseHandler):
                     for x in remote_state:
                         event_map[x.event_id] = x
 
-            room_version = await self.store.get_room_version_id(room_id)
+            room_version = await self._store.get_room_version_id(room_id)
             state_map = await self._state_resolution_handler.resolve_events_with_store(
                 room_id,
                 room_version,
                 state_maps,
                 event_map,
-                state_res_store=StateResolutionStore(self.store),
+                state_res_store=StateResolutionStore(self._store),
             )
 
             # We need to give _process_received_pdu the actual state events
@@ -734,7 +736,7 @@ class FederationEventHandler(BaseHandler):
 
             # First though we need to fetch all the events that are in
             # state_map, so we can build up the state below.
-            evs = await self.store.get_events(
+            evs = await self._store.get_events(
                 list(state_map.values()),
                 get_prev_content=False,
                 redact_behaviour=EventRedactBehaviour.AS_IS,
@@ -774,7 +776,7 @@ class FederationEventHandler(BaseHandler):
         (
             state_event_ids,
             auth_event_ids,
-        ) = await self.federation_client.get_room_state_ids(
+        ) = await self._federation_client.get_room_state_ids(
             destination, room_id, event_id=event_id
         )
 
@@ -788,7 +790,7 @@ class FederationEventHandler(BaseHandler):
         desired_events = set(state_event_ids)
         desired_events.add(event_id)
         logger.debug("Fetching %i events from cache/store", len(desired_events))
-        fetched_events = await self.store.get_events(
+        fetched_events = await self._store.get_events(
             desired_events, allow_rejected=True
         )
 
@@ -809,7 +811,7 @@ class FederationEventHandler(BaseHandler):
 
         missing_auth_events = set(auth_event_ids) - fetched_events.keys()
         missing_auth_events.difference_update(
-            await self.store.have_seen_events(room_id, missing_auth_events)
+            await self._store.have_seen_events(room_id, missing_auth_events)
         )
         logger.debug("We are also missing %i auth events", len(missing_auth_events))
 
@@ -822,7 +824,7 @@ class FederationEventHandler(BaseHandler):
         # we need to make sure we re-load from the database to get the rejected
         # state correct.
         fetched_events.update(
-            await self.store.get_events(missing_desired_events, allow_rejected=True)
+            await self._store.get_events(missing_desired_events, allow_rejected=True)
         )
 
         # check for events which were in the wrong room.
@@ -901,7 +903,7 @@ class FederationEventHandler(BaseHandler):
         logger.debug("Processing event: %s", event)
 
         try:
-            context = await self.state_handler.compute_event_context(
+            context = await self._state_handler.compute_event_context(
                 event, old_state=state
             )
             await self._auth_and_persist_event(
@@ -919,7 +921,7 @@ class FederationEventHandler(BaseHandler):
             device_id = event.content.get("device_id")
             sender_key = event.content.get("sender_key")
 
-            cached_devices = await self.store.get_cached_devices_for_user(event.sender)
+            cached_devices = await self._store.get_cached_devices_for_user(event.sender)
 
             resync = False  # Whether we should resync device lists.
 
@@ -995,10 +997,10 @@ class FederationEventHandler(BaseHandler):
         """
 
         try:
-            await self.store.mark_remote_user_device_cache_as_stale(sender)
+            await self._store.mark_remote_user_device_cache_as_stale(sender)
 
             # Immediately attempt a resync in the background
-            if self.config.worker_app:
+            if self._config.worker_app:
                 await self._user_device_resync(user_id=sender)
             else:
                 await self._device_list_updater.user_device_resync(sender)
@@ -1023,9 +1025,15 @@ class FederationEventHandler(BaseHandler):
             return
 
         # Skip processing a marker event if the room version doesn't
-        # support it.
-        room_version = await self.store.get_room_version(marker_event.room_id)
-        if not room_version.msc2716_historical:
+        # support it or the event is not from the room creator.
+        room_version = await self._store.get_room_version(marker_event.room_id)
+        create_event = await self._store.get_create_event_for_room(marker_event.room_id)
+        room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
+        if (
+            not room_version.msc2716_historical
+            or not self._config.experimental.msc2716_enabled
+            or marker_event.sender != room_creator
+        ):
             return
 
         logger.debug("_handle_marker_event: received %s", marker_event)
@@ -1048,7 +1056,7 @@ class FederationEventHandler(BaseHandler):
             [insertion_event_id],
         )
 
-        insertion_event = await self.store.get_event(
+        insertion_event = await self._store.get_event(
             insertion_event_id, allow_none=True
         )
         if insertion_event is None:
@@ -1066,7 +1074,7 @@ class FederationEventHandler(BaseHandler):
             marker_event,
         )
 
-        await self.store.insert_insertion_extremity(
+        await self._store.insert_insertion_extremity(
             insertion_event_id, marker_event.room_id
         )
 
@@ -1088,14 +1096,14 @@ class FederationEventHandler(BaseHandler):
         Logs a warning if we can't find the given event.
         """
 
-        room_version = await self.store.get_room_version(room_id)
+        room_version = await self._store.get_room_version(room_id)
 
         event_map: Dict[str, EventBase] = {}
 
         async def get_event(event_id: str):
             with nested_logging_context(event_id):
                 try:
-                    event = await self.federation_client.get_pdu(
+                    event = await self._federation_client.get_pdu(
                         [destination],
                         event_id,
                         room_version,
@@ -1131,7 +1139,7 @@ class FederationEventHandler(BaseHandler):
             for aid in event.auth_event_ids()
             if aid not in event_map
         ]
-        persisted_events = await self.store.get_events(
+        persisted_events = await self._store.get_events(
             auth_events,
             allow_rejected=True,
         )
@@ -1175,7 +1183,7 @@ class FederationEventHandler(BaseHandler):
         async def prep(ev_info: _NewEventInfo):
             event = ev_info.event
             with nested_logging_context(suffix=event.event_id):
-                res = await self.state_handler.compute_event_context(event)
+                res = await self._state_handler.compute_event_context(event)
                 res = await self._check_event_auth(
                     origin,
                     event,
@@ -1278,7 +1286,7 @@ class FederationEventHandler(BaseHandler):
         Returns:
             The updated context object.
         """
-        room_version = await self.store.get_room_version_id(event.room_id)
+        room_version = await self._store.get_room_version_id(event.room_id)
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
         if claimed_auth_event_map:
@@ -1291,7 +1299,7 @@ class FederationEventHandler(BaseHandler):
             auth_events_ids = self._event_auth_handler.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
-            auth_events_x = await self.store.get_events(auth_events_ids)
+            auth_events_x = await self._store.get_events(auth_events_ids)
             auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
 
         try:
@@ -1321,19 +1329,29 @@ class FederationEventHandler(BaseHandler):
 
         if not context.rejected:
             await self._check_for_soft_fail(event, state, backfilled, origin=origin)
-
-        if event.type == EventTypes.GuestAccess and not context.rejected:
-            await self.maybe_kick_guest_users(event)
+            await self._maybe_kick_guest_users(event)
 
         # If we are going to send this event over federation we precaclculate
         # the joined hosts.
         if event.internal_metadata.get_send_on_behalf_of():
-            await self.event_creation_handler.cache_joined_hosts_for_event(
+            await self._event_creation_handler.cache_joined_hosts_for_event(
                 event, context
             )
 
         return context
 
+    async def _maybe_kick_guest_users(self, event: EventBase) -> None:
+        if event.type != EventTypes.GuestAccess:
+            return
+
+        guest_access = event.content.get(EventContentFields.GUEST_ACCESS)
+        if guest_access == GuestAccess.CAN_JOIN:
+            return
+
+        current_state_map = await self._state_handler.get_current_state(event.room_id)
+        current_state = list(current_state_map.values())
+        await self._get_room_member_handler().kick_guest_users(current_state)
+
     async def _check_for_soft_fail(
         self,
         event: EventBase,
@@ -1356,7 +1374,7 @@ class FederationEventHandler(BaseHandler):
         if backfilled or event.internal_metadata.is_outlier():
             return
 
-        extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
+        extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
         extrem_ids = set(extrem_ids_list)
         prev_event_ids = set(event.prev_event_ids())
 
@@ -1365,7 +1383,7 @@ class FederationEventHandler(BaseHandler):
             # state at the event, so no point rechecking auth for soft fail.
             return
 
-        room_version = await self.store.get_room_version_id(event.room_id)
+        room_version = await self._store.get_room_version_id(event.room_id)
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
         # Calculate the "current state".
@@ -1382,19 +1400,19 @@ class FederationEventHandler(BaseHandler):
             # given state at the event. This should correctly handle cases
             # like bans, especially with state res v2.
 
-            state_sets_d = await self.state_store.get_state_groups(
+            state_sets_d = await self._state_store.get_state_groups(
                 event.room_id, extrem_ids
             )
             state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
             state_sets.append(state)
-            current_states = await self.state_handler.resolve_events(
+            current_states = await self._state_handler.resolve_events(
                 room_version, state_sets, event
             )
             current_state_ids: StateMap[str] = {
                 k: e.event_id for k, e in current_states.items()
             }
         else:
-            current_state_ids = await self.state_handler.get_current_state_ids(
+            current_state_ids = await self._state_handler.get_current_state_ids(
                 event.room_id, latest_event_ids=extrem_ids
             )
 
@@ -1410,7 +1428,7 @@ class FederationEventHandler(BaseHandler):
             e for k, e in current_state_ids.items() if k in auth_types
         ]
 
-        auth_events_map = await self.store.get_events(current_state_ids_list)
+        auth_events_map = await self._store.get_events(current_state_ids_list)
         current_auth_events = {
             (e.type, e.state_key): e for e in auth_events_map.values()
         }
@@ -1481,7 +1499,9 @@ class FederationEventHandler(BaseHandler):
         #
         # we start by checking if they are in the store, and then try calling /event_auth/.
         if missing_auth:
-            have_events = await self.store.have_seen_events(event.room_id, missing_auth)
+            have_events = await self._store.have_seen_events(
+                event.room_id, missing_auth
+            )
             logger.debug("Events %s are in the store", have_events)
             missing_auth.difference_update(have_events)
 
@@ -1490,7 +1510,7 @@ class FederationEventHandler(BaseHandler):
             logger.info("auth_events contains unknown events: %s", missing_auth)
             try:
                 try:
-                    remote_auth_chain = await self.federation_client.get_event_auth(
+                    remote_auth_chain = await self._federation_client.get_event_auth(
                         origin, event.room_id, event.event_id
                     )
                 except RequestSendFailed as e1:
@@ -1499,7 +1519,7 @@ class FederationEventHandler(BaseHandler):
                     logger.info("Failed to get event auth from remote: %s", e1)
                     return context, auth_events
 
-                seen_remotes = await self.store.have_seen_events(
+                seen_remotes = await self._store.have_seen_events(
                     event.room_id, [e.event_id for e in remote_auth_chain]
                 )
 
@@ -1525,7 +1545,7 @@ class FederationEventHandler(BaseHandler):
                             e.event_id,
                         )
                         missing_auth_event_context = (
-                            await self.state_handler.compute_event_context(e)
+                            await self._state_handler.compute_event_context(e)
                         )
                         await self._auth_and_persist_event(
                             origin,
@@ -1566,7 +1586,7 @@ class FederationEventHandler(BaseHandler):
 
         # XXX: currently this checks for redactions but I'm not convinced that is
         # necessary?
-        different_events = await self.store.get_events_as_list(different_auth)
+        different_events = await self._store.get_events_as_list(different_auth)
 
         for d in different_events:
             if d.room_id != event.room_id:
@@ -1592,8 +1612,8 @@ class FederationEventHandler(BaseHandler):
         remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
         remote_state = remote_auth_events.values()
 
-        room_version = await self.store.get_room_version_id(event.room_id)
-        new_state = await self.state_handler.resolve_events(
+        room_version = await self._store.get_room_version_id(event.room_id)
+        new_state = await self._state_handler.resolve_events(
             room_version, (local_state, remote_state), event
         )
 
@@ -1651,7 +1671,7 @@ class FederationEventHandler(BaseHandler):
 
         # create a new state group as a delta from the existing one.
         prev_group = context.state_group
-        state_group = await self.state_store.store_state_group(
+        state_group = await self._state_store.store_state_group(
             event.event_id,
             event.room_id,
             prev_group=prev_group,
@@ -1683,9 +1703,9 @@ class FederationEventHandler(BaseHandler):
                 not event.internal_metadata.is_outlier()
                 and not backfilled
                 and not context.rejected
-                and (await self.store.get_min_depth(event.room_id)) <= event.depth
+                and (await self._store.get_min_depth(event.room_id)) <= event.depth
             ):
-                await self.action_generator.handle_push_actions_for_event(
+                await self._action_generator.handle_push_actions_for_event(
                     event, context
                 )
 
@@ -1694,7 +1714,7 @@ class FederationEventHandler(BaseHandler):
             )
         except Exception:
             run_in_background(
-                self.store.remove_push_actions_from_staging, event.event_id
+                self._store.remove_push_actions_from_staging, event.event_id
             )
             raise
 
@@ -1719,27 +1739,27 @@ class FederationEventHandler(BaseHandler):
             The stream ID after which all events have been persisted.
         """
         if not event_and_contexts:
-            return self.store.get_current_events_token()
+            return self._store.get_current_events_token()
 
-        instance = self.config.worker.events_shard_config.get_instance(room_id)
+        instance = self._config.worker.events_shard_config.get_instance(room_id)
         if instance != self._instance_name:
             # Limit the number of events sent over replication. We choose 200
             # here as that is what we default to in `max_request_body_size(..)`
             for batch in batch_iter(event_and_contexts, 200):
                 result = await self._send_events(
                     instance_name=instance,
-                    store=self.store,
+                    store=self._store,
                     room_id=room_id,
                     event_and_contexts=batch,
                     backfilled=backfilled,
                 )
             return result["max_stream_id"]
         else:
-            assert self.storage.persistence
+            assert self._storage.persistence
 
             # Note that this returns the events that were persisted, which may not be
             # the same as were passed in if some were deduplicated due to transaction IDs.
-            events, max_stream_token = await self.storage.persistence.persist_events(
+            events, max_stream_token = await self._storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
@@ -1773,7 +1793,7 @@ class FederationEventHandler(BaseHandler):
             # users
             if event.internal_metadata.is_outlier():
                 if event.membership != Membership.INVITE:
-                    if not self.is_mine_id(target_user_id):
+                    if not self._is_mine_id(target_user_id):
                         return
 
             target_user = UserID.from_string(target_user_id)
@@ -1787,7 +1807,7 @@ class FederationEventHandler(BaseHandler):
         event_pos = PersistedEventPosition(
             self._instance_name, event.internal_metadata.stream_ordering
         )
-        self.notifier.on_new_room_event(
+        self._notifier.on_new_room_event(
             event, event_pos, max_stream_token, extra_users=extra_users
         )
 
@@ -1822,4 +1842,4 @@ class FederationEventHandler(BaseHandler):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
 
     async def get_min_depth_for_context(self, context: str) -> int:
-        return await self.store.get_min_depth(context)
+        return await self._store.get_min_depth(context)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 101a29c6d3..bf0fef1510 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -27,6 +27,7 @@ from synapse import event_auth
 from synapse.api.constants import (
     EventContentFields,
     EventTypes,
+    GuestAccess,
     Membership,
     RelationTypes,
     UserTypes,
@@ -426,7 +427,7 @@ class EventCreationHandler:
 
         self.send_event = ReplicationSendEventRestServlet.make_client(hs)
 
-        # This is only used to get at ratelimit function, and maybe_kick_guest_users
+        # This is only used to get at ratelimit function
         self.base_handler = BaseHandler(hs)
 
         # We arbitrarily limit concurrent event creation for a room to 5.
@@ -1306,7 +1307,7 @@ class EventCreationHandler:
                 requester, is_admin_redaction=is_admin_redaction
             )
 
-        await self.base_handler.maybe_kick_guest_users(event, context)
+        await self._maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
             # Validate a newly added alias or newly added alt_aliases.
@@ -1393,6 +1394,9 @@ class EventCreationHandler:
                 allow_none=True,
             )
 
+            room_version = await self.store.get_room_version_id(event.room_id)
+            room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
             # we can make some additional checks now if we have the original event.
             if original_event:
                 if original_event.type == EventTypes.Create:
@@ -1404,6 +1408,28 @@ class EventCreationHandler:
                 if original_event.type == EventTypes.ServerACL:
                     raise AuthError(403, "Redacting server ACL events is not permitted")
 
+                # Add a little safety stop-gap to prevent people from trying to
+                # redact MSC2716 related events when they're in a room version
+                # which does not support it yet. We allow people to use MSC2716
+                # events in existing room versions but only from the room
+                # creator since it does not require any changes to the auth
+                # rules and in effect, the redaction algorithm . In the
+                # supported room version, we add the `historical` power level to
+                # auth the MSC2716 related events and adjust the redaction
+                # algorthim to keep the `historical` field around (redacting an
+                # event should only strip fields which don't affect the
+                # structural protocol level).
+                is_msc2716_event = (
+                    original_event.type == EventTypes.MSC2716_INSERTION
+                    or original_event.type == EventTypes.MSC2716_CHUNK
+                    or original_event.type == EventTypes.MSC2716_MARKER
+                )
+                if not room_version_obj.msc2716_historical and is_msc2716_event:
+                    raise AuthError(
+                        403,
+                        "Redacting MSC2716 events is not supported in this room version",
+                    )
+
             prev_state_ids = await context.get_prev_state_ids()
             auth_events_ids = self._event_auth_handler.compute_auth_events(
                 event, prev_state_ids, for_verification=True
@@ -1411,9 +1437,6 @@ class EventCreationHandler:
             auth_events_map = await self.store.get_events(auth_events_ids)
             auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
 
-            room_version = await self.store.get_room_version_id(event.room_id)
-            room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
             if event_auth.check_redaction(
                 room_version_obj, event, auth_events=auth_events
             ):
@@ -1471,6 +1494,28 @@ class EventCreationHandler:
 
         return event
 
+    async def _maybe_kick_guest_users(
+        self, event: EventBase, context: EventContext
+    ) -> None:
+        if event.type != EventTypes.GuestAccess:
+            return
+
+        guest_access = event.content.get(EventContentFields.GUEST_ACCESS)
+        if guest_access == GuestAccess.CAN_JOIN:
+            return
+
+        current_state_ids = await context.get_current_state_ids()
+
+        # since this is a client-generated event, it cannot be an outlier and we must
+        # therefore have the state ids.
+        assert current_state_ids is not None
+        current_state_dict = await self.store.get_events(
+            list(current_state_ids.values())
+        )
+        current_state = list(current_state_dict.values())
+        logger.info("maybe_kick_guest_users %r", current_state)
+        await self.hs.get_room_member_handler().kick_guest_users(current_state)
+
     async def _bump_active_time(self, user: UserID) -> None:
         try:
             presence = self.hs.get_presence_handler()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b33fe09f77..0235fd09b4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,7 +25,9 @@ from collections import OrderedDict
 from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
 
 from synapse.api.constants import (
+    EventContentFields,
     EventTypes,
+    GuestAccess,
     HistoryVisibility,
     JoinRules,
     Membership,
@@ -909,7 +911,12 @@ class RoomCreationHandler(BaseHandler):
             )
             return last_stream_id
 
-        config = self._presets_dict[preset_config]
+        try:
+            config = self._presets_dict[preset_config]
+        except KeyError:
+            raise SynapseError(
+                400, f"'{preset_config}' is not a valid preset", errcode=Codes.BAD_JSON
+            )
 
         creation_content.update({"creator": creator_id})
         await send(etype=EventTypes.Create, content=creation_content)
@@ -988,7 +995,8 @@ class RoomCreationHandler(BaseHandler):
         if config["guest_can_join"]:
             if (EventTypes.GuestAccess, "") not in initial_state:
                 last_sent_stream_id = await send(
-                    etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
+                    etype=EventTypes.GuestAccess,
+                    content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
                 )
 
         for (etype, state_key), content in initial_state.items():
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 6d433fad41..92bb75c848 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -19,7 +19,13 @@ from typing import TYPE_CHECKING, Optional, Tuple
 import msgpack
 from unpaddedbase64 import decode_base64, encode_base64
 
-from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    GuestAccess,
+    HistoryVisibility,
+    JoinRules,
+)
 from synapse.api.errors import (
     Codes,
     HttpResponseException,
@@ -336,8 +342,8 @@ class RoomListHandler(BaseHandler):
         guest_event = current_state.get((EventTypes.GuestAccess, ""))
         guest = None
         if guest_event:
-            guest = guest_event.content.get("guest_access", None)
-        result["guest_can_join"] = guest == "can_join"
+            guest = guest_event.content.get(EventContentFields.GUEST_ACCESS)
+        result["guest_can_join"] = guest == GuestAccess.CAN_JOIN
 
         avatar_event = current_state.get(("m.room.avatar", ""))
         if avatar_event:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 401b84aad1..4390201641 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -23,6 +23,7 @@ from synapse.api.constants import (
     AccountDataTypes,
     EventContentFields,
     EventTypes,
+    GuestAccess,
     Membership,
 )
 from synapse.api.errors import (
@@ -44,6 +45,7 @@ from synapse.types import (
     RoomID,
     StateMap,
     UserID,
+    create_requester,
     get_domain_from_id,
 )
 from synapse.util.async_helpers import Linearizer
@@ -70,6 +72,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.auth = hs.get_auth()
         self.state_handler = hs.get_state_handler()
         self.config = hs.config
+        self._server_name = hs.hostname
 
         self.federation_handler = hs.get_federation_handler()
         self.directory_handler = hs.get_directory_handler()
@@ -115,9 +118,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
         )
 
-        # This is only used to get at ratelimit function, and
-        # maybe_kick_guest_users. It's fine there are multiple of these as
-        # it doesn't store state.
+        # This is only used to get at the ratelimit function. It's fine there are
+        # multiple of these as it doesn't store state.
         self.base_handler = BaseHandler(hs)
 
     @abc.abstractmethod
@@ -1095,10 +1097,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         return bool(
             guest_access
             and guest_access.content
-            and "guest_access" in guest_access.content
-            and guest_access.content["guest_access"] == "can_join"
+            and guest_access.content.get(EventContentFields.GUEST_ACCESS)
+            == GuestAccess.CAN_JOIN
         )
 
+    async def kick_guest_users(self, current_state: Iterable[EventBase]) -> None:
+        """Kick any local guest users from the room.
+
+        This is called when the room state changes from guests allowed to not-allowed.
+
+        Params:
+            current_state: the current state of the room. We will iterate this to look
+               for guest users to kick.
+        """
+        for member_event in current_state:
+            try:
+                if member_event.type != EventTypes.Member:
+                    continue
+
+                if not self.hs.is_mine_id(member_event.state_key):
+                    continue
+
+                if member_event.content["membership"] not in {
+                    Membership.JOIN,
+                    Membership.INVITE,
+                }:
+                    continue
+
+                if (
+                    "kind" not in member_event.content
+                    or member_event.content["kind"] != "guest"
+                ):
+                    continue
+
+                # We make the user choose to leave, rather than have the
+                # event-sender kick them. This is partially because we don't
+                # need to worry about power levels, and partially because guest
+                # users are a concept which doesn't hugely work over federation,
+                # and having homeservers have their own users leave keeps more
+                # of that decision-making and control local to the guest-having
+                # homeserver.
+                target_user = UserID.from_string(member_event.state_key)
+                requester = create_requester(
+                    target_user, is_guest=True, authenticated_entity=self._server_name
+                )
+                handler = self.hs.get_room_member_handler()
+                await handler.update_membership(
+                    requester,
+                    target_user,
+                    member_event.room_id,
+                    "leave",
+                    ratelimit=False,
+                    require_consent=False,
+                )
+            except Exception as e:
+                logger.exception("Error kicking guest user: %s" % (e,))
+
     async def lookup_room_alias(
         self, room_alias: RoomAlias
     ) -> Tuple[RoomID, List[str]]:
@@ -1352,7 +1406,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
         self.distributor = hs.get_distributor()
         self.distributor.declare("user_left_room")
-        self._server_name = hs.hostname
 
     async def _is_remote_room_too_complex(
         self, room_id: str, remote_room_hosts: List[str]
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 906985c754..781da9e811 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,9 +28,15 @@ from synapse.api.constants import (
     Membership,
     RoomTypes,
 )
-from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    NotFoundError,
+    StoreError,
+    SynapseError,
+    UnsupportedRoomVersionError,
+)
 from synapse.events import EventBase
-from synapse.events.utils import format_event_for_client_v2
 from synapse.types import JsonDict
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -82,7 +88,6 @@ class RoomSummaryHandler:
     _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000
 
     def __init__(self, hs: "HomeServer"):
-        self._clock = hs.get_clock()
         self._event_auth_handler = hs.get_event_auth_handler()
         self._store = hs.get_datastore()
         self._event_serializer = hs.get_event_client_serializer()
@@ -641,18 +646,18 @@ class RoomSummaryHandler:
         if max_children is None or max_children > MAX_ROOMS_PER_SPACE:
             max_children = MAX_ROOMS_PER_SPACE
 
-        now = self._clock.time_msec()
-        events_result: List[JsonDict] = []
-        for edge_event in itertools.islice(child_events, max_children):
-            events_result.append(
-                await self._event_serializer.serialize_event(
-                    edge_event,
-                    time_now=now,
-                    event_format=format_event_for_client_v2,
-                )
-            )
-
-        return _RoomEntry(room_id, room_entry, events_result)
+        stripped_events: List[JsonDict] = [
+            {
+                "type": e.type,
+                "state_key": e.state_key,
+                "content": e.content,
+                "room_id": e.room_id,
+                "sender": e.sender,
+                "origin_server_ts": e.origin_server_ts,
+            }
+            for e in itertools.islice(child_events, max_children)
+        ]
+        return _RoomEntry(room_id, room_entry, stripped_events)
 
     async def _summarize_remote_room(
         self,
@@ -814,7 +819,12 @@ class RoomSummaryHandler:
             logger.info("room %s is unknown, omitting from summary", room_id)
             return False
 
-        room_version = await self._store.get_room_version(room_id)
+        try:
+            room_version = await self._store.get_room_version(room_id)
+        except UnsupportedRoomVersionError:
+            # If a room with an unsupported room version is encountered, ignore
+            # it to avoid breaking the entire summary response.
+            return False
 
         # Include the room if it has join rules of public or knock.
         join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""))
@@ -1139,25 +1149,26 @@ def _is_suggested_child_event(edge_event: EventBase) -> bool:
 _INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]")
 
 
-def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]:
+def _child_events_comparison_key(
+    child: EventBase,
+) -> Tuple[bool, Optional[str], int, str]:
     """
     Generate a value for comparing two child events for ordering.
 
-    The rules for ordering are supposed to be:
+    The rules for ordering are:
 
     1. The 'order' key, if it is valid.
-    2. The 'origin_server_ts' of the 'm.room.create' event.
+    2. The 'origin_server_ts' of the 'm.space.child' event.
     3. The 'room_id'.
 
-    But we skip step 2 since we may not have any state from the room.
-
     Args:
         child: The event for generating a comparison key.
 
     Returns:
         The comparison key as a tuple of:
             False if the ordering is valid.
-            The ordering field.
+            The 'order' field or None if it is not given or invalid.
+            The 'origin_server_ts' field.
             The room ID.
     """
     order = child.content.get("order")
@@ -1168,4 +1179,4 @@ def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str],
         order = None
 
     # Items without an order come last.
-    return (order is None, order, child.room_id)
+    return (order is None, order, child.origin_server_ts, child.room_id)
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 3fd89af2a4..3a4c41c9ff 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
 
 from typing_extensions import Counter as CounterType
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import JsonDict
@@ -273,7 +273,9 @@ class StatsHandler:
             elif typ == EventTypes.CanonicalAlias:
                 room_state["canonical_alias"] = event_content.get("alias")
             elif typ == EventTypes.GuestAccess:
-                room_state["guest_access"] = event_content.get("guest_access")
+                room_state["guest_access"] = event_content.get(
+                    EventContentFields.GUEST_ACCESS
+                )
 
         for room_id, state in room_to_state_updates.items():
             logger.debug("Updating room_stats_state for %s: %s", room_id, state)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 86c3c7f0df..e017b28cd2 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -505,10 +505,13 @@ class SyncHandler:
             else:
                 limited = False
 
+            log_kv({"limited": limited})
+
             if potential_recents:
                 recents = sync_config.filter_collection.filter_room_timeline(
                     potential_recents
                 )
+                log_kv({"recents_after_sync_filtering": len(recents)})
 
                 # We check if there are any state events, if there are then we pass
                 # all current state events to the filter_events function. This is to
@@ -526,6 +529,7 @@ class SyncHandler:
                     recents,
                     always_include_ids=current_state_ids,
                 )
+                log_kv({"recents_after_visibility_filtering": len(recents)})
             else:
                 recents = []
 
@@ -566,10 +570,15 @@ class SyncHandler:
                     events, end_key = await self.store.get_recent_events_for_room(
                         room_id, limit=load_limit + 1, end_token=end_key
                     )
+
+                log_kv({"loaded_recents": len(events)})
+
                 loaded_recents = sync_config.filter_collection.filter_room_timeline(
                     events
                 )
 
+                log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)})
+
                 # We check if there are any state events, if there are then we pass
                 # all current state events to the filter_events function. This is to
                 # ensure that we always include current state in the timeline
@@ -586,6 +595,9 @@ class SyncHandler:
                     loaded_recents,
                     always_include_ids=current_state_ids,
                 )
+
+                log_kv({"loaded_recents_after_client_filtering": len(loaded_recents)})
+
                 loaded_recents.extend(recents)
                 recents = loaded_recents
 
@@ -1116,6 +1128,8 @@ class SyncHandler:
         logger.debug("Fetching group data")
         await self._generate_sync_entry_for_groups(sync_result_builder)
 
+        num_events = 0
+
         # debug for https://github.com/matrix-org/synapse/issues/4422
         for joined_room in sync_result_builder.joined:
             room_id = joined_room.room_id
@@ -1123,6 +1137,14 @@ class SyncHandler:
                 issue4422_logger.debug(
                     "Sync result for newly joined room %s: %r", room_id, joined_room
                 )
+            num_events += len(joined_room.timeline.events)
+
+        log_kv(
+            {
+                "joined_rooms_in_result": len(sync_result_builder.joined),
+                "events_in_result": num_events,
+            }
+        )
 
         logger.debug("Sync response calculation complete")
         return SyncResult(
@@ -1467,6 +1489,7 @@ class SyncHandler:
         if not sync_result_builder.full_state:
             if since_token and not ephemeral_by_room and not account_data_by_room:
                 have_changed = await self._have_rooms_changed(sync_result_builder)
+                log_kv({"rooms_have_changed": have_changed})
                 if not have_changed:
                     tags_by_room = await self.store.get_updated_tags(
                         user_id, since_token.account_data_key
@@ -1501,25 +1524,30 @@ class SyncHandler:
 
             tags_by_room = await self.store.get_tags_for_user(user_id)
 
+        log_kv({"rooms_changed": len(room_changes.room_entries)})
+
         room_entries = room_changes.room_entries
         invited = room_changes.invited
         knocked = room_changes.knocked
         newly_joined_rooms = room_changes.newly_joined_rooms
         newly_left_rooms = room_changes.newly_left_rooms
 
-        async def handle_room_entries(room_entry):
-            logger.debug("Generating room entry for %s", room_entry.room_id)
-            res = await self._generate_room_entry(
-                sync_result_builder,
-                ignored_users,
-                room_entry,
-                ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
-                tags=tags_by_room.get(room_entry.room_id),
-                account_data=account_data_by_room.get(room_entry.room_id, {}),
-                always_include=sync_result_builder.full_state,
-            )
-            logger.debug("Generated room entry for %s", room_entry.room_id)
-            return res
+        async def handle_room_entries(room_entry: "RoomSyncResultBuilder"):
+            with start_active_span("generate_room_entry"):
+                set_tag("room_id", room_entry.room_id)
+                log_kv({"events": len(room_entry.events or [])})
+                logger.debug("Generating room entry for %s", room_entry.room_id)
+                res = await self._generate_room_entry(
+                    sync_result_builder,
+                    ignored_users,
+                    room_entry,
+                    ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
+                    tags=tags_by_room.get(room_entry.room_id),
+                    account_data=account_data_by_room.get(room_entry.room_id, {}),
+                    always_include=sync_result_builder.full_state,
+                )
+                logger.debug("Generated room entry for %s", room_entry.room_id)
+                return res
 
         await concurrently_execute(handle_room_entries, room_entries, 10)
 
@@ -1932,6 +1960,12 @@ class SyncHandler:
         room_id = room_builder.room_id
         since_token = room_builder.since_token
         upto_token = room_builder.upto_token
+        log_kv(
+            {
+                "since_token": since_token,
+                "upto_token": upto_token,
+            }
+        )
 
         batch = await self._load_filtered_recents(
             room_id,
@@ -1941,6 +1975,13 @@ class SyncHandler:
             potential_recents=events,
             newly_joined_room=newly_joined,
         )
+        log_kv(
+            {
+                "batch_events": len(batch.events),
+                "prev_batch": batch.prev_batch,
+                "batch_limited": batch.limited,
+            }
+        )
 
         # Note: `batch` can be both empty and limited here in the case where
         # `_load_filtered_recents` can't find any events the user should see
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index a12fa30bfd..91ba93372c 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -572,6 +572,25 @@ def parse_string_from_args(
     return strings[0]
 
 
+@overload
+def parse_json_value_from_request(request: Request) -> JsonDict:
+    ...
+
+
+@overload
+def parse_json_value_from_request(
+    request: Request, allow_empty_body: Literal[False]
+) -> JsonDict:
+    ...
+
+
+@overload
+def parse_json_value_from_request(
+    request: Request, allow_empty_body: bool = False
+) -> Optional[JsonDict]:
+    ...
+
+
 def parse_json_value_from_request(
     request: Request, allow_empty_body: bool = False
 ) -> Optional[JsonDict]:
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 941fb238b7..b0834720ad 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -258,7 +258,7 @@ class Mailer:
         # actually sort our so-called rooms_in_order list, most recent room first
         rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
 
-        rooms = []
+        rooms: List[Dict[str, Any]] = []
 
         for r in rooms_in_order:
             roomvars = await self._get_room_vars(
@@ -362,6 +362,7 @@ class Mailer:
             "notifs": [],
             "invite": is_invite,
             "link": self._make_room_link(room_id),
+            "avatar_url": await self._get_room_avatar(room_state_ids),
         }
 
         if not is_invite:
@@ -393,6 +394,27 @@ class Mailer:
 
         return room_vars
 
+    async def _get_room_avatar(
+        self,
+        room_state_ids: StateMap[str],
+    ) -> Optional[str]:
+        """
+        Retrieve the avatar url for this room---if it exists.
+
+        Args:
+            room_state_ids: The event IDs of the current room state.
+
+        Returns:
+             room's avatar url if it's present and a string; otherwise None.
+        """
+        event_id = room_state_ids.get((EventTypes.RoomAvatar, ""))
+        if event_id:
+            ev = await self.store.get_event(event_id)
+            url = ev.content.get("url")
+            if isinstance(url, str):
+                return url
+        return None
+
     async def _get_notif_vars(
         self,
         notif: Dict[str, Any],
diff --git a/synapse/res/providers.json b/synapse/res/providers.json
new file mode 100644
index 0000000000..f1838f9559
--- /dev/null
+++ b/synapse/res/providers.json
@@ -0,0 +1,17 @@
+[
+    {
+        "provider_name": "Twitter",
+        "provider_url": "http://www.twitter.com/",
+        "endpoints": [
+            {
+                "schemes": [
+                    "https://twitter.com/*/status/*",
+                    "https://*.twitter.com/*/status/*",
+                    "https://twitter.com/*/moments/*",
+                    "https://*.twitter.com/*/moments/*"
+                ],
+                "url": "https://publish.twitter.com/oembed"
+            }
+        ]
+    }
+]
\ No newline at end of file
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 42201afc86..f5a38c2670 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
 
 from synapse.api.constants import EventTypes
 from synapse.api.errors import NotFoundError, SynapseError
@@ -101,7 +101,9 @@ class SendServerNoticeServlet(RestServlet):
 
         return 200, {"event_id": event.event_id}
 
-    def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
+    def on_PUT(
+        self, request: SynapseRequest, txn_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         return self.txns.fetch_or_execute_request(
             request, self.on_POST, request, txn_id
         )
diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py
index 0443f4571c..a0971ce994 100644
--- a/synapse/rest/client/_base.py
+++ b/synapse/rest/client/_base.py
@@ -16,7 +16,7 @@
 """
 import logging
 import re
-from typing import Iterable, Pattern
+from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast
 
 from synapse.api.errors import InteractiveAuthIncompleteError
 from synapse.api.urls import CLIENT_API_PREFIX
@@ -76,7 +76,10 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
         )
 
 
-def interactive_auth_handler(orig):
+C = TypeVar("C", bound=Callable[..., Awaitable[Tuple[int, JsonDict]]])
+
+
+def interactive_auth_handler(orig: C) -> C:
     """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
 
     Takes a on_POST method which returns an Awaitable (errcode, body) response
@@ -91,10 +94,10 @@ def interactive_auth_handler(orig):
         await self.auth_handler.check_auth
     """
 
-    async def wrapped(*args, **kwargs):
+    async def wrapped(*args: Any, **kwargs: Any) -> Tuple[int, JsonDict]:
         try:
             return await orig(*args, **kwargs)
         except InteractiveAuthIncompleteError as e:
             return 401, e.result
 
-    return wrapped
+    return cast(C, wrapped)
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index fb5ad2906e..aefaaa8ae8 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -16,9 +16,11 @@
 import logging
 import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Tuple
 from urllib.parse import urlparse
 
+from twisted.web.server import Request
+
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     Codes,
@@ -28,15 +30,17 @@ from synapse.api.errors import (
 )
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
-from synapse.http.server import finish_request, respond_with_html
+from synapse.http.server import HttpServer, finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.metrics import threepid_send_requests
 from synapse.push.mailer import Mailer
+from synapse.types import JsonDict
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import assert_valid_client_secret, random_string
 from synapse.util.threepids import check_3pid_allowed, validate_email
@@ -68,7 +72,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
                 template_text=self.config.email_password_reset_template_text,
             )
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
             if self.config.local_threepid_handling_disabled_due_to_email_config:
                 logger.warning(
@@ -159,7 +163,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
 class PasswordRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -169,7 +173,7 @@ class PasswordRestServlet(RestServlet):
         self._set_password_handler = hs.get_set_password_handler()
 
     @interactive_auth_handler
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
 
         # we do basic sanity checks here because the auth layer will store these
@@ -190,6 +194,7 @@ class PasswordRestServlet(RestServlet):
         #
         # In the second case, we require a password to confirm their identity.
 
+        requester = None
         if self.auth.has_access_token(request):
             requester = await self.auth.get_user_by_req(request)
             try:
@@ -206,16 +211,15 @@ class PasswordRestServlet(RestServlet):
                 # If a password is available now, hash the provided password and
                 # store it for later.
                 if new_password:
-                    password_hash = await self.auth_handler.hash(new_password)
+                    new_password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
                         e.session_id,
                         UIAuthSessionDataConstants.PASSWORD_HASH,
-                        password_hash,
+                        new_password_hash,
                     )
                 raise
             user_id = requester.user.to_string()
         else:
-            requester = None
             try:
                 result, params, session_id = await self.auth_handler.check_ui_auth(
                     [[LoginType.EMAIL_IDENTITY]],
@@ -230,11 +234,11 @@ class PasswordRestServlet(RestServlet):
                 # If a password is available now, hash the provided password and
                 # store it for later.
                 if new_password:
-                    password_hash = await self.auth_handler.hash(new_password)
+                    new_password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
                         e.session_id,
                         UIAuthSessionDataConstants.PASSWORD_HASH,
-                        password_hash,
+                        new_password_hash,
                     )
                 raise
 
@@ -264,7 +268,7 @@ class PasswordRestServlet(RestServlet):
         # If we have a password in this request, prefer it. Otherwise, use the
         # password hash from an earlier request.
         if new_password:
-            password_hash = await self.auth_handler.hash(new_password)
+            password_hash: Optional[str] = await self.auth_handler.hash(new_password)
         elif session_id is not None:
             password_hash = await self.auth_handler.get_session_data(
                 session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
@@ -288,7 +292,7 @@ class PasswordRestServlet(RestServlet):
 class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/deactivate$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -296,7 +300,7 @@ class DeactivateAccountRestServlet(RestServlet):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
     @interactive_auth_handler
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
         erase = body.get("erase", False)
         if not isinstance(erase, bool):
@@ -338,7 +342,7 @@ class DeactivateAccountRestServlet(RestServlet):
 class EmailThreepidRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/email/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.config = hs.config
@@ -353,7 +357,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 template_text=self.config.email_add_threepid_template_text,
             )
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
             if self.config.local_threepid_handling_disabled_due_to_email_config:
                 logger.warning(
@@ -449,7 +453,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
         self.store = self.hs.get_datastore()
         self.identity_handler = hs.get_identity_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
         assert_params_in_dict(
             body, ["client_secret", "country", "phone_number", "send_attempt"]
@@ -525,11 +529,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
         "/add_threepid/email/submit_token$", releases=(), unstable=True
     )
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.config = hs.config
         self.clock = hs.get_clock()
@@ -539,7 +539,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
                 self.config.email_add_threepid_template_failure_html
             )
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: Request) -> None:
         if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
             if self.config.local_threepid_handling_disabled_due_to_email_config:
                 logger.warning(
@@ -596,18 +596,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
         "/add_threepid/msisdn/submit_token$", releases=(), unstable=True
     )
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.config = hs.config
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.identity_handler = hs.get_identity_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         if not self.config.account_threepid_delegate_msisdn:
             raise SynapseError(
                 400,
@@ -632,7 +628,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
 class ThreepidRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
@@ -640,14 +636,14 @@ class ThreepidRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self.datastore = self.hs.get_datastore()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         threepids = await self.datastore.user_get_threepids(requester.user.to_string())
 
         return 200, {"threepids": threepids}
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if not self.hs.config.enable_3pid_changes:
             raise SynapseError(
                 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -688,7 +684,7 @@ class ThreepidRestServlet(RestServlet):
 class ThreepidAddRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/add$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
@@ -696,7 +692,7 @@ class ThreepidAddRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
 
     @interactive_auth_handler
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if not self.hs.config.enable_3pid_changes:
             raise SynapseError(
                 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -738,13 +734,13 @@ class ThreepidAddRestServlet(RestServlet):
 class ThreepidBindRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/bind$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
 
         assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
@@ -767,14 +763,14 @@ class ThreepidBindRestServlet(RestServlet):
 class ThreepidUnbindRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/unbind$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
         self.auth = hs.get_auth()
         self.datastore = self.hs.get_datastore()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         """Unbind the given 3pid from a specific identity server, or identity servers that are
         known to have this 3pid bound
         """
@@ -798,13 +794,13 @@ class ThreepidUnbindRestServlet(RestServlet):
 class ThreepidDeleteRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/delete$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if not self.hs.config.enable_3pid_changes:
             raise SynapseError(
                 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -835,7 +831,7 @@ class ThreepidDeleteRestServlet(RestServlet):
         return 200, {"id_server_unbind_result": id_server_unbind_result}
 
 
-def assert_valid_next_link(hs: "HomeServer", next_link: str):
+def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
     """
     Raises a SynapseError if a given next_link value is invalid
 
@@ -877,11 +873,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str):
 class WhoamiRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/whoami$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         response = {"user_id": requester.user.to_string()}
@@ -894,7 +890,7 @@ class WhoamiRestServlet(RestServlet):
         return 200, response
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     EmailPasswordRequestTokenRestServlet(hs).register(http_server)
     PasswordRestServlet(hs).register(http_server)
     DeactivateAccountRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index 7517e9304e..d1badbdf3b 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -13,12 +13,19 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import AuthError, NotFoundError, SynapseError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -32,13 +39,15 @@ class AccountDataServlet(RestServlet):
         "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.handler = hs.get_account_data_handler()
 
-    async def on_PUT(self, request, user_id, account_data_type):
+    async def on_PUT(
+        self, request: SynapseRequest, user_id: str, account_data_type: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
@@ -49,7 +58,9 @@ class AccountDataServlet(RestServlet):
 
         return 200, {}
 
-    async def on_GET(self, request, user_id, account_data_type):
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str, account_data_type: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot get account data for other users.")
@@ -76,13 +87,19 @@ class RoomAccountDataServlet(RestServlet):
         "/account_data/(?P<account_data_type>[^/]*)"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.handler = hs.get_account_data_handler()
 
-    async def on_PUT(self, request, user_id, room_id, account_data_type):
+    async def on_PUT(
+        self,
+        request: SynapseRequest,
+        user_id: str,
+        room_id: str,
+        account_data_type: str,
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
@@ -102,7 +119,13 @@ class RoomAccountDataServlet(RestServlet):
 
         return 200, {}
 
-    async def on_GET(self, request, user_id, room_id, account_data_type):
+    async def on_GET(
+        self,
+        request: SynapseRequest,
+        user_id: str,
+        room_id: str,
+        account_data_type: str,
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot get account data for other users.")
@@ -117,6 +140,6 @@ class RoomAccountDataServlet(RestServlet):
         return 200, event
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     AccountDataServlet(hs).register(http_server)
     RoomAccountDataServlet(hs).register(http_server)
diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py
index c3667ff8aa..a7e9aa3e9b 100644
--- a/synapse/rest/client/groups.py
+++ b/synapse/rest/client/groups.py
@@ -15,7 +15,7 @@
 
 import logging
 from functools import wraps
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
 
 from twisted.web.server import Request
 
@@ -43,14 +43,18 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-def _validate_group_id(f):
+def _validate_group_id(
+    f: Callable[..., Awaitable[Tuple[int, JsonDict]]]
+) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]:
     """Wrapper to validate the form of the group ID.
 
     Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
     """
 
     @wraps(f)
-    def wrapper(self, request: Request, group_id: str, *args, **kwargs):
+    def wrapper(
+        self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         if not GroupID.is_valid(group_id):
             raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
 
@@ -156,7 +160,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         group_id: str,
         category_id: Optional[str],
         room_id: str,
-    ):
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -188,7 +192,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
     @_validate_group_id
     async def on_DELETE(
         self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
-    ):
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -451,7 +455,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
     @_validate_group_id
     async def on_DELETE(
         self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
-    ):
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -674,7 +678,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
     @_validate_group_id
     async def on_PUT(
         self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
-    ):
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -706,7 +710,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
 
     @_validate_group_id
     async def on_PUT(
-        self, request: SynapseRequest, group_id, user_id
+        self, request: SynapseRequest, group_id: str, user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -738,7 +742,7 @@ class GroupAdminUsersKickServlet(RestServlet):
 
     @_validate_group_id
     async def on_PUT(
-        self, request: SynapseRequest, group_id, user_id
+        self, request: SynapseRequest, group_id: str, user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index 68fb08d0ba..0152a0c66a 100644
--- a/synapse/rest/client/knock.py
+++ b/synapse/rest/client/knock.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
 
 from twisted.web.server import Request
 
@@ -96,7 +96,9 @@ class KnockRoomAliasServlet(RestServlet):
 
         return 200, {"room_id": room_id}
 
-    def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
+    def on_PUT(
+        self, request: Request, room_identifier: str, txn_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         set_tag("txn_id", txn_id)
 
         return self.txns.fetch_or_execute_request(
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 702b351d18..fb3211bf3a 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -12,22 +12,40 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
+
+import attr
+
 from synapse.api.errors import (
     NotFoundError,
     StoreError,
     SynapseError,
     UnrecognizedRequestError,
 )
+from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
     parse_json_value_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.push.rulekinds import PRIORITY_CLASS_MAP
 from synapse.rest.client._base import client_patterns
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RuleSpec:
+    scope: str
+    template: str
+    rule_id: str
+    attr: Optional[str]
 
 
 class PushRuleRestServlet(RestServlet):
@@ -36,7 +54,7 @@ class PushRuleRestServlet(RestServlet):
         "Unrecognised request: You probably wanted a trailing slash"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
@@ -45,7 +63,7 @@ class PushRuleRestServlet(RestServlet):
 
         self._users_new_default_push_rules = hs.config.users_new_default_push_rules
 
-    async def on_PUT(self, request, path):
+    async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
         if self._is_worker:
             raise Exception("Cannot handle PUT /push_rules on worker")
 
@@ -57,25 +75,25 @@ class PushRuleRestServlet(RestServlet):
 
         requester = await self.auth.get_user_by_req(request)
 
-        if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
+        if "/" in spec.rule_id or "\\" in spec.rule_id:
             raise SynapseError(400, "rule_id may not contain slashes")
 
         content = parse_json_value_from_request(request)
 
         user_id = requester.user.to_string()
 
-        if "attr" in spec:
+        if spec.attr:
             await self.set_rule_attr(user_id, spec, content)
             self.notify_user(user_id)
             return 200, {}
 
-        if spec["rule_id"].startswith("."):
+        if spec.rule_id.startswith("."):
             # Rule ids starting with '.' are reserved for server default rules.
             raise SynapseError(400, "cannot add new rule_ids that start with '.'")
 
         try:
             (conditions, actions) = _rule_tuple_from_request_object(
-                spec["template"], spec["rule_id"], content
+                spec.template, spec.rule_id, content
             )
         except InvalidRuleException as e:
             raise SynapseError(400, str(e))
@@ -106,7 +124,9 @@ class PushRuleRestServlet(RestServlet):
 
         return 200, {}
 
-    async def on_DELETE(self, request, path):
+    async def on_DELETE(
+        self, request: SynapseRequest, path: str
+    ) -> Tuple[int, JsonDict]:
         if self._is_worker:
             raise Exception("Cannot handle DELETE /push_rules on worker")
 
@@ -127,7 +147,7 @@ class PushRuleRestServlet(RestServlet):
             else:
                 raise
 
-    async def on_GET(self, request, path):
+    async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
 
@@ -138,40 +158,42 @@ class PushRuleRestServlet(RestServlet):
 
         rules = format_push_rules_for_user(requester.user, rules)
 
-        path = path.split("/")[1:]
+        path_parts = path.split("/")[1:]
 
-        if path == []:
+        if path_parts == []:
             # we're a reference impl: pedantry is our job.
             raise UnrecognizedRequestError(
                 PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
             )
 
-        if path[0] == "":
+        if path_parts[0] == "":
             return 200, rules
-        elif path[0] == "global":
-            result = _filter_ruleset_with_path(rules["global"], path[1:])
+        elif path_parts[0] == "global":
+            result = _filter_ruleset_with_path(rules["global"], path_parts[1:])
             return 200, result
         else:
             raise UnrecognizedRequestError()
 
-    def notify_user(self, user_id):
+    def notify_user(self, user_id: str) -> None:
         stream_id = self.store.get_max_push_rules_stream_id()
         self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
 
-    async def set_rule_attr(self, user_id, spec, val):
-        if spec["attr"] not in ("enabled", "actions"):
+    async def set_rule_attr(
+        self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
+    ) -> None:
+        if spec.attr not in ("enabled", "actions"):
             # for the sake of potential future expansion, shouldn't report
             # 404 in the case of an unknown request so check it corresponds to
             # a known attribute first.
             raise UnrecognizedRequestError()
 
         namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
-        rule_id = spec["rule_id"]
+        rule_id = spec.rule_id
         is_default_rule = rule_id.startswith(".")
         if is_default_rule:
             if namespaced_rule_id not in BASE_RULE_IDS:
                 raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
-        if spec["attr"] == "enabled":
+        if spec.attr == "enabled":
             if isinstance(val, dict) and "enabled" in val:
                 val = val["enabled"]
             if not isinstance(val, bool):
@@ -179,14 +201,18 @@ class PushRuleRestServlet(RestServlet):
                 # This should *actually* take a dict, but many clients pass
                 # bools directly, so let's not break them.
                 raise SynapseError(400, "Value for 'enabled' must be boolean")
-            return await self.store.set_push_rule_enabled(
+            await self.store.set_push_rule_enabled(
                 user_id, namespaced_rule_id, val, is_default_rule
             )
-        elif spec["attr"] == "actions":
+        elif spec.attr == "actions":
+            if not isinstance(val, dict):
+                raise SynapseError(400, "Value must be a dict")
             actions = val.get("actions")
+            if not isinstance(actions, list):
+                raise SynapseError(400, "Value for 'actions' must be dict")
             _check_actions(actions)
             namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
-            rule_id = spec["rule_id"]
+            rule_id = spec.rule_id
             is_default_rule = rule_id.startswith(".")
             if is_default_rule:
                 if user_id in self._users_new_default_push_rules:
@@ -196,22 +222,21 @@ class PushRuleRestServlet(RestServlet):
 
                 if namespaced_rule_id not in rule_ids:
                     raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
-            return await self.store.set_push_rule_actions(
+            await self.store.set_push_rule_actions(
                 user_id, namespaced_rule_id, actions, is_default_rule
             )
         else:
             raise UnrecognizedRequestError()
 
 
-def _rule_spec_from_path(path):
+def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
     """Turn a sequence of path components into a rule spec
 
     Args:
-        path (sequence[unicode]): the URL path components.
+        path: the URL path components.
 
     Returns:
-        dict: rule spec dict, containing scope/template/rule_id entries,
-            and possibly attr.
+        rule spec, containing scope/template/rule_id entries, and possibly attr.
 
     Raises:
         UnrecognizedRequestError if the path components cannot be parsed.
@@ -237,17 +262,18 @@ def _rule_spec_from_path(path):
 
     rule_id = path[0]
 
-    spec = {"scope": scope, "template": template, "rule_id": rule_id}
-
     path = path[1:]
 
+    attr = None
     if len(path) > 0 and len(path[0]) > 0:
-        spec["attr"] = path[0]
+        attr = path[0]
 
-    return spec
+    return RuleSpec(scope, template, rule_id, attr)
 
 
-def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
+def _rule_tuple_from_request_object(
+    rule_template: str, rule_id: str, req_obj: JsonDict
+) -> Tuple[List[JsonDict], List[Union[str, JsonDict]]]:
     if rule_template in ["override", "underride"]:
         if "conditions" not in req_obj:
             raise InvalidRuleException("Missing 'conditions'")
@@ -277,7 +303,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
     return conditions, actions
 
 
-def _check_actions(actions):
+def _check_actions(actions: List[Union[str, JsonDict]]) -> None:
     if not isinstance(actions, list):
         raise InvalidRuleException("No actions found")
 
@@ -290,7 +316,7 @@ def _check_actions(actions):
             raise InvalidRuleException("Unrecognised action")
 
 
-def _filter_ruleset_with_path(ruleset, path):
+def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict:
     if path == []:
         raise UnrecognizedRequestError(
             PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
@@ -315,7 +341,7 @@ def _filter_ruleset_with_path(ruleset, path):
         if r["rule_id"] == rule_id:
             the_rule = r
     if the_rule is None:
-        raise NotFoundError
+        raise NotFoundError()
 
     path = path[1:]
     if len(path) == 0:
@@ -330,25 +356,25 @@ def _filter_ruleset_with_path(ruleset, path):
         raise UnrecognizedRequestError()
 
 
-def _priority_class_from_spec(spec):
-    if spec["template"] not in PRIORITY_CLASS_MAP.keys():
-        raise InvalidRuleException("Unknown template: %s" % (spec["template"]))
-    pc = PRIORITY_CLASS_MAP[spec["template"]]
+def _priority_class_from_spec(spec: RuleSpec) -> int:
+    if spec.template not in PRIORITY_CLASS_MAP.keys():
+        raise InvalidRuleException("Unknown template: %s" % (spec.template))
+    pc = PRIORITY_CLASS_MAP[spec.template]
 
     return pc
 
 
-def _namespaced_rule_id_from_spec(spec):
-    return _namespaced_rule_id(spec, spec["rule_id"])
+def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str:
+    return _namespaced_rule_id(spec, spec.rule_id)
 
 
-def _namespaced_rule_id(spec, rule_id):
-    return "global/%s/%s" % (spec["template"], rule_id)
+def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str:
+    return "global/%s/%s" % (spec.template, rule_id)
 
 
 class InvalidRuleException(Exception):
     pass
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     PushRuleRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index d9ab836cd8..9770413c61 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -13,13 +13,20 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.constants import ReadReceiptEventFields
 from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet):
         "/(?P<event_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_receipts_handler()
         self.presence_handler = hs.get_presence_handler()
 
-    async def on_POST(self, request, room_id, receipt_type, event_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         if receipt_type != "m.read":
@@ -67,5 +76,5 @@ class ReceiptRestServlet(RestServlet):
         return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 7b5f49d635..8f3dd2a101 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -14,7 +14,9 @@
 # limitations under the License.
 import logging
 import random
-from typing import List, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from twisted.web.server import Request
 
 import synapse
 import synapse.api.auth
@@ -29,15 +31,13 @@ from synapse.api.errors import (
 )
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.config import ConfigError
-from synapse.config.captcha import CaptchaConfig
-from synapse.config.consent import ConsentConfig
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.config.homeserver import HomeServerConfig
 from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.config.registration import RegistrationConfig
 from synapse.config.server import is_threepid_reserved
 from synapse.handlers.auth import AuthHandler
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
-from synapse.http.server import finish_request, respond_with_html
+from synapse.http.server import HttpServer, finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
@@ -45,6 +45,7 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.metrics import threepid_send_requests
 from synapse.push.mailer import Mailer
 from synapse.types import JsonDict
@@ -59,17 +60,16 @@ from synapse.util.threepids import (
 
 from ._base import client_patterns, interactive_auth_handler
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class EmailRegisterRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/register/email/requestToken$")
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
@@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
                 template_text=self.config.email_registration_template_text,
             )
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
             if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
                 logger.warning(
@@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
 class MsisdnRegisterRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/register/msisdn/requestToken$")
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
 
         assert_params_in_dict(
@@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
         "/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
     )
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
                 self.config.email_registration_template_failure_html
             )
 
-    async def on_GET(self, request, medium):
+    async def on_GET(self, request: Request, medium: str) -> None:
         if medium != "email":
             raise SynapseError(
                 400, "This medium is currently not supported for registration"
@@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
 class UsernameAvailabilityRestServlet(RestServlet):
     PATTERNS = client_patterns("/register/available")
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.registration_handler = hs.get_registration_handler()
@@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
             ),
         )
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
         if not self.hs.config.enable_registration:
             raise SynapseError(
                 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
@@ -387,11 +375,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
         unstable=True,
     )
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.store = hs.get_datastore()
@@ -402,7 +386,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
             burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
         )
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
         await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
 
         if not self.hs.config.enable_registration:
@@ -419,11 +403,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
 class RegisterRestServlet(RestServlet):
     PATTERNS = client_patterns("/register$")
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
 
         self.hs = hs
@@ -445,23 +425,21 @@ class RegisterRestServlet(RestServlet):
         )
 
     @interactive_auth_handler
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
 
         client_addr = request.getClientIP()
 
         await self.ratelimiter.ratelimit(None, client_addr, update=False)
 
-        kind = b"user"
-        if b"kind" in request.args:
-            kind = request.args[b"kind"][0]
+        kind = parse_string(request, "kind", default="user")
 
-        if kind == b"guest":
+        if kind == "guest":
             ret = await self._do_guest_registration(body, address=client_addr)
             return ret
-        elif kind != b"user":
+        elif kind != "user":
             raise UnrecognizedRequestError(
-                "Do not understand membership kind: %s" % (kind.decode("utf8"),)
+                f"Do not understand membership kind: {kind}",
             )
 
         if self._msc2918_enabled:
@@ -748,8 +726,12 @@ class RegisterRestServlet(RestServlet):
         return 200, return_dict
 
     async def _do_appservice_registration(
-        self, username, as_token, body, should_issue_refresh_token: bool = False
-    ):
+        self,
+        username: str,
+        as_token: str,
+        body: JsonDict,
+        should_issue_refresh_token: bool = False,
+    ) -> JsonDict:
         user_id = await self.registration_handler.appservice_register(
             username, as_token
         )
@@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet):
         params: JsonDict,
         is_appservice_ghost: bool = False,
         should_issue_refresh_token: bool = False,
-    ):
+    ) -> JsonDict:
         """Complete registration of newly-registered user
 
         Allocates device_id if one was not given; also creates access_token.
@@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet):
 
         return result
 
-    async def _do_guest_registration(self, params, address=None):
+    async def _do_guest_registration(
+        self, params: JsonDict, address: Optional[str] = None
+    ) -> Tuple[int, JsonDict]:
         if not self.hs.config.allow_guest_access:
             raise SynapseError(403, "Guest access is disabled")
         user_id = await self.registration_handler.register_user(
@@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet):
 
 
 def _calculate_registration_flows(
-    # technically `config` has to provide *all* of these interfaces, not just one
-    config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
-    auth_handler: AuthHandler,
+    config: HomeServerConfig, auth_handler: AuthHandler
 ) -> List[List[str]]:
     """Get a suitable flows list for registration
 
@@ -929,7 +911,7 @@ def _calculate_registration_flows(
     return flows
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     EmailRegisterRequestTokenRestServlet(hs).register(http_server)
     MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
     UsernameAvailabilityRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 0821cd285f..0b0711c03c 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,25 +19,32 @@ any time to reflect changes in the MSC.
 """
 
 import logging
+from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
 
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.api.errors import ShadowBanError, SynapseError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
     parse_integer,
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.storage.relations import (
     AggregationPaginationToken,
     PaginationChunk,
     RelationPaginationToken,
 )
+from synapse.types import JsonDict
 from synapse.util.stringutils import random_string
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet):
         "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.txns = HttpTransactionCache(hs)
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         http_server.register_paths(
             "POST",
             client_patterns(self.PATTERN + "$", releases=()),
@@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet):
             self.__class__.__name__,
         )
 
-    def on_PUT(self, request, *args, **kwargs):
+    def on_PUT(
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        parent_id: str,
+        relation_type: str,
+        event_type: str,
+        txn_id: Optional[str] = None,
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         return self.txns.fetch_or_execute_request(
-            request, self.on_PUT_or_POST, request, *args, **kwargs
+            request,
+            self.on_PUT_or_POST,
+            request,
+            room_id,
+            parent_id,
+            relation_type,
+            event_type,
+            txn_id,
         )
 
     async def on_PUT_or_POST(
-        self, request, room_id, parent_id, relation_type, event_type, txn_id=None
-    ):
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        parent_id: str,
+        relation_type: str,
+        event_type: str,
+        txn_id: Optional[str] = None,
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         if event_type == EventTypes.Member:
@@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet):
         releases=(),
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
@@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet):
         self.event_handler = hs.get_event_handler()
 
     async def on_GET(
-        self, request, room_id, parent_id, relation_type=None, event_type=None
-    ):
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        parent_id: str,
+        relation_type: Optional[str] = None,
+        event_type: Optional[str] = None,
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         await self.auth.check_user_in_room_or_world_readable(
@@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet):
         # This gets the original event and checks that a) the event exists and
         # b) the user is allowed to view it.
         event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+        if event is None:
+            raise SynapseError(404, "Unknown parent event.")
 
         limit = parse_integer(request, "limit", default=5)
         from_token_str = parse_string(request, "from")
@@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet):
         releases=(),
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.event_handler = hs.get_event_handler()
 
     async def on_GET(
-        self, request, room_id, parent_id, relation_type=None, event_type=None
-    ):
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        parent_id: str,
+        relation_type: Optional[str] = None,
+        event_type: Optional[str] = None,
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         await self.auth.check_user_in_room_or_world_readable(
@@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet):
         # This checks that a) the event exists and b) the user is allowed to
         # view it.
         event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+        if event is None:
+            raise SynapseError(404, "Unknown parent event.")
 
         if relation_type not in (RelationTypes.ANNOTATION, None):
             raise SynapseError(400, "Relation type must be 'annotation'")
@@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         releases=(),
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
@@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         self._event_serializer = hs.get_event_client_serializer()
         self.event_handler = hs.get_event_handler()
 
-    async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+    async def on_GET(
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        parent_id: str,
+        relation_type: str,
+        event_type: str,
+        key: str,
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         await self.auth.check_user_in_room_or_world_readable(
@@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         return 200, return_value
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RelationSendServlet(hs).register(http_server)
     RelationPaginationServlet(hs).register(http_server)
     RelationAggregationPaginationServlet(hs).register(http_server)
diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py
index 07ea39a8a3..d4a4adb50c 100644
--- a/synapse/rest/client/report_event.py
+++ b/synapse/rest/client/report_event.py
@@ -14,26 +14,35 @@
 
 import logging
 from http import HTTPStatus
+from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class ReportEventRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
-    async def on_POST(self, request, room_id, event_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str, event_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
 
@@ -64,5 +73,5 @@ class ReportEventRestServlet(RestServlet):
         return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ReportEventRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index c5c54564be..9b0c546505 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -16,9 +16,11 @@
 """ This module contains REST servlets to do with rooms: /rooms/<paths> """
 import logging
 import re
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
 from urllib import parse as urlparse
 
+from twisted.web.server import Request
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
@@ -30,6 +32,7 @@ from synapse.api.errors import (
 )
 from synapse.api.filtering import Filter
 from synapse.events.utils import format_event_for_client_v2
+from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     ResolveRoomIdMixin,
     RestServlet,
@@ -57,7 +60,7 @@ logger = logging.getLogger(__name__)
 
 
 class TransactionRestServlet(RestServlet):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.txns = HttpTransactionCache(hs)
 
@@ -65,20 +68,22 @@ class TransactionRestServlet(RestServlet):
 class RoomCreateRestServlet(TransactionRestServlet):
     # No PATTERN; we have custom dispatch rules here
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._room_creation_handler = hs.get_room_creation_handler()
         self.auth = hs.get_auth()
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         PATTERNS = "/createRoom"
         register_txn_path(self, PATTERNS, http_server)
 
-    def on_PUT(self, request, txn_id):
+    def on_PUT(
+        self, request: SynapseRequest, txn_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         set_tag("txn_id", txn_id)
         return self.txns.fetch_or_execute_request(request, self.on_POST, request)
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         info, _ = await self._room_creation_handler.create_room(
@@ -87,21 +92,21 @@ class RoomCreateRestServlet(TransactionRestServlet):
 
         return 200, info
 
-    def get_room_config(self, request):
+    def get_room_config(self, request: Request) -> JsonDict:
         user_supplied_config = parse_json_object_from_request(request)
         return user_supplied_config
 
 
 # TODO: Needs unit testing for generic events
 class RoomStateEventRestServlet(TransactionRestServlet):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.event_creation_handler = hs.get_event_creation_handler()
         self.room_member_handler = hs.get_room_member_handler()
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         # /room/$roomid/state/$eventtype
         no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
 
@@ -136,13 +141,19 @@ class RoomStateEventRestServlet(TransactionRestServlet):
             self.__class__.__name__,
         )
 
-    def on_GET_no_state_key(self, request, room_id, event_type):
+    def on_GET_no_state_key(
+        self, request: SynapseRequest, room_id: str, event_type: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         return self.on_GET(request, room_id, event_type, "")
 
-    def on_PUT_no_state_key(self, request, room_id, event_type):
+    def on_PUT_no_state_key(
+        self, request: SynapseRequest, room_id: str, event_type: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         return self.on_PUT(request, room_id, event_type, "")
 
-    async def on_GET(self, request, room_id, event_type, state_key):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str, event_type: str, state_key: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         format = parse_string(
             request, "format", default="content", allowed_values=["content", "event"]
@@ -165,7 +176,17 @@ class RoomStateEventRestServlet(TransactionRestServlet):
         elif format == "content":
             return 200, data.get_dict()["content"]
 
-    async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+        # Format must be event or content, per the parse_string call above.
+        raise RuntimeError(f"Unknown format: {format:r}.")
+
+    async def on_PUT(
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        event_type: str,
+        state_key: str,
+        txn_id: Optional[str] = None,
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         if txn_id:
@@ -211,27 +232,35 @@ class RoomStateEventRestServlet(TransactionRestServlet):
 
 # TODO: Needs unit testing for generic events + feedback
 class RoomSendEventRestServlet(TransactionRestServlet):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.event_creation_handler = hs.get_event_creation_handler()
         self.auth = hs.get_auth()
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         # /rooms/$roomid/send/$event_type[/$txn_id]
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
         register_txn_path(self, PATTERNS, http_server, with_get=True)
 
-    async def on_POST(self, request, room_id, event_type, txn_id=None):
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        event_type: str,
+        txn_id: Optional[str] = None,
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         content = parse_json_object_from_request(request)
 
-        event_dict = {
+        event_dict: JsonDict = {
             "type": event_type,
             "content": content,
             "room_id": room_id,
             "sender": requester.user.to_string(),
         }
 
+        # Twisted will have processed the args by now.
+        assert request.args is not None
         if b"ts" in request.args and requester.app_service:
             event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
 
@@ -249,10 +278,14 @@ class RoomSendEventRestServlet(TransactionRestServlet):
         set_tag("event_id", event_id)
         return 200, {"event_id": event_id}
 
-    def on_GET(self, request, room_id, event_type, txn_id):
+    def on_GET(
+        self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
+    ) -> Tuple[int, str]:
         return 200, "Not implemented"
 
-    def on_PUT(self, request, room_id, event_type, txn_id):
+    def on_PUT(
+        self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         set_tag("txn_id", txn_id)
 
         return self.txns.fetch_or_execute_request(
@@ -262,12 +295,12 @@ class RoomSendEventRestServlet(TransactionRestServlet):
 
 # TODO: Needs unit testing for room ID + alias joins
 class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         super(ResolveRoomIdMixin, self).__init__(hs)  # ensure the Mixin is set up
         self.auth = hs.get_auth()
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         # /join/$room_identifier[/$txn_id]
         PATTERNS = "/join/(?P<room_identifier>[^/]*)"
         register_txn_path(self, PATTERNS, http_server)
@@ -277,7 +310,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
         request: SynapseRequest,
         room_identifier: str,
         txn_id: Optional[str] = None,
-    ):
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         try:
@@ -308,7 +341,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
 
         return 200, {"room_id": room_id}
 
-    def on_PUT(self, request, room_identifier, txn_id):
+    def on_PUT(
+        self, request: SynapseRequest, room_identifier: str, txn_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         set_tag("txn_id", txn_id)
 
         return self.txns.fetch_or_execute_request(
@@ -320,12 +355,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
 class PublicRoomListRestServlet(TransactionRestServlet):
     PATTERNS = client_patterns("/publicRooms$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         server = parse_string(request, "server")
 
         try:
@@ -374,7 +409,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         return 200, data
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         server = parse_string(request, "server")
@@ -438,13 +473,15 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 class RoomMemberListRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         # TODO support Pagination stream API (limit/tokens)
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         handler = self.message_handler
@@ -490,12 +527,14 @@ class RoomMemberListRestServlet(RestServlet):
 class JoinedRoomMemberListRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         users_with_profile = await self.message_handler.get_joined_members(
@@ -509,17 +548,21 @@ class JoinedRoomMemberListRestServlet(RestServlet):
 class RoomMessageListRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.pagination_handler = hs.get_pagination_handler()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         pagination_config = await PaginationConfig.from_request(
             self.store, request, default_limit=10
         )
+        # Twisted will have processed the args by now.
+        assert request.args is not None
         as_client_event = b"raw" not in request.args
         filter_str = parse_string(request, "filter", encoding="utf-8")
         if filter_str:
@@ -549,12 +592,14 @@ class RoomMessageListRestServlet(RestServlet):
 class RoomStateRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, List[JsonDict]]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         # Get all the current state for this room
         events = await self.message_handler.get_state_events(
@@ -569,13 +614,15 @@ class RoomStateRestServlet(RestServlet):
 class RoomInitialSyncRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.initial_sync_handler = hs.get_initial_sync_handler()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         pagination_config = await PaginationConfig.from_request(self.store, request)
         content = await self.initial_sync_handler.room_initial_sync(
@@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet):
         "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.clock = hs.get_clock()
         self.event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, room_id, event_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str, event_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         try:
             event = await self.event_handler.get_event(
@@ -610,10 +659,10 @@ class RoomEventServlet(RestServlet):
 
         time_now = self.clock.time_msec()
         if event:
-            event = await self._event_serializer.serialize_event(event, time_now)
-            return 200, event
+            event_dict = await self._event_serializer.serialize_event(event, time_now)
+            return 200, event_dict
 
-        return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+        raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
 
 
 class RoomEventContextServlet(RestServlet):
@@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet):
         "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.clock = hs.get_clock()
         self.room_context_handler = hs.get_room_context_handler()
         self._event_serializer = hs.get_event_client_serializer()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, room_id, event_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str, event_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         limit = parse_integer(request, "limit", default=10)
@@ -669,23 +720,27 @@ class RoomEventContextServlet(RestServlet):
 
 
 class RoomForgetRestServlet(TransactionRestServlet):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.room_member_handler = hs.get_room_member_handler()
         self.auth = hs.get_auth()
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
         register_txn_path(self, PATTERNS, http_server)
 
-    async def on_POST(self, request, room_id, txn_id=None):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
 
         await self.room_member_handler.forget(user=requester.user, room_id=room_id)
 
         return 200, {}
 
-    def on_PUT(self, request, room_id, txn_id):
+    def on_PUT(
+        self, request: SynapseRequest, room_id: str, txn_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         set_tag("txn_id", txn_id)
 
         return self.txns.fetch_or_execute_request(
@@ -695,12 +750,12 @@ class RoomForgetRestServlet(TransactionRestServlet):
 
 # TODO: Needs unit testing
 class RoomMembershipRestServlet(TransactionRestServlet):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.room_member_handler = hs.get_room_member_handler()
         self.auth = hs.get_auth()
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         # /rooms/$roomid/[invite|join|leave]
         PATTERNS = (
             "/rooms/(?P<room_id>[^/]*)/"
@@ -708,7 +763,13 @@ class RoomMembershipRestServlet(TransactionRestServlet):
         )
         register_txn_path(self, PATTERNS, http_server)
 
-    async def on_POST(self, request, room_id, membership_action, txn_id=None):
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        membership_action: str,
+        txn_id: Optional[str] = None,
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         if requester.is_guest and membership_action not in {
@@ -771,13 +832,15 @@ class RoomMembershipRestServlet(TransactionRestServlet):
 
         return 200, return_value
 
-    def _has_3pid_invite_keys(self, content):
+    def _has_3pid_invite_keys(self, content: JsonDict) -> bool:
         for key in {"id_server", "medium", "address"}:
             if key not in content:
                 return False
         return True
 
-    def on_PUT(self, request, room_id, membership_action, txn_id):
+    def on_PUT(
+        self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         set_tag("txn_id", txn_id)
 
         return self.txns.fetch_or_execute_request(
@@ -786,16 +849,22 @@ class RoomMembershipRestServlet(TransactionRestServlet):
 
 
 class RoomRedactEventRestServlet(TransactionRestServlet):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.event_creation_handler = hs.get_event_creation_handler()
         self.auth = hs.get_auth()
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
         register_txn_path(self, PATTERNS, http_server)
 
-    async def on_POST(self, request, room_id, event_id, txn_id=None):
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        event_id: str,
+        txn_id: Optional[str] = None,
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         content = parse_json_object_from_request(request)
 
@@ -821,7 +890,9 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
         set_tag("event_id", event_id)
         return 200, {"event_id": event_id}
 
-    def on_PUT(self, request, room_id, event_id, txn_id):
+    def on_PUT(
+        self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         set_tag("txn_id", txn_id)
 
         return self.txns.fetch_or_execute_request(
@@ -846,7 +917,9 @@ class RoomTypingRestServlet(RestServlet):
             hs.config.worker.writers.typing == hs.get_instance_name()
         )
 
-    async def on_PUT(self, request, room_id, user_id):
+    async def on_PUT(
+        self, request: SynapseRequest, room_id: str, user_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         if not self._is_typing_writer:
@@ -897,7 +970,9 @@ class RoomAliasListServlet(RestServlet):
         self.auth = hs.get_auth()
         self.directory_handler = hs.get_directory_handler()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         alias_list = await self.directory_handler.get_aliases_for_room(
@@ -910,12 +985,12 @@ class RoomAliasListServlet(RestServlet):
 class SearchRestServlet(RestServlet):
     PATTERNS = client_patterns("/search$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.search_handler = hs.get_search_handler()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         content = parse_json_object_from_request(request)
@@ -929,19 +1004,24 @@ class SearchRestServlet(RestServlet):
 class JoinedRoomsRestServlet(RestServlet):
     PATTERNS = client_patterns("/joined_rooms$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
         return 200, {"joined_rooms": list(room_ids)}
 
 
-def register_txn_path(servlet, regex_string, http_server, with_get=False):
+def register_txn_path(
+    servlet: RestServlet,
+    regex_string: str,
+    http_server: HttpServer,
+    with_get: bool = False,
+) -> None:
     """Registers a transaction-based path.
 
     This registers two paths:
@@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
         POST regex_string
 
     Args:
-        regex_string (str): The regex string to register. Must NOT have a
-        trailing $ as this string will be appended to.
-        http_server : The http_server to register paths with.
+        regex_string: The regex string to register. Must NOT have a
+            trailing $ as this string will be appended to.
+        http_server: The http_server to register paths with.
         with_get: True to also register respective GET paths for the PUTs.
     """
+    on_POST = getattr(servlet, "on_POST", None)
+    on_PUT = getattr(servlet, "on_PUT", None)
+    if on_POST is None or on_PUT is None:
+        raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path")
     http_server.register_paths(
         "POST",
         client_patterns(regex_string + "$", v1=True),
-        servlet.on_POST,
+        on_POST,
         servlet.__class__.__name__,
     )
     http_server.register_paths(
         "PUT",
         client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
-        servlet.on_PUT,
+        on_PUT,
         servlet.__class__.__name__,
     )
+    on_GET = getattr(servlet, "on_GET", None)
     if with_get:
+        if on_GET is None:
+            raise RuntimeError(
+                "register_txn_path called with with_get = True, but no on_GET method exists"
+            )
         http_server.register_paths(
             "GET",
             client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
-            servlet.on_GET,
+            on_GET,
             servlet.__class__.__name__,
         )
 
@@ -1120,7 +1209,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
         )
 
 
-def register_servlets(hs: "HomeServer", http_server, is_worker=False):
+def register_servlets(
+    hs: "HomeServer", http_server: HttpServer, is_worker: bool = False
+) -> None:
     RoomStateEventRestServlet(hs).register(http_server)
     RoomMemberListRestServlet(hs).register(http_server)
     JoinedRoomMemberListRestServlet(hs).register(http_server)
@@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
         RoomForgetRestServlet(hs).register(http_server)
 
 
-def register_deprecated_servlets(hs, http_server):
+def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RoomInitialSyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 3172aba605..ed96978448 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -14,10 +14,14 @@
 
 import logging
 import re
+from typing import TYPE_CHECKING, Awaitable, List, Tuple
+
+from twisted.web.server import Request
 
 from synapse.api.constants import EventContentFields, EventTypes
 from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.appservice import ApplicationService
+from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
@@ -25,10 +29,14 @@ from synapse.http.servlet import (
     parse_string,
     parse_strings_from_args,
 )
+from synapse.http.site import SynapseRequest
 from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
 from synapse.util.stringutils import random_string
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -66,7 +74,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
         ),
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.store = hs.get_datastore()
@@ -76,7 +84,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.txns = HttpTransactionCache(hs)
 
-    async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
+    async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int:
         (
             most_recent_prev_event_id,
             most_recent_prev_event_depth,
@@ -118,7 +126,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
 
     def _create_insertion_event_dict(
         self, sender: str, room_id: str, origin_server_ts: int
-    ):
+    ) -> JsonDict:
         """Creates an event dict for an "insertion" event with the proper fields
         and a random chunk ID.
 
@@ -128,7 +136,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
             origin_server_ts: Timestamp when the event was sent
 
         Returns:
-            Tuple of event ID and stream ordering position
+            The new event dictionary to insert.
         """
 
         next_chunk_id = random_string(8)
@@ -164,7 +172,9 @@ class RoomBatchSendEventRestServlet(RestServlet):
 
         return create_requester(user_id, app_service=app_service)
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
 
         if not requester.app_service:
@@ -176,6 +186,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
         body = parse_json_object_from_request(request)
         assert_params_in_dict(body, ["state_events_at_start", "events"])
 
+        assert request.args is not None
         prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
         chunk_id_from_query = parse_string(request, "chunk_id")
 
@@ -425,16 +436,18 @@ class RoomBatchSendEventRestServlet(RestServlet):
             ],
         }
 
-    def on_GET(self, request, room_id):
+    def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
         return 501, "Not implemented"
 
-    def on_PUT(self, request, room_id):
+    def on_PUT(
+        self, request: SynapseRequest, room_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         return self.txns.fetch_or_execute_request(
             request, self.on_POST, request, room_id
         )
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     msc2716_enabled = hs.config.experimental.msc2716_enabled
 
     if msc2716_enabled:
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index 263596be86..37e39570f6 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -13,16 +13,23 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -31,16 +38,14 @@ class RoomKeysServlet(RestServlet):
         "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
     )
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
-    async def on_PUT(self, request, room_id, session_id):
+    async def on_PUT(
+        self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
+    ) -> Tuple[int, JsonDict]:
         """
         Uploads one or more encrypted E2E room keys for backup purposes.
         room_id: the ID of the room the keys are for (optional)
@@ -133,7 +138,9 @@ class RoomKeysServlet(RestServlet):
         ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
         return 200, ret
 
-    async def on_GET(self, request, room_id, session_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
+    ) -> Tuple[int, JsonDict]:
         """
         Retrieves one or more encrypted E2E room keys for backup purposes.
         Symmetric with the PUT version of the API.
@@ -215,7 +222,9 @@ class RoomKeysServlet(RestServlet):
 
         return 200, room_keys
 
-    async def on_DELETE(self, request, room_id, session_id):
+    async def on_DELETE(
+        self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
+    ) -> Tuple[int, JsonDict]:
         """
         Deletes one or more encrypted E2E room keys for a user for backup purposes.
 
@@ -242,16 +251,12 @@ class RoomKeysServlet(RestServlet):
 class RoomKeysNewVersionServlet(RestServlet):
     PATTERNS = client_patterns("/room_keys/version$")
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         """
         Create a new backup version for this user's room_keys with the given
         info.  The version is allocated by the server and returned to the user
@@ -295,16 +300,14 @@ class RoomKeysNewVersionServlet(RestServlet):
 class RoomKeysVersionServlet(RestServlet):
     PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$")
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
-    async def on_GET(self, request, version):
+    async def on_GET(
+        self, request: SynapseRequest, version: Optional[str]
+    ) -> Tuple[int, JsonDict]:
         """
         Retrieve the version information about a given version of the user's
         room_keys backup.  If the version part is missing, returns info about the
@@ -332,7 +335,9 @@ class RoomKeysVersionServlet(RestServlet):
                 raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
         return 200, info
 
-    async def on_DELETE(self, request, version):
+    async def on_DELETE(
+        self, request: SynapseRequest, version: Optional[str]
+    ) -> Tuple[int, JsonDict]:
         """
         Delete the information about a given version of the user's
         room_keys backup.  If the version part is missing, deletes the most
@@ -351,7 +356,9 @@ class RoomKeysVersionServlet(RestServlet):
         await self.e2e_room_keys_handler.delete_version(user_id, version)
         return 200, {}
 
-    async def on_PUT(self, request, version):
+    async def on_PUT(
+        self, request: SynapseRequest, version: Optional[str]
+    ) -> Tuple[int, JsonDict]:
         """
         Update the information about a given version of the user's room_keys backup.
 
@@ -385,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet):
         return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RoomKeysServlet(hs).register(http_server)
     RoomKeysVersionServlet(hs).register(http_server)
     RoomKeysNewVersionServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index d537d811d8..3322c8ef48 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -13,15 +13,21 @@
 # limitations under the License.
 
 import logging
-from typing import Tuple
+from typing import TYPE_CHECKING, Awaitable, Tuple
 
 from synapse.http import servlet
+from synapse.http.server import HttpServer
 from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import set_tag, trace
 from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -30,11 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -42,14 +44,18 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         self.device_message_handler = hs.get_device_message_handler()
 
     @trace(opname="sendToDevice")
-    def on_PUT(self, request, message_type, txn_id):
+    def on_PUT(
+        self, request: SynapseRequest, message_type: str, txn_id: str
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         set_tag("message_type", message_type)
         set_tag("txn_id", txn_id)
         return self.txns.fetch_or_execute_request(
             request, self._put, request, message_type, txn_id
         )
 
-    async def _put(self, request, message_type, txn_id):
+    async def _put(
+        self, request: SynapseRequest, message_type: str, txn_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         content = parse_json_object_from_request(request)
@@ -59,9 +65,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
             requester, message_type, content["messages"]
         )
 
-        response: Tuple[int, dict] = (200, {})
-        return response
+        return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     SendToDeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 65c37be3e9..1259058b9b 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -14,12 +14,24 @@
 import itertools
 import logging
 from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 from synapse.api.constants import Membership, PresenceState
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
 from synapse.api.presence import UserPresenceState
+from synapse.events import EventBase
 from synapse.events.utils import (
     format_event_for_client_v2_without_room_id,
     format_event_raw,
@@ -504,7 +516,7 @@ class SyncRestServlet(RestServlet):
             The room, encoded in our response format
         """
 
-        def serialize(events):
+        def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]:
             return self._event_serializer.serialize_events(
                 events,
                 time_now=time_now,
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 94ff3719ce..914fb3acf5 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -15,28 +15,37 @@
 """This module contains logic for storing HTTP PUT transactions. This is used
 to ensure idempotency when performing PUTs using the REST API."""
 import logging
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple
+
+from twisted.python.failure import Failure
+from twisted.web.server import Request
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import JsonDict
 from synapse.util.async_helpers import ObservableDeferred
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 CLEANUP_PERIOD_MS = 1000 * 60 * 30  # 30 mins
 
 
 class HttpTransactionCache:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = self.hs.get_auth()
         self.clock = self.hs.get_clock()
-        self.transactions = {
-            # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
-        }
+        # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
+        self.transactions: Dict[
+            str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
+        ] = {}
         # Try to clean entries every 30 mins. This means entries will exist
         # for at *LEAST* 30 mins, and at *MOST* 60 mins.
         self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
 
-    def _get_transaction_key(self, request):
+    def _get_transaction_key(self, request: Request) -> str:
         """A helper function which returns a transaction key that can be used
         with TransactionCache for idempotent requests.
 
@@ -45,15 +54,21 @@ class HttpTransactionCache:
         path and the access_token for the requesting user.
 
         Args:
-            request (twisted.web.http.Request): The incoming request. Must
-            contain an access_token.
+            request: The incoming request. Must contain an access_token.
         Returns:
-            str: A transaction key
+            A transaction key
         """
+        assert request.path is not None
         token = self.auth.get_access_token_from_request(request)
         return request.path.decode("utf8") + "/" + token
 
-    def fetch_or_execute_request(self, request, fn, *args, **kwargs):
+    def fetch_or_execute_request(
+        self,
+        request: Request,
+        fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
+        *args: Any,
+        **kwargs: Any,
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         """A helper function for fetch_or_execute which extracts
         a transaction key from the given request.
 
@@ -64,15 +79,20 @@ class HttpTransactionCache:
             self._get_transaction_key(request), fn, *args, **kwargs
         )
 
-    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
+    def fetch_or_execute(
+        self,
+        txn_key: str,
+        fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
+        *args: Any,
+        **kwargs: Any,
+    ) -> Awaitable[Tuple[int, JsonDict]]:
         """Fetches the response for this transaction, or executes the given function
         to produce a response for this transaction.
 
         Args:
-            txn_key (str): A key to ensure idempotency should fetch_or_execute be
-            called again at a later point in time.
-            fn (function): A function which returns a tuple of
-            (response_code, response_dict).
+            txn_key: A key to ensure idempotency should fetch_or_execute be
+                called again at a later point in time.
+            fn: A function which returns a tuple of (response_code, response_dict).
             *args: Arguments to pass to fn.
             **kwargs: Keyword arguments to pass to fn.
         Returns:
@@ -90,7 +110,7 @@ class HttpTransactionCache:
             # if the request fails with an exception, remove it
             # from the transaction map. This is done to ensure that we don't
             # cache transient errors like rate-limiting errors, etc.
-            def remove_from_map(err):
+            def remove_from_map(err: Failure) -> None:
                 self.transactions.pop(txn_key, None)
                 # we deliberately do not propagate the error any further, as we
                 # expect the observers to have reported it.
@@ -99,7 +119,7 @@ class HttpTransactionCache:
 
         return make_deferred_yieldable(observable.observe())
 
-    def _cleanup(self):
+    def _cleanup(self) -> None:
         now = self.clock.time_msec()
         for key in list(self.transactions):
             ts = self.transactions[key][1]
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
new file mode 100644
index 0000000000..afe41823e4
--- /dev/null
+++ b/synapse/rest/media/v1/oembed.py
@@ -0,0 +1,135 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Optional
+
+import attr
+
+from synapse.http.client import SimpleHttpClient
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, auto_attribs=True)
+class OEmbedResult:
+    # Either HTML content or URL must be provided.
+    html: Optional[str]
+    url: Optional[str]
+    title: Optional[str]
+    # Number of seconds to cache the content.
+    cache_age: int
+
+
+class OEmbedError(Exception):
+    """An error occurred processing the oEmbed object."""
+
+
+class OEmbedProvider:
+    """
+    A helper for accessing oEmbed content.
+
+    It can be used to check if a URL should be accessed via oEmbed and for
+    requesting/parsing oEmbed content.
+    """
+
+    def __init__(self, hs: "HomeServer", client: SimpleHttpClient):
+        self._oembed_patterns = {}
+        for oembed_endpoint in hs.config.oembed.oembed_patterns:
+            for pattern in oembed_endpoint.url_patterns:
+                self._oembed_patterns[pattern] = oembed_endpoint.api_endpoint
+        self._client = client
+
+    def get_oembed_url(self, url: str) -> Optional[str]:
+        """
+        Check whether the URL should be downloaded as oEmbed content instead.
+
+        Args:
+            url: The URL to check.
+
+        Returns:
+            A URL to use instead or None if the original URL should be used.
+        """
+        for url_pattern, endpoint in self._oembed_patterns.items():
+            if url_pattern.fullmatch(url):
+                return endpoint
+
+        # No match.
+        return None
+
+    async def get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+        """
+        Request content from an oEmbed endpoint.
+
+        Args:
+            endpoint: The oEmbed API endpoint.
+            url: The URL to pass to the API.
+
+        Returns:
+            An object representing the metadata returned.
+
+        Raises:
+            OEmbedError if fetching or parsing of the oEmbed information fails.
+        """
+        try:
+            logger.debug("Trying to get oEmbed content for url '%s'", url)
+            result = await self._client.get_json(
+                endpoint,
+                # TODO Specify max height / width.
+                # Note that only the JSON format is supported.
+                args={"url": url},
+            )
+
+            # Ensure there's a version of 1.0.
+            if result.get("version") != "1.0":
+                raise OEmbedError("Invalid version: %s" % (result.get("version"),))
+
+            oembed_type = result.get("type")
+
+            # Ensure the cache age is None or an int.
+            cache_age = result.get("cache_age")
+            if cache_age:
+                cache_age = int(cache_age)
+
+            oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+
+            # HTML content.
+            if oembed_type == "rich":
+                oembed_result.html = result.get("html")
+                return oembed_result
+
+            if oembed_type == "photo":
+                oembed_result.url = result.get("url")
+                return oembed_result
+
+            # TODO Handle link and video types.
+
+            if "thumbnail_url" in result:
+                oembed_result.url = result.get("thumbnail_url")
+                return oembed_result
+
+            raise OEmbedError("Incompatible oEmbed information.")
+
+        except OEmbedError as e:
+            # Trap OEmbedErrors first so we can directly re-raise them.
+            logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
+            raise
+
+        except Exception as e:
+            # Trap any exception and let the code follow as usual.
+            # FIXME: pass through 404s and other error messages nicely
+            logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
+            raise OEmbedError() from e
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 0f051d4041..f108da05db 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -22,7 +22,7 @@ import re
 import shutil
 import sys
 import traceback
-from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Union
 from urllib import parse as urlparse
 
 import attr
@@ -43,6 +43,8 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
 from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.oembed import OEmbedError, OEmbedProvider
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -71,65 +73,44 @@ OG_TAG_VALUE_MAXLEN = 1000
 
 ONE_HOUR = 60 * 60 * 1000
 
-# A map of globs to API endpoints.
-_oembed_globs = {
-    # Twitter.
-    "https://publish.twitter.com/oembed": [
-        "https://twitter.com/*/status/*",
-        "https://*.twitter.com/*/status/*",
-        "https://twitter.com/*/moments/*",
-        "https://*.twitter.com/*/moments/*",
-        # Include the HTTP versions too.
-        "http://twitter.com/*/status/*",
-        "http://*.twitter.com/*/status/*",
-        "http://twitter.com/*/moments/*",
-        "http://*.twitter.com/*/moments/*",
-    ],
-}
-# Convert the globs to regular expressions.
-_oembed_patterns = {}
-for endpoint, globs in _oembed_globs.items():
-    for glob in globs:
-        # Convert the glob into a sane regular expression to match against. The
-        # rules followed will be slightly different for the domain portion vs.
-        # the rest.
-        #
-        # 1. The scheme must be one of HTTP / HTTPS (and have no globs).
-        # 2. The domain can have globs, but we limit it to characters that can
-        #    reasonably be a domain part.
-        #    TODO: This does not attempt to handle Unicode domain names.
-        # 3. Other parts allow a glob to be any one, or more, characters.
-        results = urlparse.urlparse(glob)
-
-        # Ensure the scheme does not have wildcards (and is a sane scheme).
-        if results.scheme not in {"http", "https"}:
-            raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
-
-        pattern = urlparse.urlunparse(
-            [
-                results.scheme,
-                re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
-            ]
-            + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
-        )
-        _oembed_patterns[re.compile(pattern)] = endpoint
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class MediaInfo:
+    """
+    Information parsed from downloading media being previewed.
+    """
 
-@attr.s(slots=True)
-class OEmbedResult:
-    # Either HTML content or URL must be provided.
-    html = attr.ib(type=Optional[str])
-    url = attr.ib(type=Optional[str])
-    title = attr.ib(type=Optional[str])
-    # Number of seconds to cache the content.
-    cache_age = attr.ib(type=int)
+    # The Content-Type header of the response.
+    media_type: str
+    # The length (in bytes) of the downloaded media.
+    media_length: int
+    # The media filename, according to the server. This is parsed from the
+    # returned headers, if possible.
+    download_name: Optional[str]
+    # The time of the preview.
+    created_ts_ms: int
+    # Information from the media storage provider about where the file is stored
+    # on disk.
+    filesystem_id: str
+    filename: str
+    # The URI being previewed.
+    uri: str
+    # The HTTP response code.
+    response_code: int
+    # The timestamp (in milliseconds) of when this preview expires.
+    expires: int
+    # The ETag header of the response.
+    etag: Optional[str]
 
 
-class OEmbedError(Exception):
-    """An error occurred processing the oEmbed object."""
+class PreviewUrlResource(DirectServeJsonResource):
+    """
+    Generating URL previews is a complicated task which many potential pitfalls.
 
+    See docs/development/url_previews.md for discussion of the design and
+    algorithm followed in this module.
+    """
 
-class PreviewUrlResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(
@@ -157,6 +138,8 @@ class PreviewUrlResource(DirectServeJsonResource):
         self.primary_base_path = media_repo.primary_base_path
         self.media_storage = media_storage
 
+        self._oembed = OEmbedProvider(hs, self.client)
+
         # We run the background jobs if we're the instance specified (or no
         # instance is specified, where we assume there is only one instance
         # serving media).
@@ -275,18 +258,17 @@ class PreviewUrlResource(DirectServeJsonResource):
 
         logger.debug("got media_info of '%s'", media_info)
 
-        if _is_media(media_info["media_type"]):
-            file_id = media_info["filesystem_id"]
+        if _is_media(media_info.media_type):
+            file_id = media_info.filesystem_id
             dims = await self.media_repo._generate_thumbnails(
-                None, file_id, file_id, media_info["media_type"], url_cache=True
+                None, file_id, file_id, media_info.media_type, url_cache=True
             )
 
             og = {
-                "og:description": media_info["download_name"],
-                "og:image": "mxc://%s/%s"
-                % (self.server_name, media_info["filesystem_id"]),
-                "og:image:type": media_info["media_type"],
-                "matrix:image:size": media_info["media_length"],
+                "og:description": media_info.download_name,
+                "og:image": f"mxc://{self.server_name}/{media_info.filesystem_id}",
+                "og:image:type": media_info.media_type,
+                "matrix:image:size": media_info.media_length,
             }
 
             if dims:
@@ -296,14 +278,14 @@ class PreviewUrlResource(DirectServeJsonResource):
                 logger.warning("Couldn't get dims for %s" % url)
 
             # define our OG response for this media
-        elif _is_html(media_info["media_type"]):
+        elif _is_html(media_info.media_type):
             # TODO: somehow stop a big HTML tree from exploding synapse's RAM
 
-            with open(media_info["filename"], "rb") as file:
+            with open(media_info.filename, "rb") as file:
                 body = file.read()
 
-            encoding = get_html_media_encoding(body, media_info["media_type"])
-            og = decode_and_calc_og(body, media_info["uri"], encoding)
+            encoding = get_html_media_encoding(body, media_info.media_type)
+            og = decode_and_calc_og(body, media_info.uri, encoding)
 
             # pre-cache the image for posterity
             # FIXME: it might be cleaner to use the same flow as the main /preview_url
@@ -311,14 +293,14 @@ class PreviewUrlResource(DirectServeJsonResource):
             # just rely on the caching on the master request to speed things up.
             if "og:image" in og and og["og:image"]:
                 image_info = await self._download_url(
-                    _rebase_url(og["og:image"], media_info["uri"]), user
+                    _rebase_url(og["og:image"], media_info.uri), user
                 )
 
-                if _is_media(image_info["media_type"]):
+                if _is_media(image_info.media_type):
                     # TODO: make sure we don't choke on white-on-transparent images
-                    file_id = image_info["filesystem_id"]
+                    file_id = image_info.filesystem_id
                     dims = await self.media_repo._generate_thumbnails(
-                        None, file_id, file_id, image_info["media_type"], url_cache=True
+                        None, file_id, file_id, image_info.media_type, url_cache=True
                     )
                     if dims:
                         og["og:image:width"] = dims["width"]
@@ -326,12 +308,11 @@ class PreviewUrlResource(DirectServeJsonResource):
                     else:
                         logger.warning("Couldn't get dims for %s", og["og:image"])
 
-                    og["og:image"] = "mxc://%s/%s" % (
-                        self.server_name,
-                        image_info["filesystem_id"],
-                    )
-                    og["og:image:type"] = image_info["media_type"]
-                    og["matrix:image:size"] = image_info["media_length"]
+                    og[
+                        "og:image"
+                    ] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
+                    og["og:image:type"] = image_info.media_type
+                    og["matrix:image:size"] = image_info.media_length
                 else:
                     del og["og:image"]
         else:
@@ -357,98 +338,17 @@ class PreviewUrlResource(DirectServeJsonResource):
         # store OG in history-aware DB cache
         await self.store.store_url_cache(
             url,
-            media_info["response_code"],
-            media_info["etag"],
-            media_info["expires"] + media_info["created_ts"],
+            media_info.response_code,
+            media_info.etag,
+            media_info.expires + media_info.created_ts_ms,
             jsonog,
-            media_info["filesystem_id"],
-            media_info["created_ts"],
+            media_info.filesystem_id,
+            media_info.created_ts_ms,
         )
 
         return jsonog.encode("utf8")
 
-    def _get_oembed_url(self, url: str) -> Optional[str]:
-        """
-        Check whether the URL should be downloaded as oEmbed content instead.
-
-        Args:
-            url: The URL to check.
-
-        Returns:
-            A URL to use instead or None if the original URL should be used.
-        """
-        for url_pattern, endpoint in _oembed_patterns.items():
-            if url_pattern.fullmatch(url):
-                return endpoint
-
-        # No match.
-        return None
-
-    async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
-        """
-        Request content from an oEmbed endpoint.
-
-        Args:
-            endpoint: The oEmbed API endpoint.
-            url: The URL to pass to the API.
-
-        Returns:
-            An object representing the metadata returned.
-
-        Raises:
-            OEmbedError if fetching or parsing of the oEmbed information fails.
-        """
-        try:
-            logger.debug("Trying to get oEmbed content for url '%s'", url)
-            result = await self.client.get_json(
-                endpoint,
-                # TODO Specify max height / width.
-                # Note that only the JSON format is supported.
-                args={"url": url},
-            )
-
-            # Ensure there's a version of 1.0.
-            if result.get("version") != "1.0":
-                raise OEmbedError("Invalid version: %s" % (result.get("version"),))
-
-            oembed_type = result.get("type")
-
-            # Ensure the cache age is None or an int.
-            cache_age = result.get("cache_age")
-            if cache_age:
-                cache_age = int(cache_age)
-
-            oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
-
-            # HTML content.
-            if oembed_type == "rich":
-                oembed_result.html = result.get("html")
-                return oembed_result
-
-            if oembed_type == "photo":
-                oembed_result.url = result.get("url")
-                return oembed_result
-
-            # TODO Handle link and video types.
-
-            if "thumbnail_url" in result:
-                oembed_result.url = result.get("thumbnail_url")
-                return oembed_result
-
-            raise OEmbedError("Incompatible oEmbed information.")
-
-        except OEmbedError as e:
-            # Trap OEmbedErrors first so we can directly re-raise them.
-            logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
-            raise
-
-        except Exception as e:
-            # Trap any exception and let the code follow as usual.
-            # FIXME: pass through 404s and other error messages nicely
-            logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
-            raise OEmbedError() from e
-
-    async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
+    async def _download_url(self, url: str, user: str) -> MediaInfo:
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -459,11 +359,11 @@ class PreviewUrlResource(DirectServeJsonResource):
 
         # If this URL can be accessed via oEmbed, use that instead.
         url_to_download: Optional[str] = url
-        oembed_url = self._get_oembed_url(url)
+        oembed_url = self._oembed.get_oembed_url(url)
         if oembed_url:
             # The result might be a new URL to download, or it might be HTML content.
             try:
-                oembed_result = await self._get_oembed_content(oembed_url, url)
+                oembed_result = await self._oembed.get_oembed_content(oembed_url, url)
                 if oembed_result.url:
                     url_to_download = oembed_result.url
                 elif oembed_result.html:
@@ -560,18 +460,18 @@ class PreviewUrlResource(DirectServeJsonResource):
             # therefore not expire it.
             raise
 
-        return {
-            "media_type": media_type,
-            "media_length": length,
-            "download_name": download_name,
-            "created_ts": time_now_ms,
-            "filesystem_id": file_id,
-            "filename": fname,
-            "uri": uri,
-            "response_code": code,
-            "expires": expires,
-            "etag": etag,
-        }
+        return MediaInfo(
+            media_type=media_type,
+            media_length=length,
+            download_name=download_name,
+            created_ts_ms=time_now_ms,
+            filesystem_id=file_id,
+            filename=fname,
+            uri=uri,
+            response_code=code,
+            expires=expires,
+            etag=etag,
+        )
 
     def _start_expire_url_cache_data(self):
         return run_as_background_process(
@@ -717,7 +617,7 @@ def get_html_media_encoding(body: bytes, content_type: str) -> str:
 
 def decode_and_calc_og(
     body: bytes, media_uri: str, request_encoding: Optional[str] = None
-) -> Dict[str, Optional[str]]:
+) -> JsonDict:
     """
     Calculate metadata for an HTML document.
 
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 95d2caff62..0084d9f96c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -280,18 +280,18 @@ class LoggingTransaction:
         else:
             self.executemany(sql, args)
 
-    def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+    def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]:
         """Corresponds to psycopg2.extras.execute_values. Only available when
         using postgres.
 
-        Always sets fetch=True when caling `execute_values`, so will return the
-        results.
+        The `fetch` parameter must be set to False if the query does not return
+        rows (e.g. INSERTs).
         """
         assert isinstance(self.database_engine, PostgresEngine)
         from psycopg2.extras import execute_values  # type: ignore
 
         return self._do_execute(
-            lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+            lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
         )
 
     def execute(self, sql: str, *args: Any) -> None:
@@ -920,13 +920,23 @@ class DatabasePool:
             if k != keys[0]:
                 raise RuntimeError("All items must have the same keys")
 
-        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
-            table,
-            ", ".join(k for k in keys[0]),
-            ", ".join("?" for _ in keys[0]),
-        )
+        if isinstance(txn.database_engine, PostgresEngine):
+            # We use `execute_values` as it can be a lot faster than `execute_batch`,
+            # but it's only available on postgres.
+            sql = "INSERT INTO %s (%s) VALUES ?" % (
+                table,
+                ", ".join(k for k in keys[0]),
+            )
 
-        txn.execute_batch(sql, vals)
+            txn.execute_values(sql, vals, fetch=False)
+        else:
+            sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+                table,
+                ", ".join(k for k in keys[0]),
+                ", ".join("?" for _ in keys[0]),
+            )
+
+            txn.execute_batch(sql, vals)
 
     async def simple_upsert(
         self,
@@ -1281,20 +1291,33 @@ class DatabasePool:
                 k + "=EXCLUDED." + k for k in value_names
             )
 
-        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
-            table,
-            ", ".join(k for k in allnames),
-            ", ".join("?" for _ in allnames),
-            ", ".join(key_names),
-            latter,
-        )
-
         args = []
 
         for x, y in zip(key_values, value_values):
             args.append(tuple(x) + tuple(y))
 
-        return txn.execute_batch(sql, args)
+        if isinstance(txn.database_engine, PostgresEngine):
+            # We use `execute_values` as it can be a lot faster than `execute_batch`,
+            # but it's only available on postgres.
+            sql = "INSERT INTO %s (%s) VALUES ? ON CONFLICT (%s) DO %s" % (
+                table,
+                ", ".join(k for k in allnames),
+                ", ".join(key_names),
+                latter,
+            )
+
+            txn.execute_values(sql, args, fetch=False)
+
+        else:
+            sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+                table,
+                ", ".join(k for k in allnames),
+                ", ".join("?" for _ in allnames),
+                ", ".join(key_names),
+                latter,
+            )
+
+            return txn.execute_batch(sql, args)
 
     @overload
     async def simple_select_one(
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 86075bc55b..6daf8b8ffb 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -75,8 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore):
             desc="get_aliases_for_room",
         )
 
-
-class DirectoryStore(DirectoryWorkerStore):
     async def create_room_alias_association(
         self,
         room_alias: RoomAlias,
@@ -126,6 +124,8 @@ class DirectoryStore(DirectoryWorkerStore):
                 409, "Room alias %s already exists" % room_alias.to_string()
             )
 
+
+class DirectoryStore(DirectoryWorkerStore):
     async def delete_room_alias(self, room_alias: RoomAlias) -> str:
         room_id = await self.db_pool.runInteraction(
             "delete_room_alias", self._delete_room_alias_txn, room_alias
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 40b53274fb..f07e288056 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -575,7 +575,13 @@ class PersistEventsStore:
 
             missing_auth_chains.clear()
 
-            for auth_id, event_type, state_key, chain_id, sequence_number in txn:
+            for (
+                auth_id,
+                event_type,
+                state_key,
+                chain_id,
+                sequence_number,
+            ) in txn.fetchall():
                 event_to_types[auth_id] = (event_type, state_key)
 
                 if chain_id is None:
@@ -1379,18 +1385,18 @@ class PersistEventsStore:
         # If we're persisting an unredacted event we go and ensure
         # that we mark any redactions that reference this event as
         # requiring censoring.
-        sql = "UPDATE redactions SET have_censored = ? WHERE redacts = ?"
-        txn.execute_batch(
-            sql,
-            (
-                (
-                    False,
-                    event.event_id,
-                )
-                for event, _ in events_and_contexts
-                if not event.internal_metadata.is_redacted()
-            ),
+        unredacted_events = [
+            event.event_id
+            for event, _ in events_and_contexts
+            if not event.internal_metadata.is_redacted()
+        ]
+        sql = "UPDATE redactions SET have_censored = ? WHERE "
+        clause, args = make_in_list_sql_clause(
+            self.database_engine,
+            "redacts",
+            unredacted_events,
         )
+        txn.execute(sql + clause, [False] + args)
 
         state_events_and_contexts = [
             ec for ec in events_and_contexts if ec[0].is_state()
@@ -1770,10 +1776,21 @@ class PersistEventsStore:
             # Not a insertion event
             return
 
-        # Skip processing a insertion event if the room version doesn't
-        # support it.
+        # Skip processing an insertion event if the room version doesn't
+        # support it or the event is not from the room creator.
         room_version = self.store.get_room_version_txn(txn, event.room_id)
-        if not room_version.msc2716_historical:
+        room_creator = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="rooms",
+            keyvalues={"room_id": event.room_id},
+            retcol="creator",
+            allow_none=True,
+        )
+        if (
+            not room_version.msc2716_historical
+            or not self.hs.config.experimental.msc2716_enabled
+            or event.sender != room_creator
+        ):
             return
 
         next_chunk_id = event.content.get(EventContentFields.MSC2716_NEXT_CHUNK_ID)
@@ -1822,9 +1839,20 @@ class PersistEventsStore:
             return
 
         # Skip processing a chunk event if the room version doesn't
-        # support it.
+        # support it or the event is not from the room creator.
         room_version = self.store.get_room_version_txn(txn, event.room_id)
-        if not room_version.msc2716_historical:
+        room_creator = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="rooms",
+            keyvalues={"room_id": event.room_id},
+            retcol="creator",
+            allow_none=True,
+        )
+        if (
+            not room_version.msc2716_historical
+            or not self.hs.config.experimental.msc2716_enabled
+            or event.sender != room_creator
+        ):
             return
 
         chunk_id = event.content.get(EventContentFields.MSC2716_CHUNK_ID)
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 1388771c40..12cf6995eb 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -29,7 +29,26 @@ if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-class PresenceStore(SQLBaseStore):
+class PresenceBackgroundUpdateStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: Connection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        # Used by `PresenceStore._get_active_presence()`
+        self.db_pool.updates.register_background_index_update(
+            "presence_stream_not_offline_index",
+            index_name="presence_stream_state_not_offline_idx",
+            table="presence_stream",
+            columns=["state"],
+            where_clause="state != 'offline'",
+        )
+
+
+class PresenceStore(PresenceBackgroundUpdateStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -332,6 +351,8 @@ class PresenceStore(SQLBaseStore):
         the appropriate time outs.
         """
 
+        # The `presence_stream_state_not_offline_idx` index should be used for this
+        # query.
         sql = (
             "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
             " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f98b892598..6e7312266d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -19,9 +19,10 @@ from abc import abstractmethod
 from enum import Enum
 from typing import Any, Dict, List, Optional, Tuple
 
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventContentFields, EventTypes, JoinRules
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchStore
@@ -1013,6 +1014,7 @@ class _BackgroundUpdates:
     ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
     POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
     REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
+    POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column"
 
 
 _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
@@ -1054,6 +1056,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             self._background_replace_room_depth_min_depth,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN,
+            self._background_populate_rooms_creator_column,
+        )
+
     async def _background_insert_retention(self, progress, batch_size):
         """Retrieves a list of all rooms within a range and inserts an entry for each of
         them into the room_retention table.
@@ -1273,7 +1280,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             keyvalues={"room_id": room_id},
             retcol="MAX(stream_ordering)",
             allow_none=True,
-            desc="upsert_room_on_join",
+            desc="has_auth_chain_index_fallback",
         )
 
         return max_ordering is None
@@ -1343,6 +1350,65 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         return 0
 
+    async def _background_populate_rooms_creator_column(
+        self, progress: dict, batch_size: int
+    ):
+        """Background update to go and add creator information to `rooms`
+        table from `current_state_events` table.
+        """
+
+        last_room_id = progress.get("room_id", "")
+
+        def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction):
+            sql = """
+                SELECT room_id, json FROM event_json
+                INNER JOIN rooms AS room USING (room_id)
+                INNER JOIN current_state_events AS state_event USING (room_id, event_id)
+                WHERE room_id > ? AND (room.creator IS NULL OR room.creator = '') AND state_event.type = 'm.room.create' AND state_event.state_key = ''
+                ORDER BY room_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_room_id, batch_size))
+            room_id_to_create_event_results = txn.fetchall()
+
+            new_last_room_id = ""
+            for room_id, event_json in room_id_to_create_event_results:
+                event_dict = db_to_json(event_json)
+
+                creator = event_dict.get("content").get(EventContentFields.ROOM_CREATOR)
+
+                self.db_pool.simple_update_txn(
+                    txn,
+                    table="rooms",
+                    keyvalues={"room_id": room_id},
+                    updatevalues={"creator": creator},
+                )
+                new_last_room_id = room_id
+
+            if new_last_room_id == "":
+                return True
+
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN,
+                {"room_id": new_last_room_id},
+            )
+
+            return False
+
+        end = await self.db_pool.runInteraction(
+            "_background_populate_rooms_creator_column",
+            _background_populate_rooms_creator_column_txn,
+        )
+
+        if end:
+            await self.db_pool.updates._end_background_update(
+                _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN
+            )
+
+        return batch_size
+
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1350,7 +1416,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
         self.config = hs.config
 
-    async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion):
+    async def upsert_room_on_join(
+        self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
+    ):
         """Ensure that the room is stored in the table
 
         Called when we join a room over federation, and overwrites any room version
@@ -1361,6 +1429,24 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         # mark the room as having an auth chain cover index.
         has_auth_chain_index = await self.has_auth_chain_index(room_id)
 
+        create_event = None
+        for e in auth_events:
+            if (e.type, e.state_key) == (EventTypes.Create, ""):
+                create_event = e
+                break
+
+        if create_event is None:
+            # If the state doesn't have a create event then the room is
+            # invalid, and it would fail auth checks anyway.
+            raise StoreError(400, "No create event in state")
+
+        room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
+
+        if not isinstance(room_creator, str):
+            # If the create event does not have a creator then the room is
+            # invalid, and it would fail auth checks anyway.
+            raise StoreError(400, "No creator defined on the create event")
+
         await self.db_pool.simple_upsert(
             desc="upsert_room_on_join",
             table="rooms",
@@ -1368,7 +1454,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             values={"room_version": room_version.identifier},
             insertion_values={
                 "is_public": False,
-                "creator": "",
+                "creator": room_creator,
                 "has_auth_chain_index": has_auth_chain_index,
             },
             # rooms has a unique constraint on room_id, so no need to lock when doing an
@@ -1396,6 +1482,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             insertion_values={
                 "room_version": room_version.identifier,
                 "is_public": False,
+                # We don't worry about setting the `creator` here because
+                # we don't process any messages in a room while a user is
+                # invited (only after the join).
                 "creator": "",
                 "has_auth_chain_index": has_auth_chain_index,
             },
diff --git a/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql b/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql
new file mode 100644
index 0000000000..f7c0b31261
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json)
+    VALUES (6302, 'populate_rooms_creator_column', '{}');
diff --git a/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql b/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql
new file mode 100644
index 0000000000..b90856004b
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql
@@ -0,0 +1,18 @@
+/*
+ * Copyright 2021 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (6304, 'presence_stream_not_offline_index', '{}');
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index c768fdea56..6f7cbe40f4 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -19,6 +19,7 @@ from contextlib import contextmanager
 from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import attr
+from sortedcontainers import SortedSet
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -240,7 +241,7 @@ class MultiWriterIdGenerator:
 
         # Set of local IDs that we're still processing. The current position
         # should be less than the minimum of this set (if not empty).
-        self._unfinished_ids: Set[int] = set()
+        self._unfinished_ids: SortedSet[int] = SortedSet()
 
         # Set of local IDs that we've processed that are larger than the current
         # position, due to there being smaller unpersisted IDs.
@@ -473,7 +474,7 @@ class MultiWriterIdGenerator:
 
                 finished = set()
 
-                min_unfinshed = min(self._unfinished_ids)
+                min_unfinshed = self._unfinished_ids[0]
                 for s in self._finished_ids:
                     if s < min_unfinshed:
                         if new_cur is None or new_cur < s:
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 522daa323d..cfb5b94ca9 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -61,7 +61,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
 -----END RSA PRIVATE KEY-----"""
 
 
-def manhole(username, password, globals):
+def manhole(settings, globals):
     """Starts a ssh listener with password authentication using
     the given username and password. Clients connecting to the ssh
     listener will find themselves in a colored python shell with
@@ -75,6 +75,15 @@ def manhole(username, password, globals):
     Returns:
         twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
     """
+    username = settings.username
+    password = settings.password
+    priv_key = settings.priv_key
+    if priv_key is None:
+        priv_key = Key.fromString(PRIVATE_KEY)
+    pub_key = settings.pub_key
+    if pub_key is None:
+        pub_key = Key.fromString(PUBLIC_KEY)
+
     if not isinstance(password, bytes):
         password = password.encode("ascii")
 
@@ -86,8 +95,8 @@ def manhole(username, password, globals):
     )
 
     factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
-    factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY)
-    factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY)
+    factory.privateKeys[b"ssh-rsa"] = priv_key
+    factory.publicKeys[b"ssh-rsa"] = pub_key
 
     return factory