summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/_base.py3
-rw-r--r--synapse/app/phone_stats_home.py9
-rw-r--r--synapse/config/_base.py42
-rw-r--r--synapse/config/_base.pyi4
-rw-r--r--synapse/config/captcha.py4
-rw-r--r--synapse/config/cas.py12
-rw-r--r--synapse/config/consent_config.py2
-rw-r--r--synapse/config/oidc_config.py3
-rw-r--r--synapse/config/ratelimiting.py32
-rw-r--r--synapse/config/registration.py4
-rw-r--r--synapse/crypto/context_factory.py13
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/federation/sender/__init__.py50
-rw-r--r--synapse/handlers/acme.py12
-rw-r--r--synapse/handlers/acme_issuing_service.py27
-rw-r--r--synapse/handlers/auth.py10
-rw-r--r--synapse/handlers/cas_handler.py6
-rw-r--r--synapse/handlers/device.py12
-rw-r--r--synapse/handlers/e2e_keys.py223
-rw-r--r--synapse/handlers/e2e_room_keys.py91
-rw-r--r--synapse/handlers/federation.py9
-rw-r--r--synapse/handlers/groups_local.py83
-rw-r--r--synapse/handlers/identity.py28
-rw-r--r--synapse/handlers/message.py42
-rw-r--r--synapse/handlers/room.py7
-rw-r--r--synapse/handlers/room_member.py25
-rw-r--r--synapse/handlers/search.py38
-rw-r--r--synapse/handlers/set_password.py10
-rw-r--r--synapse/handlers/state_deltas.py14
-rw-r--r--synapse/handlers/stats.py39
-rw-r--r--synapse/handlers/typing.py69
-rw-r--r--synapse/handlers/user_directory.py9
-rw-r--r--synapse/logging/opentracing.py2
-rw-r--r--synapse/push/mailer.py313
-rw-r--r--synapse/push/presentable_names.py26
-rw-r--r--synapse/replication/tcp/external_cache.py105
-rw-r--r--synapse/replication/tcp/handler.py21
-rw-r--r--synapse/replication/tcp/redis.py143
-rw-r--r--synapse/res/templates/sso_auth_bad_user.html2
-rw-r--r--synapse/res/templates/sso_auth_confirm.html6
-rw-r--r--synapse/res/templates/sso_error.html4
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html12
-rw-r--r--synapse/res/templates/sso_redirect_confirm.html10
-rw-r--r--synapse/rest/admin/__init__.py6
-rw-r--r--synapse/rest/admin/rooms.py71
-rw-r--r--synapse/rest/admin/users.py57
-rw-r--r--synapse/rest/client/v2_alpha/account.py12
-rw-r--r--synapse/rest/client/v2_alpha/register.py6
-rw-r--r--synapse/rest/media/v1/_base.py3
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py44
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py236
-rw-r--r--synapse/server.py30
-rw-r--r--synapse/state/__init__.py11
-rw-r--r--synapse/storage/database.py6
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py4
-rw-r--r--synapse/storage/databases/main/events.py199
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py12
-rw-r--r--synapse/storage/databases/main/events_forward_extremities.py101
-rw-r--r--synapse/storage/databases/main/media_repository.py10
-rw-r--r--synapse/storage/databases/main/metrics.py56
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/registration.py31
-rw-r--r--synapse/storage/databases/main/roommember.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/59/01ignored_user.py2
-rw-r--r--synapse/storage/databases/main/search.py7
-rw-r--r--synapse/storage/databases/main/stats.py22
-rw-r--r--synapse/storage/databases/main/user_directory.py2
-rw-r--r--synapse/storage/databases/state/store.py4
-rw-r--r--synapse/storage/util/id_generators.py21
-rw-r--r--synapse/storage/util/sequence.py16
-rw-r--r--synapse/util/module_loader.py3
75 files changed, 1807 insertions, 767 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 395e202b89..9840a9d55b 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -16,6 +16,7 @@
 import gc
 import logging
 import os
+import platform
 import signal
 import socket
 import sys
@@ -339,7 +340,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
     # rest of time. Doing so means less work each GC (hopefully).
     #
     # This only works on Python 3.7
-    if sys.version_info >= (3, 7):
+    if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
         gc.collect()
         gc.freeze()
 
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index c38cf8231f..8f86cecb76 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -93,15 +93,20 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
 
     stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
     stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+    daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
+    stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
+    stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
+    daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
+    stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
     stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
     stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
+    daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+    stats["daily_sent_messages"] = daily_sent_messages
 
     r30_results = await hs.get_datastore().count_r30_users()
     for name, count in r30_results.items():
         stats["r30_users_" + name] = count
 
-    daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
-    stats["daily_sent_messages"] = daily_sent_messages
     stats["cache_factor"] = hs.config.caches.global_factor
     stats["event_cache_size"] = hs.config.caches.event_cache_size
 
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 35e5594b73..a851f8801d 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -203,11 +203,28 @@ class Config:
         with open(file_path) as file_stream:
             return file_stream.read()
 
+    def read_template(self, filename: str) -> jinja2.Template:
+        """Load a template file from disk.
+
+        This function will attempt to load the given template from the default Synapse
+        template directory.
+
+        Files read are treated as Jinja templates. The templates is not rendered yet
+        and has autoescape enabled.
+
+        Args:
+            filename: A template filename to read.
+
+        Raises:
+            ConfigError: if the file's path is incorrect or otherwise cannot be read.
+
+        Returns:
+            A jinja2 template.
+        """
+        return self.read_templates([filename])[0]
+
     def read_templates(
-        self,
-        filenames: List[str],
-        custom_template_directory: Optional[str] = None,
-        autoescape: bool = False,
+        self, filenames: List[str], custom_template_directory: Optional[str] = None,
     ) -> List[jinja2.Template]:
         """Load a list of template files from disk using the given variables.
 
@@ -215,7 +232,8 @@ class Config:
         template directory. If `custom_template_directory` is supplied, that directory
         is tried first.
 
-        Files read are treated as Jinja templates. These templates are not rendered yet.
+        Files read are treated as Jinja templates. The templates are not rendered yet
+        and have autoescape enabled.
 
         Args:
             filenames: A list of template filenames to read.
@@ -223,16 +241,12 @@ class Config:
             custom_template_directory: A directory to try to look for the templates
                 before using the default Synapse template directory instead.
 
-            autoescape: Whether to autoescape variables before inserting them into the
-                template.
-
         Raises:
             ConfigError: if the file's path is incorrect or otherwise cannot be read.
 
         Returns:
             A list of jinja2 templates.
         """
-        templates = []
         search_directories = [self.default_template_dir]
 
         # The loader will first look in the custom template directory (if specified) for the
@@ -250,7 +264,7 @@ class Config:
 
         # TODO: switch to synapse.util.templates.build_jinja_env
         loader = jinja2.FileSystemLoader(search_directories)
-        env = jinja2.Environment(loader=loader, autoescape=autoescape)
+        env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),)
 
         # Update the environment with our custom filters
         env.filters.update(
@@ -260,12 +274,8 @@ class Config:
             }
         )
 
-        for filename in filenames:
-            # Load the template
-            template = env.get_template(filename)
-            templates.append(template)
-
-        return templates
+        # Load the templates
+        return [env.get_template(filename) for filename in filenames]
 
 
 class RootConfig:
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 3ccea4b02d..70025b5d60 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -19,6 +19,7 @@ from synapse.config import (
     password_auth_providers,
     push,
     ratelimiting,
+    redis,
     registration,
     repository,
     room_directory,
@@ -53,7 +54,7 @@ class RootConfig:
     tls: tls.TlsConfig
     database: database.DatabaseConfig
     logging: logger.LoggingConfig
-    ratelimit: ratelimiting.RatelimitConfig
+    ratelimiting: ratelimiting.RatelimitConfig
     media: repository.ContentRepositoryConfig
     captcha: captcha.CaptchaConfig
     voip: voip.VoipConfig
@@ -81,6 +82,7 @@ class RootConfig:
     roomdirectory: room_directory.RoomDirectoryConfig
     thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
     tracer: tracer.TracerConfig
+    redis: redis.RedisConfig
 
     config_classes: List = ...
     def __init__(self) -> None: ...
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index cb00958165..9e48f865cc 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -28,9 +28,7 @@ class CaptchaConfig(Config):
             "recaptcha_siteverify_api",
             "https://www.recaptcha.net/recaptcha/api/siteverify",
         )
-        self.recaptcha_template = self.read_templates(
-            ["recaptcha.html"], autoescape=True
-        )[0]
+        self.recaptcha_template = self.read_template("recaptcha.html")
 
     def generate_config_section(self, **kwargs):
         return """\
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index c7877b4095..b226890c2a 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -30,7 +30,13 @@ class CasConfig(Config):
 
         if self.cas_enabled:
             self.cas_server_url = cas_config["server_url"]
-            self.cas_service_url = cas_config["service_url"]
+            public_base_url = cas_config.get("service_url") or self.public_baseurl
+            if public_base_url[-1] != "/":
+                public_base_url += "/"
+            # TODO Update this to a _synapse URL.
+            self.cas_service_url = (
+                public_base_url + "_matrix/client/r0/login/cas/ticket"
+            )
             self.cas_displayname_attribute = cas_config.get("displayname_attribute")
             self.cas_required_attributes = cas_config.get("required_attributes") or {}
         else:
@@ -53,10 +59,6 @@ class CasConfig(Config):
           #
           #server_url: "https://cas-server.com"
 
-          # The public URL of the homeserver.
-          #
-          #service_url: "https://homeserver.domain.com:8448"
-
           # The attribute of the CAS response to use as the display name.
           #
           # If unset, no displayname will be set.
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index 6efa59b110..c47f364b14 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -89,7 +89,7 @@ class ConsentConfig(Config):
 
     def read_config(self, config, **kwargs):
         consent_config = config.get("user_consent")
-        self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0]
+        self.terms_template = self.read_template("terms.html")
 
         if consent_config is None:
             return
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 784b416f95..bb122ef182 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -53,8 +53,7 @@ class OIDCConfig(Config):
                     "Multiple OIDC providers have the idp_id %r." % idp_id
                 )
 
-        public_baseurl = self.public_baseurl
-        self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
+        self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback"
 
     @property
     def oidc_enabled(self) -> bool:
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 14b8836197..def33a60ad 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -24,7 +24,7 @@ class RateLimitConfig:
         defaults={"per_second": 0.17, "burst_count": 3.0},
     ):
         self.per_second = config.get("per_second", defaults["per_second"])
-        self.burst_count = config.get("burst_count", defaults["burst_count"])
+        self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
 
 
 class FederationRateLimitConfig:
@@ -102,6 +102,20 @@ class RatelimitConfig(Config):
             defaults={"per_second": 0.01, "burst_count": 3},
         )
 
+        self.rc_3pid_validation = RateLimitConfig(
+            config.get("rc_3pid_validation") or {},
+            defaults={"per_second": 0.003, "burst_count": 5},
+        )
+
+        self.rc_invites_per_room = RateLimitConfig(
+            config.get("rc_invites", {}).get("per_room", {}),
+            defaults={"per_second": 0.3, "burst_count": 10},
+        )
+        self.rc_invites_per_user = RateLimitConfig(
+            config.get("rc_invites", {}).get("per_user", {}),
+            defaults={"per_second": 0.003, "burst_count": 5},
+        )
+
     def generate_config_section(self, **kwargs):
         return """\
         ## Ratelimiting ##
@@ -131,6 +145,9 @@ class RatelimitConfig(Config):
         #     users are joining rooms the server is already in (this is cheap) vs
         #     "remote" for when users are trying to join rooms not on the server (which
         #     can be more expensive)
+        #   - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
+        #   - two for ratelimiting how often invites can be sent in a room or to a
+        #     specific user.
         #
         # The defaults are as shown below.
         #
@@ -164,7 +181,18 @@ class RatelimitConfig(Config):
         #  remote:
         #    per_second: 0.01
         #    burst_count: 3
-
+        #
+        #rc_3pid_validation:
+        #  per_second: 0.003
+        #  burst_count: 5
+        #
+        #rc_invites:
+        #  per_room:
+        #    per_second: 0.3
+        #    burst_count: 10
+        #  per_user:
+        #    per_second: 0.003
+        #    burst_count: 5
 
         # Ratelimiting settings for incoming federation
         #
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 4bfc69cb7a..ac48913a0b 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -176,9 +176,7 @@ class RegistrationConfig(Config):
         self.session_lifetime = session_lifetime
 
         # The success template used during fallback auth.
-        self.fallback_success_template = self.read_templates(
-            ["auth_success.html"], autoescape=True
-        )[0]
+        self.fallback_success_template = self.read_template("auth_success.html")
 
     def generate_config_section(self, generate_secrets=False, **kwargs):
         if generate_secrets:
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 74b67b230a..14b21796d9 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -125,19 +125,24 @@ class FederationPolicyForHTTPS:
         self._no_verify_ssl_context = _no_verify_ssl.getContext()
         self._no_verify_ssl_context.set_info_callback(_context_info_cb)
 
-    def get_options(self, host: bytes):
+        self._should_verify = self._config.federation_verify_certificates
+
+        self._federation_certificate_verification_whitelist = (
+            self._config.federation_certificate_verification_whitelist
+        )
 
+    def get_options(self, host: bytes):
         # IPolicyForHTTPS.get_options takes bytes, but we want to compare
         # against the str whitelist. The hostnames in the whitelist are already
         # IDNA-encoded like the hosts will be here.
         ascii_host = host.decode("ascii")
 
         # Check if certificate verification has been enabled
-        should_verify = self._config.federation_verify_certificates
+        should_verify = self._should_verify
 
         # Check if we've disabled certificate verification for this host
-        if should_verify:
-            for regex in self._config.federation_certificate_verification_whitelist:
+        if self._should_verify:
+            for regex in self._federation_certificate_verification_whitelist:
                 if regex.match(ascii_host):
                     should_verify = False
                     break
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d330ae5dbc..40e1451201 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -810,7 +810,7 @@ class FederationClient(FederationBase):
                         "User's homeserver does not support this room version",
                         Codes.UNSUPPORTED_ROOM_VERSION,
                     )
-            elif e.code == 403:
+            elif e.code in (403, 429):
                 raise e.to_synapse_error()
             else:
                 raise
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 604cfd1935..643b26ae6d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -142,6 +142,8 @@ class FederationSender:
             self._wake_destinations_needing_catchup,
         )
 
+        self._external_cache = hs.get_external_cache()
+
     def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
         """Get or create a PerDestinationQueue for the given destination
 
@@ -197,22 +199,40 @@ class FederationSender:
                     if not event.internal_metadata.should_proactively_send():
                         return
 
-                    try:
-                        # Get the state from before the event.
-                        # We need to make sure that this is the state from before
-                        # the event and not from after it.
-                        # Otherwise if the last member on a server in a room is
-                        # banned then it won't receive the event because it won't
-                        # be in the room after the ban.
-                        destinations = await self.state.get_hosts_in_room_at_events(
-                            event.room_id, event_ids=event.prev_event_ids()
-                        )
-                    except Exception:
-                        logger.exception(
-                            "Failed to calculate hosts in room for event: %s",
-                            event.event_id,
+                    destinations = None  # type: Optional[Set[str]]
+                    if not event.prev_event_ids():
+                        # If there are no prev event IDs then the state is empty
+                        # and so no remote servers in the room
+                        destinations = set()
+                    else:
+                        # We check the external cache for the destinations, which is
+                        # stored per state group.
+
+                        sg = await self._external_cache.get(
+                            "event_to_prev_state_group", event.event_id
                         )
-                        return
+                        if sg:
+                            destinations = await self._external_cache.get(
+                                "get_joined_hosts", str(sg)
+                            )
+
+                    if destinations is None:
+                        try:
+                            # Get the state from before the event.
+                            # We need to make sure that this is the state from before
+                            # the event and not from after it.
+                            # Otherwise if the last member on a server in a room is
+                            # banned then it won't receive the event because it won't
+                            # be in the room after the ban.
+                            destinations = await self.state.get_hosts_in_room_at_events(
+                                event.room_id, event_ids=event.prev_event_ids()
+                            )
+                        except Exception:
+                            logger.exception(
+                                "Failed to calculate hosts in room for event: %s",
+                                event.event_id,
+                            )
+                            return
 
                     destinations = {
                         d
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 8476256a59..5ecb2da1ac 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 import twisted
 import twisted.internet.error
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
 
 from synapse.app import check_bind_error
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 ACME_REGISTER_FAIL_ERROR = """
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
 
 
 class AcmeHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.reactor = hs.get_reactor()
         self._acme_domain = hs.config.acme_domain
 
-    async def start_listening(self):
+    async def start_listening(self) -> None:
         from synapse.handlers import acme_issuing_service
 
         # Configure logging for txacme, if you need to debug
@@ -85,7 +89,7 @@ class AcmeHandler:
             logger.error(ACME_REGISTER_FAIL_ERROR)
             raise
 
-    async def provision_certificate(self):
+    async def provision_certificate(self) -> None:
 
         logger.warning("Reprovisioning %s", self._acme_domain)
 
@@ -110,5 +114,3 @@ class AcmeHandler:
         except Exception:
             logger.exception("Failed saving!")
             raise
-
-        return True
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 7294649d71..ae2a9dd9c2 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
 imported conditionally.
 """
 import logging
+from typing import Dict, Iterable, List
 
 import attr
+import pem
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
 from josepy import JWKRSA
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
 from zope.interface import implementer
 
 from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTCP
 from twisted.python.filepath import FilePath
 from twisted.python.url import URL
+from twisted.web.resource import IResource
 
 logger = logging.getLogger(__name__)
 
 
-def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+def create_issuing_service(
+    reactor: IReactorTCP,
+    acme_url: str,
+    account_key_file: str,
+    well_known_resource: IResource,
+) -> AcmeIssuingService:
     """Create an ACME issuing service, and attach it to a web Resource
 
     Args:
         reactor: twisted reactor
-        acme_url (str): URL to use to request certificates
-        account_key_file (str): where to store the account key
-        well_known_resource (twisted.web.IResource): web resource for .well-known.
+        acme_url: URL to use to request certificates
+        account_key_file: where to store the account key
+        well_known_resource: web resource for .well-known.
             we will attach a child resource for "acme-challenge".
 
     Returns:
@@ -83,18 +92,20 @@ class ErsatzStore:
     A store that only stores in memory.
     """
 
-    certs = attr.ib(default=attr.Factory(dict))
+    certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
 
-    def store(self, server_name, pem_objects):
+    def store(
+        self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
+    ) -> defer.Deferred:
         self.certs[server_name] = [o.as_bytes() for o in pem_objects]
         return defer.succeed(None)
 
 
-def load_or_create_client_key(key_file):
+def load_or_create_client_key(key_file: str) -> JWKRSA:
     """Load the ACME account key from a file, creating it if it does not exist.
 
     Args:
-        key_file (str): name of the file to use as the account key
+        key_file: name of the file to use as the account key
     """
     # this is based on txacme.endpoint.load_or_create_client_key, but doesn't
     # hardcode the 'client.key' filename
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6f746711ca..a19c556437 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -568,16 +568,6 @@ class AuthHandler(BaseHandler):
                         session.session_id, login_type, result
                     )
             except LoginError as e:
-                if login_type == LoginType.EMAIL_IDENTITY:
-                    # riot used to have a bug where it would request a new
-                    # validation token (thus sending a new email) each time it
-                    # got a 401 with a 'flows' field.
-                    # (https://github.com/vector-im/vector-web/issues/2447).
-                    #
-                    # Grandfather in the old behaviour for now to avoid
-                    # breaking old riot deployments.
-                    raise
-
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
                 errordict = e.error_dict()
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048523ec94..bd35d1fb87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -100,11 +100,7 @@ class CasHandler:
         Returns:
             The URL to use as a "service" parameter.
         """
-        return "%s%s?%s" % (
-            self._cas_service_url,
-            "/_matrix/client/r0/login/cas/ticket",
-            urllib.parse.urlencode(args),
-        )
+        return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
 
     async def _validate_ticket(
         self, ticket: str, service_args: Dict[str, str]
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index debb1b4f29..0863154f7a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
         self._auth_handler = hs.get_auth_handler()
 
     @trace
-    async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
+    async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
         """
         Retrieve the given user's devices
 
@@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
         return devices
 
     @trace
-    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
+    async def get_device(self, user_id: str, device_id: str) -> JsonDict:
         """ Retrieve the given device
 
         Args:
@@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
 
 def _update_device_from_client_ips(
-    device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+    device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -946,8 +946,8 @@ class DeviceListUpdater:
     async def process_cross_signing_key_update(
         self,
         user_id: str,
-        master_key: Optional[Dict[str, Any]],
-        self_signing_key: Optional[Dict[str, Any]],
+        master_key: Optional[JsonDict],
+        self_signing_key: Optional[JsonDict],
     ) -> List[str]:
         """Process the given new master and self-signing key for the given remote user.
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..8f3a6b35a4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import (
+    JsonDict,
     UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class E2eKeysHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
@@ -78,7 +82,9 @@ class E2eKeysHandler:
         )
 
     @trace
-    async def query_devices(self, query_body, timeout, from_user_id):
+    async def query_devices(
+        self, query_body: JsonDict, timeout: int, from_user_id: str
+    ) -> JsonDict:
         """ Handle a device key query from a client
 
         {
@@ -98,12 +104,14 @@ class E2eKeysHandler:
         }
 
         Args:
-            from_user_id (str): the user making the query.  This is used when
+            from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
         """
 
-        device_keys_query = query_body.get("device_keys", {})
+        device_keys_query = query_body.get(
+            "device_keys", {}
+        )  # type: Dict[str, Iterable[str]]
 
         # separate users by domain.
         # make a map from domain to user_id to device_ids
@@ -121,7 +129,8 @@ class E2eKeysHandler:
         set_tag("remote_key_query", remote_queries)
 
         # First get local devices.
-        failures = {}
+        # A map of destination -> failure response.
+        failures = {}  # type: Dict[str, JsonDict]
         results = {}
         if local_query:
             local_result = await self.query_local_devices(local_query)
@@ -135,9 +144,10 @@ class E2eKeysHandler:
         )
 
         # Now attempt to get any remote devices from our local cache.
-        remote_queries_not_in_cache = {}
+        # A map of destination -> user ID -> device IDs.
+        remote_queries_not_in_cache = {}  # type: Dict[str, Dict[str, Iterable[str]]]
         if remote_queries:
-            query_list = []
+            query_list = []  # type: List[Tuple[str, Optional[str]]]
             for user_id, device_ids in remote_queries.items():
                 if device_ids:
                     query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -284,15 +294,15 @@ class E2eKeysHandler:
         return ret
 
     async def get_cross_signing_keys_from_cache(
-        self, query, from_user_id
+        self, query: Iterable[str], from_user_id: Optional[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Get cross-signing keys for users from the database
 
         Args:
-            query (Iterable[string]) an iterable of user IDs.  A dict whose keys
+            query: an iterable of user IDs.  A dict whose keys
                 are user IDs satisfies this, so the query format used for
                 query_devices can be used here.
-            from_user_id (str): the user making the query.  This is used when
+            from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
 
@@ -315,14 +325,12 @@ class E2eKeysHandler:
             if "self_signing" in user_info:
                 self_signing_keys[user_id] = user_info["self_signing"]
 
-        if (
-            from_user_id in keys
-            and keys[from_user_id] is not None
-            and "user_signing" in keys[from_user_id]
-        ):
-            # users can see other users' master and self-signing keys, but can
-            # only see their own user-signing keys
-            user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+        # users can see other users' master and self-signing keys, but can
+        # only see their own user-signing keys
+        if from_user_id:
+            from_user_key = keys.get(from_user_id)
+            if from_user_key and "user_signing" in from_user_key:
+                user_signing_keys[from_user_id] = from_user_key["user_signing"]
 
         return {
             "master_keys": master_keys,
@@ -344,9 +352,9 @@ class E2eKeysHandler:
             A map from user_id -> device_id -> device details
         """
         set_tag("local_query", query)
-        local_query = []
+        local_query = []  # type: List[Tuple[str, Optional[str]]]
 
-        result_dict = {}
+        result_dict = {}  # type: Dict[str, Dict[str, dict]]
         for user_id, device_ids in query.items():
             # we use UserID.from_string to catch invalid user ids
             if not self.is_mine(UserID.from_string(user_id)):
@@ -380,10 +388,14 @@ class E2eKeysHandler:
         log_kv(results)
         return result_dict
 
-    async def on_federation_query_client_keys(self, query_body):
+    async def on_federation_query_client_keys(
+        self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
+    ) -> JsonDict:
         """ Handle a device key query from a federated server
         """
-        device_keys_query = query_body.get("device_keys", {})
+        device_keys_query = query_body.get(
+            "device_keys", {}
+        )  # type: Dict[str, Optional[List[str]]]
         res = await self.query_local_devices(device_keys_query)
         ret = {"device_keys": res}
 
@@ -397,31 +409,34 @@ class E2eKeysHandler:
         return ret
 
     @trace
-    async def claim_one_time_keys(self, query, timeout):
-        local_query = []
-        remote_queries = {}
+    async def claim_one_time_keys(
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+    ) -> JsonDict:
+        local_query = []  # type: List[Tuple[str, str, str]]
+        remote_queries = {}  # type: Dict[str, Dict[str, Dict[str, str]]]
 
-        for user_id, device_keys in query.get("one_time_keys", {}).items():
+        for user_id, one_time_keys in query.get("one_time_keys", {}).items():
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
-                for device_id, algorithm in device_keys.items():
+                for device_id, algorithm in one_time_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
                 domain = get_domain_from_id(user_id)
-                remote_queries.setdefault(domain, {})[user_id] = device_keys
+                remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
         set_tag("local_key_query", local_query)
         set_tag("remote_key_query", remote_queries)
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
-        json_result = {}
-        failures = {}
+        # A map of user ID -> device ID -> key ID -> key.
+        json_result = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+        failures = {}  # type: Dict[str, JsonDict]
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
-                for key_id, json_bytes in keys.items():
+                for key_id, json_str in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json_decoder.decode(json_bytes)
+                        key_id: json_decoder.decode(json_str)
                     }
 
         @trace
@@ -468,7 +483,9 @@ class E2eKeysHandler:
         return {"one_time_keys": json_result, "failures": failures}
 
     @tag_args
-    async def upload_keys_for_user(self, user_id, device_id, keys):
+    async def upload_keys_for_user(
+        self, user_id: str, device_id: str, keys: JsonDict
+    ) -> JsonDict:
 
         time_now = self.clock.time_msec()
 
@@ -543,8 +560,8 @@ class E2eKeysHandler:
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(
-        self, user_id, device_id, time_now, one_time_keys
-    ):
+        self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+    ) -> None:
         logger.info(
             "Adding one_time_keys %r for device %r for user %r at %d",
             one_time_keys.keys(),
@@ -585,12 +602,14 @@ class E2eKeysHandler:
         log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
         await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
 
-    async def upload_signing_keys_for_user(self, user_id, keys):
+    async def upload_signing_keys_for_user(
+        self, user_id: str, keys: JsonDict
+    ) -> JsonDict:
         """Upload signing keys for cross-signing
 
         Args:
-            user_id (string): the user uploading the keys
-            keys (dict[string, dict]): the signing keys
+            user_id: the user uploading the keys
+            keys: the signing keys
         """
 
         # if a master key is uploaded, then check it.  Otherwise, load the
@@ -667,16 +686,17 @@ class E2eKeysHandler:
 
         return {}
 
-    async def upload_signatures_for_device_keys(self, user_id, signatures):
+    async def upload_signatures_for_device_keys(
+        self, user_id: str, signatures: JsonDict
+    ) -> JsonDict:
         """Upload device signatures for cross-signing
 
         Args:
-            user_id (string): the user uploading the signatures
-            signatures (dict[string, dict[string, dict]]): map of users to
-                devices to signed keys. This is the submission from the user; an
-                exception will be raised if it is malformed.
+            user_id: the user uploading the signatures
+            signatures: map of users to devices to signed keys. This is the submission
+            from the user; an exception will be raised if it is malformed.
         Returns:
-            dict: response to be sent back to the client.  The response will have
+            The response to be sent back to the client.  The response will have
                 a "failures" key, which will be a dict mapping users to devices
                 to errors for the signatures that failed.
         Raises:
@@ -719,7 +739,9 @@ class E2eKeysHandler:
 
         return {"failures": failures}
 
-    async def _process_self_signatures(self, user_id, signatures):
+    async def _process_self_signatures(
+        self, user_id: str, signatures: JsonDict
+    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
         """Process uploaded signatures of the user's own keys.
 
         Signatures of the user's own keys from this API come in two forms:
@@ -731,15 +753,14 @@ class E2eKeysHandler:
             signatures (dict[string, dict]): map of devices to signed keys
 
         Returns:
-            (list[SignatureListItem], dict[string, dict[string, dict]]):
-            a list of signatures to store, and a map of users to devices to failure
-            reasons
+            A tuple of a list of signatures to store, and a map of users to
+            devices to failure reasons
 
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []
-        failures = {}
+        signature_list = []  # type: List[SignatureListItem]
+        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
         if not signatures:
             return signature_list, failures
 
@@ -834,19 +855,24 @@ class E2eKeysHandler:
         return signature_list, failures
 
     def _check_master_key_signature(
-        self, user_id, master_key_id, signed_master_key, stored_master_key, devices
-    ):
+        self,
+        user_id: str,
+        master_key_id: str,
+        signed_master_key: JsonDict,
+        stored_master_key: JsonDict,
+        devices: Dict[str, Dict[str, JsonDict]],
+    ) -> List["SignatureListItem"]:
         """Check signatures of a user's master key made by their devices.
 
         Args:
-            user_id (string): the user whose master key is being checked
-            master_key_id (string): the ID of the user's master key
-            signed_master_key (dict): the user's signed master key that was uploaded
-            stored_master_key (dict): our previously-stored copy of the user's master key
-            devices (iterable(dict)): the user's devices
+            user_id: the user whose master key is being checked
+            master_key_id: the ID of the user's master key
+            signed_master_key: the user's signed master key that was uploaded
+            stored_master_key: our previously-stored copy of the user's master key
+            devices: the user's devices
 
         Returns:
-            list[SignatureListItem]: a list of signatures to store
+            A list of signatures to store
 
         Raises:
             SynapseError: if a signature is invalid
@@ -877,25 +903,26 @@ class E2eKeysHandler:
 
         return master_key_signature_list
 
-    async def _process_other_signatures(self, user_id, signatures):
+    async def _process_other_signatures(
+        self, user_id: str, signatures: Dict[str, dict]
+    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
         """Process uploaded signatures of other users' keys.  These will be the
         target user's master keys, signed by the uploading user's user-signing
         key.
 
         Args:
-            user_id (string): the user uploading the keys
-            signatures (dict[string, dict]): map of users to devices to signed keys
+            user_id: the user uploading the keys
+            signatures: map of users to devices to signed keys
 
         Returns:
-            (list[SignatureListItem], dict[string, dict[string, dict]]):
-            a list of signatures to store, and a map of users to devices to failure
+            A list of signatures to store, and a map of users to devices to failure
             reasons
 
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []
-        failures = {}
+        signature_list = []  # type: List[SignatureListItem]
+        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
         if not signatures:
             return signature_list, failures
 
@@ -983,7 +1010,7 @@ class E2eKeysHandler:
 
     async def _get_e2e_cross_signing_verify_key(
         self, user_id: str, key_type: str, from_user_id: str = None
-    ):
+    ) -> Tuple[JsonDict, str, VerifyKey]:
         """Fetch locally or remotely query for a cross-signing public key.
 
         First, attempt to fetch the cross-signing public key from storage.
@@ -997,8 +1024,7 @@ class E2eKeysHandler:
                 This affects what signatures are fetched.
 
         Returns:
-            dict, str, VerifyKey: the raw key data, the key ID, and the
-                signedjson verify key
+            The raw key data, the key ID, and the signedjson verify key
 
         Raises:
             NotFoundError: if the key is not found
@@ -1135,16 +1161,18 @@ class E2eKeysHandler:
         return desired_key, desired_key_id, desired_verify_key
 
 
-def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+def _check_cross_signing_key(
+    key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
+) -> None:
     """Check a cross-signing key uploaded by a user.  Performs some basic sanity
     checking, and ensures that it is signed, if a signature is required.
 
     Args:
-        key (dict): the key data to verify
-        user_id (str): the user whose key is being checked
-        key_type (str): the type of key that the key should be
-        signing_key (VerifyKey): (optional) the signing key that the key should
-            be signed with.  If omitted, signatures will not be checked.
+        key: the key data to verify
+        user_id: the user whose key is being checked
+        key_type: the type of key that the key should be
+        signing_key: the signing key that the key should be signed with.  If
+            omitted, signatures will not be checked.
     """
     if (
         key.get("user_id") != user_id
@@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
             )
 
 
-def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+def _check_device_signature(
+    user_id: str,
+    verify_key: VerifyKey,
+    signed_device: JsonDict,
+    stored_device: JsonDict,
+) -> None:
     """Check that a signature on a device or cross-signing key is correct and
     matches the copy of the device/key that we have stored.  Throws an
     exception if an error is detected.
 
     Args:
-        user_id (str): the user ID whose signature is being checked
-        verify_key (VerifyKey): the key to verify the device with
-        signed_device (dict): the uploaded signed device data
-        stored_device (dict): our previously stored copy of the device
+        user_id: the user ID whose signature is being checked
+        verify_key: the key to verify the device with
+        signed_device: the uploaded signed device data
+        stored_device: our previously stored copy of the device
 
     Raises:
         SynapseError: if the signature was invalid or the sent device is not the
@@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
         raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
 
 
-def _exception_to_failure(e):
+def _exception_to_failure(e: Exception) -> JsonDict:
     if isinstance(e, SynapseError):
         return {"status": e.code, "errcode": e.errcode, "message": str(e)}
 
@@ -1218,7 +1251,7 @@ def _exception_to_failure(e):
     return {"status": 503, "message": str(e)}
 
 
-def _one_time_keys_match(old_key_json, new_key):
+def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
     old_key = json_decoder.decode(old_key_json)
 
     # if either is a string rather than an object, they must match exactly
@@ -1239,16 +1272,16 @@ class SignatureListItem:
     """An item in the signature list as used by upload_signatures_for_device_keys.
     """
 
-    signing_key_id = attr.ib()
-    target_user_id = attr.ib()
-    target_device_id = attr.ib()
-    signature = attr.ib()
+    signing_key_id = attr.ib(type=str)
+    target_user_id = attr.ib(type=str)
+    target_device_id = attr.ib(type=str)
+    signature = attr.ib(type=JsonDict)
 
 
 class SigningKeyEduUpdater:
     """Handles incoming signing key updates from federation and updates the DB"""
 
-    def __init__(self, hs, e2e_keys_handler):
+    def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
@@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = {}
+        self._pending_updates = {}  # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
 
         # Recently seen stream ids. We don't bother keeping these in the DB,
         # but they're useful to have them about to reduce the number of spurious
@@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
             iterable=True,
         )
 
-    async def incoming_signing_key_update(self, origin, edu_content):
+    async def incoming_signing_key_update(
+        self, origin: str, edu_content: JsonDict
+    ) -> None:
         """Called on incoming signing key update from federation. Responsible for
         parsing the EDU and adding to pending updates list.
 
         Args:
-            origin (string): the server that sent the EDU
-            edu_content (dict): the contents of the EDU
+            origin: the server that sent the EDU
+            edu_content: the contents of the EDU
         """
 
         user_id = edu_content.pop("user_id")
@@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
 
         await self._handle_signing_key_updates(user_id)
 
-    async def _handle_signing_key_updates(self, user_id):
+    async def _handle_signing_key_updates(self, user_id: str) -> None:
         """Actually handle pending updates.
 
         Args:
-            user_id (string): the user whose updates we are processing
+            user_id: the user whose updates we are processing
         """
 
         device_handler = self.e2e_keys_handler.device_handler
@@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater:
                 # This can happen since we batch updates
                 return
 
-            device_ids = []
+            device_ids = []  # type: List[str]
 
             logger.info("pending updates: %r", pending_updates)
 
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f01b090772..622cae23be 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, List, Optional
 
 from synapse.api.errors import (
     Codes,
@@ -24,8 +25,12 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.opentracing import log_kv, trace
+from synapse.types import JsonDict
 from synapse.util.async_helpers import Linearizer
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
     The actual payload of the encrypted keys is completely opaque to the handler.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
         # Used to lock whenever a client is uploading key data.  This prevents collisions
@@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
         self._upload_linearizer = Linearizer("upload_room_keys_lock")
 
     @trace
-    async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def get_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> List[JsonDict]:
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
         See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
 
         Args:
-            user_id(str): the user whose keys we're getting
-            version(str): the version ID of the backup we're getting keys from
-            room_id(string): room ID to get keys for, for None to get keys for all rooms
-            session_id(string): session ID to get keys for, for None to get keys for all
+            user_id: the user whose keys we're getting
+            version: the version ID of the backup we're getting keys from
+            room_id: room ID to get keys for, for None to get keys for all rooms
+            session_id: session ID to get keys for, for None to get keys for all
                 sessions
         Raises:
             NotFoundError: if the backup version does not exist
         Returns:
-            A deferred list of dicts giving the session_data and message metadata for
+            A list of dicts giving the session_data and message metadata for
             these room keys.
         """
 
@@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
             return results
 
     @trace
-    async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def delete_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> JsonDict:
         """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
         room or a given session.
         See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
 
         Args:
-            user_id(str): the user whose backup we're deleting
-            version(str): the version ID of the backup we're deleting
-            room_id(string): room ID to delete keys for, for None to delete keys for all
+            user_id: the user whose backup we're deleting
+            version: the version ID of the backup we're deleting
+            room_id: room ID to delete keys for, for None to delete keys for all
                 rooms
-            session_id(string): session ID to delete keys for, for None to delete keys
+            session_id: session ID to delete keys for, for None to delete keys
                 for all sessions
         Raises:
             NotFoundError: if the backup version does not exist
@@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
             return {"etag": str(version_etag), "count": count}
 
     @trace
-    async def upload_room_keys(self, user_id, version, room_keys):
+    async def upload_room_keys(
+        self, user_id: str, version: str, room_keys: JsonDict
+    ) -> JsonDict:
         """Bulk upload a list of room keys into a given backup version, asserting
         that the given version is the current backup version.  room_keys are merged
         into the current backup as described in RoomKeysServlet.on_PUT().
 
         Args:
-            user_id(str): the user whose backup we're setting
-            version(str): the version ID of the backup we're updating
-            room_keys(dict): a nested dict describing the room_keys we're setting:
+            user_id: the user whose backup we're setting
+            version: the version ID of the backup we're updating
+            room_keys: a nested dict describing the room_keys we're setting:
 
         {
             "rooms": {
@@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
             return {"etag": str(version_etag), "count": count}
 
     @staticmethod
-    def _should_replace_room_key(current_room_key, room_key):
+    def _should_replace_room_key(
+        current_room_key: Optional[JsonDict], room_key: JsonDict
+    ) -> bool:
         """
         Determine whether to replace a given current_room_key (if any)
         with a newly uploaded room_key backup
 
         Args:
-            current_room_key (dict): Optional, the current room_key dict if any
-            room_key (dict): The new room_key dict which may or may not be fit to
+            current_room_key: Optional, the current room_key dict if any
+            room_key : The new room_key dict which may or may not be fit to
                 replace the current_room_key
 
         Returns:
@@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
         return True
 
     @trace
-    async def create_version(self, user_id, version_info):
+    async def create_version(self, user_id: str, version_info: JsonDict) -> str:
         """Create a new backup version.  This automatically becomes the new
         backup version for the user's keys; previous backups will no longer be
         writeable to.
 
         Args:
-            user_id(str): the user whose backup version we're creating
-            version_info(dict): metadata about the new version being created
+            user_id: the user whose backup version we're creating
+            version_info: metadata about the new version being created
 
         {
             "algorithm": "m.megolm_backup.v1",
@@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
         }
 
         Returns:
-            A deferred of a string that gives the new version number.
+            The new version number.
         """
 
         # TODO: Validate the JSON to make sure it has the right keys.
@@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
             )
             return new_version
 
-    async def get_version_info(self, user_id, version=None):
+    async def get_version_info(
+        self, user_id: str, version: Optional[str] = None
+    ) -> JsonDict:
         """Get the info about a given version of the user's backup
 
         Args:
-            user_id(str): the user whose current backup version we're querying
-            version(str): Optional; if None gives the most recent version
+            user_id: the user whose current backup version we're querying
+            version: Optional; if None gives the most recent version
                 otherwise a historical one.
         Raises:
             NotFoundError: if the requested backup version doesn't exist
         Returns:
-            A deferred of a info dict that gives the info about the new version.
+            A info dict that gives the info about the new version.
 
         {
             "version": "1234",
@@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
             return res
 
     @trace
-    async def delete_version(self, user_id, version=None):
+    async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
         """Deletes a given version of the user's e2e_room_keys backup
 
         Args:
@@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
                     raise
 
     @trace
-    async def update_version(self, user_id, version, version_info):
+    async def update_version(
+        self, user_id: str, version: str, version_info: JsonDict
+    ) -> JsonDict:
         """Update the info about a given version of the user's backup
 
         Args:
-            user_id(str): the user whose current backup version we're updating
-            version(str): the backup version we're updating
-            version_info(dict): the new information about the backup
+            user_id: the user whose current backup version we're updating
+            version: the backup version we're updating
+            version_info: the new information about the backup
         Raises:
             NotFoundError: if the requested backup version doesn't exist
         Returns:
-            A deferred of an empty dict.
+            An empty dict.
         """
         if "version" not in version_info:
             version_info["version"] = version
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fd8de8696d..dbdfd56ff5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1617,6 +1617,10 @@ class FederationHandler(BaseHandler):
         if event.state_key == self._server_notices_mxid:
             raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
+        # We retrieve the room member handler here as to not cause a cyclic dependency
+        member_handler = self.hs.get_room_member_handler()
+        member_handler.ratelimit_invite(event.room_id, event.state_key)
+
         # keep a record of the room version, if we don't yet know it.
         # (this may get overwritten if we later get a different room version in a
         # join dance).
@@ -2093,6 +2097,11 @@ class FederationHandler(BaseHandler):
         if event.type == EventTypes.GuestAccess and not context.rejected:
             await self.maybe_kick_guest_users(event)
 
+        # If we are going to send this event over federation we precaclculate
+        # the joined hosts.
+        if event.internal_metadata.get_send_on_behalf_of():
+            await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
         return context
 
     async def _check_for_soft_fail(
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index df29edeb83..71f11ef94a 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -15,9 +15,13 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set
 
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, get_domain_from_id
+from synapse.types import GroupID, JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
 
 
 class GroupsLocalWorkerHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.room_list_handler = hs.get_room_list_handler()
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
     get_group_role = _create_rerouter("get_group_role")
     get_group_roles = _create_rerouter("get_group_roles")
 
-    async def get_group_summary(self, group_id, requester_user_id):
+    async def get_group_summary(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the group summary for a group.
 
         If the group is remote we check that the users have valid attestations.
@@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
 
         return res
 
-    async def get_users_in_group(self, group_id, requester_user_id):
+    async def get_users_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get users in a group
         """
         if self.is_mine_id(group_id):
-            res = await self.groups_server_handler.get_users_in_group(
+            return await self.groups_server_handler.get_users_in_group(
                 group_id, requester_user_id
             )
-            return res
 
         group_server_name = get_domain_from_id(group_id)
 
@@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
 
         return res
 
-    async def get_joined_groups(self, user_id):
+    async def get_joined_groups(self, user_id: str) -> JsonDict:
         group_ids = await self.store.get_joined_groups(user_id)
         return {"groups": group_ids}
 
-    async def get_publicised_groups_for_user(self, user_id):
+    async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
         if self.hs.is_mine_id(user_id):
             result = await self.store.get_publicised_groups_for_user(user_id)
 
@@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
             # TODO: Verify attestations
             return {"groups": result}
 
-    async def bulk_get_publicised_groups(self, user_ids, proxy=True):
-        destinations = {}
+    async def bulk_get_publicised_groups(
+        self, user_ids: Iterable[str], proxy: bool = True
+    ) -> JsonDict:
+        destinations = {}  # type: Dict[str, Set[str]]
         local_users = set()
 
         for user_id in user_ids:
@@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
             raise SynapseError(400, "Some user_ids are not local")
 
         results = {}
-        failed_results = []
+        failed_results = []  # type: List[str]
         for destination, dest_user_ids in destinations.items():
             try:
                 r = await self.transport_client.bulk_get_publicised_groups(
@@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
 
 
 class GroupsLocalHandler(GroupsLocalWorkerHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         # Ensure attestations get renewed
@@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
     set_group_join_policy = _create_rerouter("set_group_join_policy")
 
-    async def create_group(self, group_id, user_id, content):
+    async def create_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Create a group
         """
 
@@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
             local_attestation = None
             remote_attestation = None
         else:
-            local_attestation = self.attestations.create_attestation(group_id, user_id)
-            content["attestation"] = local_attestation
-
-            content["user_profile"] = await self.profile_handler.get_profile(user_id)
-
-            try:
-                res = await self.transport_client.create_group(
-                    get_domain_from_id(group_id), group_id, user_id, content
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-            remote_attestation = res["attestation"]
-            await self.attestations.verify_attestation(
-                remote_attestation,
-                group_id=group_id,
-                user_id=user_id,
-                server_name=get_domain_from_id(group_id),
-            )
+            raise SynapseError(400, "Unable to create remote groups")
 
         is_publicised = content.get("publicise", False)
         token = await self.store.register_user_group_membership(
@@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def join_group(self, group_id, user_id, content):
+    async def join_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Request to join a group
         """
         if self.is_mine_id(group_id):
@@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return {}
 
-    async def accept_invite(self, group_id, user_id, content):
+    async def accept_invite(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Accept an invite to a group
         """
         if self.is_mine_id(group_id):
@@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return {}
 
-    async def invite(self, group_id, user_id, requester_user_id, config):
+    async def invite(
+        self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
+    ) -> JsonDict:
         """Invite a user to a group
         """
         content = {"requester_user_id": requester_user_id, "config": config}
@@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def on_invite(self, group_id, user_id, content):
+    async def on_invite(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """One of our users were invited to a group
         """
         # TODO: Support auto join and rejection
@@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
         return {"state": "invite", "user_profile": user_profile}
 
     async def remove_user_from_group(
-        self, group_id, user_id, requester_user_id, content
-    ):
+        self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Remove a user from a group
         """
         if user_id == requester_user_id:
@@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def user_removed_from_group(self, group_id, user_id, content):
+    async def user_removed_from_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> None:
         """One of our users was removed/kicked from a group
         """
         # TODO: Check if user in group
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index f61844d688..4f7137539b 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -27,9 +27,11 @@ from synapse.api.errors import (
     HttpResponseException,
     SynapseError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
+from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, Requester
 from synapse.util import json_decoder
 from synapse.util.hash import sha256_and_url_safe_base64
@@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler):
 
         self._web_client_location = hs.config.invite_client_location
 
+        # Ratelimiters for `/requestToken` endpoints.
+        self._3pid_validation_ratelimiter_ip = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+        )
+        self._3pid_validation_ratelimiter_address = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+        )
+
+    def ratelimit_request_token_requests(
+        self, request: SynapseRequest, medium: str, address: str,
+    ):
+        """Used to ratelimit requests to `/requestToken` by IP and address.
+
+        Args:
+            request: The associated request
+            medium: The type of threepid, e.g. "msisdn" or "email"
+            address: The actual threepid ID, e.g. the phone number or email address
+        """
+
+        self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
+        self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+
     async def threepid_from_creds(
         self, id_server: str, creds: Dict[str, str]
     ) -> Optional[JsonDict]:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9dfeab09cd..e2a7d567fa 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -432,6 +432,8 @@ class EventCreationHandler:
 
         self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
 
+        self._external_cache = hs.get_external_cache()
+
     async def create_event(
         self,
         requester: Requester,
@@ -939,6 +941,8 @@ class EventCreationHandler:
 
         await self.action_generator.handle_push_actions_for_event(event, context)
 
+        await self.cache_joined_hosts_for_event(event)
+
         try:
             # If we're a worker we need to hit out to the master.
             writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +982,44 @@ class EventCreationHandler:
             await self.store.remove_push_actions_from_staging(event.event_id)
             raise
 
+    async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+        """Precalculate the joined hosts at the event, when using Redis, so that
+        external federation senders don't have to recalculate it themselves.
+        """
+
+        if not self._external_cache.is_enabled():
+            return
+
+        # We actually store two mappings, event ID -> prev state group,
+        # state group -> joined hosts, which is much more space efficient
+        # than event ID -> joined hosts.
+        #
+        # Note: We have to cache event ID -> prev state group, as we don't
+        # store that in the DB.
+        #
+        # Note: We always set the state group -> joined hosts cache, even if
+        # we already set it, so that the expiry time is reset.
+
+        state_entry = await self.state.resolve_state_groups_for_events(
+            event.room_id, event_ids=event.prev_event_ids()
+        )
+
+        if state_entry.state_group:
+            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+            await self._external_cache.set(
+                "event_to_prev_state_group",
+                event.event_id,
+                state_entry.state_group,
+                expiry_ms=60 * 60 * 1000,
+            )
+            await self._external_cache.set(
+                "get_joined_hosts",
+                str(state_entry.state_group),
+                list(joined_hosts),
+                expiry_ms=60 * 60 * 1000,
+            )
+
     async def _validate_canonical_alias(
         self, directory_handler, room_alias_str: str, expected_room_id: str
     ) -> None:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ee27d99135..07b2187eb1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
+        self._invite_burst_count = (
+            hs.config.ratelimiting.rc_invites_per_room.burst_count
+        )
+
     async def upgrade_room(
         self, requester: Requester, old_room_id: str, new_version: RoomVersion
     ) -> str:
@@ -662,6 +666,9 @@ class RoomCreationHandler(BaseHandler):
             invite_3pid_list = []
             invite_list = []
 
+        if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
+            raise SynapseError(400, "Cannot invite so many users at once")
+
         await self.event_creation_handler.assert_accepted_privacy_policy(requester)
 
         power_level_content_override = config.get("power_level_content_override")
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e001e418f9..d335da6f19 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
 
+        self._invites_per_room_limiter = Ratelimiter(
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
+            burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
+        )
+        self._invites_per_user_limiter = Ratelimiter(
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
+            burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
+        )
+
         # This is only used to get at ratelimit function, and
         # maybe_kick_guest_users. It's fine there are multiple of these as
         # it doesn't store state.
@@ -144,6 +155,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         """
         raise NotImplementedError()
 
+    def ratelimit_invite(self, room_id: str, invitee_user_id: str):
+        """Ratelimit invites by room and by target user.
+        """
+        self._invites_per_room_limiter.ratelimit(room_id)
+        self._invites_per_user_limiter.ratelimit(invitee_user_id)
+
     async def _local_membership_update(
         self,
         requester: Requester,
@@ -387,8 +404,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 raise SynapseError(403, "This room has been blocked on this server")
 
         if effective_membership_state == Membership.INVITE:
+            target_id = target.to_string()
+            if ratelimit:
+                self.ratelimit_invite(room_id, target_id)
+
             # block any attempts to invite the server notices mxid
-            if target.to_string() == self._server_notices_mxid:
+            if target_id == self._server_notices_mxid:
                 raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
             block_invite = False
@@ -412,7 +433,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     block_invite = True
 
                 if not await self.spam_checker.user_may_invite(
-                    requester.user.to_string(), target.to_string(), room_id
+                    requester.user.to_string(), target_id, room_id
                 ):
                     logger.info("Blocking invite due to spam checker")
                     block_invite = True
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 66f1bbcfc4..94062e79cb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,23 +15,28 @@
 
 import itertools
 import logging
-from typing import Iterable
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
 
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
+from synapse.events import EventBase
 from synapse.storage.state import StateFilter
+from synapse.types import JsonDict, UserID
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SearchHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
 
         return historical_room_ids
 
-    async def search(self, user, content, batch=None):
+    async def search(
+        self, user: UserID, content: JsonDict, batch: Optional[str] = None
+    ) -> JsonDict:
         """Performs a full text search for a user.
 
         Args:
-            user (UserID)
-            content (dict): Search parameters
-            batch (str): The next_batch parameter. Used for pagination.
+            user
+            content: Search parameters
+            batch: The next_batch parameter. Used for pagination.
 
         Returns:
             dict to be returned to the client with results of search
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
         # If doing a subset of all rooms seearch, check if any of the rooms
         # are from an upgraded room, and search their contents as well
         if search_filter.rooms:
-            historical_room_ids = []
+            historical_room_ids = []  # type: List[str]
             for room_id in search_filter.rooms:
                 # Add any previous rooms to the search if they exist
                 ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
 
         rank_map = {}  # event_id -> rank of event
         allowed_events = []
-        room_groups = {}  # Holds result of grouping by room, if applicable
-        sender_group = {}  # Holds result of grouping by sender, if applicable
+        # Holds result of grouping by room, if applicable
+        room_groups = {}  # type: Dict[str, JsonDict]
+        # Holds result of grouping by sender, if applicable
+        sender_group = {}  # type: Dict[str, JsonDict]
 
         # Holds the next_batch for the entire result set if one of those exists
         global_next_batch = None
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
                 s["results"].append(e.event_id)
 
         elif order_by == "recent":
-            room_events = []
+            room_events = []  # type: List[EventBase]
             i = 0
 
             pagination_token = batch_token
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
 
         state_results = {}
         if include_state:
-            rooms = {e.room_id for e in allowed_events}
-            for room_id in rooms:
+            for room_id in {e.room_id for e in allowed_events}:
                 state = await self.state_handler.get_current_state(room_id)
                 state_results[room_id] = list(state.values())
 
-            state_results.values()
-
         # We're now about to serialize the events. We should not make any
         # blocking calls after this. Otherwise the 'age' will be wrong
 
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
 
         if state_results:
             s = {}
-            for room_id, state in state_results.items():
+            for room_id, state_events in state_results.items():
                 s[room_id] = await self._event_serializer.serialize_events(
-                    state, time_now
+                    state_events, time_now
                 )
 
             rooms_cat_res["state"] = s
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a5d67f828f..84af2dde7e 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -13,24 +13,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.types import Requester
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SetPasswordHandler(BaseHandler):
     """Handler which deals with changing user account passwords"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
-        self._password_policy_handler = hs.get_password_policy_handler()
 
     async def set_password(
         self,
@@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler):
         password_hash: str,
         logout_devices: bool,
         requester: Optional[Requester] = None,
-    ):
+    ) -> None:
         if not self.hs.config.password_localdb_enabled:
             raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
 
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index fb4f70e8e2..b3f9875358 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -14,15 +14,25 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
 class StateDeltasHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
-    async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+    async def _get_key_change(
+        self,
+        prev_event_id: Optional[str],
+        event_id: Optional[str],
+        key_name: str,
+        public_value: str,
+    ) -> Optional[bool]:
         """Given two events check if the `key_name` field in content changed
         from not matching `public_value` to doing so.
 
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index dc62b21c06..d261d7cd4e 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -12,13 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
+
+from typing_extensions import Counter as CounterType
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -31,7 +37,7 @@ class StatsHandler:
     Heavily derived from UserDirectoryHandler
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
@@ -44,7 +50,7 @@ class StatsHandler:
         self.stats_enabled = hs.config.stats_enabled
 
         # The current position in the current_state_delta stream
-        self.pos = None
+        self.pos = None  # type: Optional[int]
 
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
@@ -56,7 +62,7 @@ class StatsHandler:
             # we start populating stats
             self.clock.call_later(0, self.notify_new_event)
 
-    def notify_new_event(self):
+    def notify_new_event(self) -> None:
         """Called when there may be more deltas to process
         """
         if not self.stats_enabled or self._is_processing:
@@ -72,7 +78,7 @@ class StatsHandler:
 
         run_as_background_process("stats.notify_new_event", process)
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         # If self.pos is None then means we haven't fetched it from DB
         if self.pos is None:
             self.pos = await self.store.get_stats_positions()
@@ -110,10 +116,10 @@ class StatsHandler:
             )
 
             for room_id, fields in room_count.items():
-                room_deltas.setdefault(room_id, {}).update(fields)
+                room_deltas.setdefault(room_id, Counter()).update(fields)
 
             for user_id, fields in user_count.items():
-                user_deltas.setdefault(user_id, {}).update(fields)
+                user_deltas.setdefault(user_id, Counter()).update(fields)
 
             logger.debug("room_deltas: %s", room_deltas)
             logger.debug("user_deltas: %s", user_deltas)
@@ -131,19 +137,20 @@ class StatsHandler:
 
             self.pos = max_pos
 
-    async def _handle_deltas(self, deltas):
+    async def _handle_deltas(
+        self, deltas: Iterable[JsonDict]
+    ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
         """Called with the state deltas to process
 
         Returns:
-            tuple[dict[str, Counter], dict[str, counter]]
             Two dicts: the room deltas and the user deltas,
             mapping from room/user ID to changes in the various fields.
         """
 
-        room_to_stats_deltas = {}
-        user_to_stats_deltas = {}
+        room_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
+        user_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
 
-        room_to_state_updates = {}
+        room_to_state_updates = {}  # type: Dict[str, Dict[str, Any]]
 
         for delta in deltas:
             typ = delta["type"]
@@ -173,7 +180,7 @@ class StatsHandler:
                 )
                 continue
 
-            event_content = {}
+            event_content = {}  # type: JsonDict
 
             sender = None
             if event_id is not None:
@@ -257,13 +264,13 @@ class StatsHandler:
                     )
 
                     if has_changed_joinedness:
-                        delta = +1 if membership == Membership.JOIN else -1
+                        membership_delta = +1 if membership == Membership.JOIN else -1
 
                         user_to_stats_deltas.setdefault(user_id, Counter())[
                             "joined_rooms"
-                        ] += delta
+                        ] += membership_delta
 
-                        room_stats_delta["local_users_in_room"] += delta
+                        room_stats_delta["local_users_in_room"] += membership_delta
 
             elif typ == EventTypes.Create:
                 room_state["is_federatable"] = (
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..3f0dfc7a74 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@
 import logging
 import random
 from collections import namedtuple
-from typing import TYPE_CHECKING, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.streams import TypingStream
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -65,17 +65,17 @@ class FollowerTypingHandler:
             )
 
         # map room IDs to serial numbers
-        self._room_serials = {}
+        self._room_serials = {}  # type: Dict[str, int]
         # map room IDs to sets of users currently typing
-        self._room_typing = {}
+        self._room_typing = {}  # type: Dict[str, Set[str]]
 
-        self._member_last_federation_poke = {}
+        self._member_last_federation_poke = {}  # type: Dict[RoomMember, int]
         self.wheel_timer = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
 
         self.clock.looping_call(self._handle_timeouts, 5000)
 
-    def _reset(self):
+    def _reset(self) -> None:
         """Reset the typing handler's data caches.
         """
         # map room IDs to serial numbers
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
         self._member_last_federation_poke = {}
         self.wheel_timer = WheelTimer(bucket_size=5000)
 
-    def _handle_timeouts(self):
+    def _handle_timeouts(self) -> None:
         logger.debug("Checking for typing timeouts")
 
         now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
         for member in members:
             self._handle_timeout_for_member(now, member)
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         if not self.is_typing(member):
             # Nothing to do if they're no longer typing
             return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
         # each person typing.
         self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
 
-    def is_typing(self, member):
+    def is_typing(self, member: RoomMember) -> bool:
         return member.user_id in self._room_typing.get(member.room_id, [])
 
-    async def _push_remote(self, member, typing):
+    async def _push_remote(self, member: RoomMember, typing: bool) -> None:
         if not self.federation:
             return
 
@@ -148,7 +148,7 @@ class FollowerTypingHandler:
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
+    ) -> None:
         """Should be called whenever we receive updates for typing stream.
         """
 
@@ -178,7 +178,7 @@ class FollowerTypingHandler:
 
     async def _send_changes_in_typing_to_remotes(
         self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
-    ):
+    ) -> None:
         """Process a change in typing of a room from replication, sending EDUs
         for any local users.
         """
@@ -194,12 +194,12 @@ class FollowerTypingHandler:
             if self.is_mine_id(user_id):
                 await self._push_remote(RoomMember(room_id, user_id), False)
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         return self._latest_room_serial
 
 
 class TypingWriterHandler(FollowerTypingHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
-        self._member_typing_until = {}  # clock time we expect to stop
+        # clock time we expect to stop
+        self._member_typing_until = {}  # type: Dict[RoomMember, int]
 
         # caches which room_ids changed at which serials
         self._typing_stream_change_cache = StreamChangeCache(
             "TypingStreamChangeCache", self._latest_room_serial
         )
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         super()._handle_timeout_for_member(now, member)
 
         if not self.is_typing(member):
@@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             self._stopped_typing(member)
             return
 
-    async def started_typing(self, target_user, requester, room_id, timeout):
+    async def started_typing(
+        self, target_user: UserID, requester: Requester, room_id: str, timeout: int
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         if was_present:
             # No point sending another notification
-            return None
+            return
 
         self._push_update(member=member, typing=True)
 
-    async def stopped_typing(self, target_user, requester, room_id):
+    async def stopped_typing(
+        self, target_user: UserID, requester: Requester, room_id: str
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._stopped_typing(member)
 
-    def user_left_room(self, user, room_id):
+    def user_left_room(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
         if self.is_mine_id(user_id):
             member = RoomMember(room_id=room_id, user_id=user_id)
             self._stopped_typing(member)
 
-    def _stopped_typing(self, member):
+    def _stopped_typing(self, member: RoomMember) -> None:
         if member.user_id not in self._room_typing.get(member.room_id, set()):
             # No point
-            return None
+            return
 
         self._member_typing_until.pop(member, None)
         self._member_last_federation_poke.pop(member, None)
 
         self._push_update(member=member, typing=False)
 
-    def _push_update(self, member, typing):
+    def _push_update(self, member: RoomMember, typing: bool) -> None:
         if self.hs.is_mine_id(member.user_id):
             # Only send updates for changes to our own users.
             run_as_background_process(
@@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._push_update_local(member=member, typing=typing)
 
-    async def _recv_edu(self, origin, content):
+    async def _recv_edu(self, origin: str, content: JsonDict) -> None:
         room_id = content["room_id"]
         user_id = content["user_id"]
 
@@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
             self._push_update_local(member=member, typing=content["typing"])
 
-    def _push_update_local(self, member, typing):
+    def _push_update_local(self, member: RoomMember, typing: bool) -> None:
         room_set = self._room_typing.setdefault(member.room_id, set())
         if typing:
             room_set.add(member.user_id)
@@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
             last_id
-        )
+        )  # type: Optional[Iterable[str]]
 
         if changed_rooms is None:
             changed_rooms = self._room_serials
@@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
+    ) -> None:
         # The writing process should never get updates from replication.
         raise Exception("Typing writer instance got typing info over replication")
 
 
 class TypingNotificationEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.clock = hs.get_clock()
         # We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +432,7 @@ class TypingNotificationEventSource:
         #
         self.get_typing_handler = hs.get_typing_handler
 
-    def _make_event_for(self, room_id):
+    def _make_event_for(self, room_id: str) -> JsonDict:
         typing = self.get_typing_handler()._room_typing[room_id]
         return {
             "type": "m.typing",
@@ -462,7 +467,9 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    async def get_new_events(self, from_key, room_ids, **kwargs):
+    async def get_new_events(
+        self, from_key: int, room_ids: Iterable[str], **kwargs
+    ) -> Tuple[List[JsonDict], int]:
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
             handler = self.get_typing_handler()
@@ -478,5 +485,5 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    def get_current_key(self):
+    def get_current_key(self) -> int:
         return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d4651c8348..8aedf5072e 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler):
         if self.pos is None:
             self.pos = await self.store.get_user_directory_stream_pos()
 
-        # If still None then the initial background update hasn't happened yet
-        if self.pos is None:
-            return None
-
         # Loop round handling deltas until we're up to date
         while True:
             with Measure(self.clock, "user_dir_delta"):
@@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
 
                     if change:  # The user joined
                         event = await self.store.get_event(event_id, allow_none=True)
+                        # It isn't expected for this event to not exist, but we
+                        # don't want the entire background process to break.
+                        if event is None:
+                            continue
+
                         profile = ProfileInfo(
                             avatar_url=event.content.get("avatar_url"),
                             display_name=event.content.get("displayname"),
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index ab586c318c..0538350f38 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -791,7 +791,7 @@ def tag_args(func):
 
     @wraps(func)
     def _tag_args_inner(*args, **kwargs):
-        argspec = inspect.getargspec(func)
+        argspec = inspect.getfullargspec(func)
         for i, arg in enumerate(argspec.args[1:]):
             set_tag("ARG_" + arg, args[i])
         set_tag("args", args[len(argspec.args) :])
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 4d875dcb91..8a6dcff30d 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -267,9 +267,21 @@ class Mailer:
             fallback_to_members=True,
         )
 
-        summary_text = await self.make_summary_text(
-            notifs_by_room, state_by_room, notif_events, user_id, reason
-        )
+        if len(notifs_by_room) == 1:
+            # Only one room has new stuff
+            room_id = list(notifs_by_room.keys())[0]
+
+            summary_text = await self.make_summary_text_single_room(
+                room_id,
+                notifs_by_room[room_id],
+                state_by_room[room_id],
+                notif_events,
+                user_id,
+            )
+        else:
+            summary_text = await self.make_summary_text(
+                notifs_by_room, state_by_room, notif_events, reason
+            )
 
         template_vars = {
             "user_display_name": user_display_name,
@@ -492,139 +504,178 @@ class Mailer:
         if "url" in event.content:
             messagevars["image_url"] = event.content["url"]
 
-    async def make_summary_text(
+    async def make_summary_text_single_room(
         self,
-        notifs_by_room: Dict[str, List[Dict[str, Any]]],
-        room_state_ids: Dict[str, StateMap[str]],
+        room_id: str,
+        notifs: List[Dict[str, Any]],
+        room_state_ids: StateMap[str],
         notif_events: Dict[str, EventBase],
         user_id: str,
-        reason: Dict[str, Any],
-    ):
-        if len(notifs_by_room) == 1:
-            # Only one room has new stuff
-            room_id = list(notifs_by_room.keys())[0]
+    ) -> str:
+        """
+        Make a summary text for the email when only a single room has notifications.
 
-            # If the room has some kind of name, use it, but we don't
-            # want the generated-from-names one here otherwise we'll
-            # end up with, "new message from Bob in the Bob room"
-            room_name = await calculate_room_name(
-                self.store, room_state_ids[room_id], user_id, fallback_to_members=False
-            )
+        Args:
+            room_id: The ID of the room.
+            notifs: The notifications for this room.
+            room_state_ids: The state map for the room.
+            notif_events: A map of event ID -> notification event.
+            user_id: The user receiving the notification.
+
+        Returns:
+            The summary text.
+        """
+        # If the room has some kind of name, use it, but we don't
+        # want the generated-from-names one here otherwise we'll
+        # end up with, "new message from Bob in the Bob room"
+        room_name = await calculate_room_name(
+            self.store, room_state_ids, user_id, fallback_to_members=False
+        )
 
-            # See if one of the notifs is an invite event for the user
-            invite_event = None
-            for n in notifs_by_room[room_id]:
-                ev = notif_events[n["event_id"]]
-                if ev.type == EventTypes.Member and ev.state_key == user_id:
-                    if ev.content.get("membership") == Membership.INVITE:
-                        invite_event = ev
-                        break
-
-            if invite_event:
-                inviter_member_event_id = room_state_ids[room_id].get(
-                    ("m.room.member", invite_event.sender)
-                )
-                inviter_name = invite_event.sender
-                if inviter_member_event_id:
-                    inviter_member_event = await self.store.get_event(
-                        inviter_member_event_id, allow_none=True
-                    )
-                    if inviter_member_event:
-                        inviter_name = name_from_member_event(inviter_member_event)
-
-                if room_name is None:
-                    return self.email_subjects.invite_from_person % {
-                        "person": inviter_name,
-                        "app": self.app_name,
-                    }
-                else:
-                    return self.email_subjects.invite_from_person_to_room % {
-                        "person": inviter_name,
-                        "room": room_name,
-                        "app": self.app_name,
-                    }
+        # See if one of the notifs is an invite event for the user
+        invite_event = None
+        for n in notifs:
+            ev = notif_events[n["event_id"]]
+            if ev.type == EventTypes.Member and ev.state_key == user_id:
+                if ev.content.get("membership") == Membership.INVITE:
+                    invite_event = ev
+                    break
 
-            sender_name = None
-            if len(notifs_by_room[room_id]) == 1:
-                # There is just the one notification, so give some detail
-                event = notif_events[notifs_by_room[room_id][0]["event_id"]]
-                if ("m.room.member", event.sender) in room_state_ids[room_id]:
-                    state_event_id = room_state_ids[room_id][
-                        ("m.room.member", event.sender)
-                    ]
-                    state_event = await self.store.get_event(state_event_id)
-                    sender_name = name_from_member_event(state_event)
-
-                if sender_name is not None and room_name is not None:
-                    return self.email_subjects.message_from_person_in_room % {
-                        "person": sender_name,
-                        "room": room_name,
-                        "app": self.app_name,
-                    }
-                elif sender_name is not None:
-                    return self.email_subjects.message_from_person % {
-                        "person": sender_name,
-                        "app": self.app_name,
-                    }
-            else:
-                # There's more than one notification for this room, so just
-                # say there are several
-                if room_name is not None:
-                    return self.email_subjects.messages_in_room % {
-                        "room": room_name,
-                        "app": self.app_name,
-                    }
-                else:
-                    # If the room doesn't have a name, say who the messages
-                    # are from explicitly to avoid, "messages in the Bob room"
-                    sender_ids = list(
-                        {
-                            notif_events[n["event_id"]].sender
-                            for n in notifs_by_room[room_id]
-                        }
-                    )
-
-                    member_events = await self.store.get_events(
-                        [
-                            room_state_ids[room_id][("m.room.member", s)]
-                            for s in sender_ids
-                        ]
-                    )
-
-                    return self.email_subjects.messages_from_person % {
-                        "person": descriptor_from_member_events(member_events.values()),
-                        "app": self.app_name,
-                    }
-        else:
-            # Stuff's happened in multiple different rooms
+        if invite_event:
+            inviter_member_event_id = room_state_ids.get(
+                ("m.room.member", invite_event.sender)
+            )
+            inviter_name = invite_event.sender
+            if inviter_member_event_id:
+                inviter_member_event = await self.store.get_event(
+                    inviter_member_event_id, allow_none=True
+                )
+                if inviter_member_event:
+                    inviter_name = name_from_member_event(inviter_member_event)
 
-            # ...but we still refer to the 'reason' room which triggered the mail
-            if reason["room_name"] is not None:
-                return self.email_subjects.messages_in_room_and_others % {
-                    "room": reason["room_name"],
+            if room_name is None:
+                return self.email_subjects.invite_from_person % {
+                    "person": inviter_name,
                     "app": self.app_name,
                 }
-            else:
-                # If the reason room doesn't have a name, say who the messages
-                # are from explicitly to avoid, "messages in the Bob room"
-                room_id = reason["room_id"]
-
-                sender_ids = list(
-                    {
-                        notif_events[n["event_id"]].sender
-                        for n in notifs_by_room[room_id]
-                    }
-                )
 
-                member_events = await self.store.get_events(
-                    [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
-                )
+            return self.email_subjects.invite_from_person_to_room % {
+                "person": inviter_name,
+                "room": room_name,
+                "app": self.app_name,
+            }
+
+        if len(notifs) == 1:
+            # There is just the one notification, so give some detail
+            sender_name = None
+            event = notif_events[notifs[0]["event_id"]]
+            if ("m.room.member", event.sender) in room_state_ids:
+                state_event_id = room_state_ids[("m.room.member", event.sender)]
+                state_event = await self.store.get_event(state_event_id)
+                sender_name = name_from_member_event(state_event)
+
+            if sender_name is not None and room_name is not None:
+                return self.email_subjects.message_from_person_in_room % {
+                    "person": sender_name,
+                    "room": room_name,
+                    "app": self.app_name,
+                }
+            elif sender_name is not None:
+                return self.email_subjects.message_from_person % {
+                    "person": sender_name,
+                    "app": self.app_name,
+                }
 
-                return self.email_subjects.messages_from_person_and_others % {
-                    "person": descriptor_from_member_events(member_events.values()),
+            # The sender is unknown, just use the room name (or ID).
+            return self.email_subjects.messages_in_room % {
+                "room": room_name or room_id,
+                "app": self.app_name,
+            }
+        else:
+            # There's more than one notification for this room, so just
+            # say there are several
+            if room_name is not None:
+                return self.email_subjects.messages_in_room % {
+                    "room": room_name,
                     "app": self.app_name,
                 }
 
+            return await self.make_summary_text_from_member_events(
+                room_id, notifs, room_state_ids, notif_events
+            )
+
+    async def make_summary_text(
+        self,
+        notifs_by_room: Dict[str, List[Dict[str, Any]]],
+        room_state_ids: Dict[str, StateMap[str]],
+        notif_events: Dict[str, EventBase],
+        reason: Dict[str, Any],
+    ) -> str:
+        """
+        Make a summary text for the email when multiple rooms have notifications.
+
+        Args:
+            notifs_by_room: A map of room ID to the notifications for that room.
+            room_state_ids: A map of room ID to the state map for that room.
+            notif_events: A map of event ID -> notification event.
+            reason: The reason this notification is being sent.
+
+        Returns:
+            The summary text.
+        """
+        # Stuff's happened in multiple different rooms
+        # ...but we still refer to the 'reason' room which triggered the mail
+        if reason["room_name"] is not None:
+            return self.email_subjects.messages_in_room_and_others % {
+                "room": reason["room_name"],
+                "app": self.app_name,
+            }
+
+        room_id = reason["room_id"]
+        return await self.make_summary_text_from_member_events(
+            room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events
+        )
+
+    async def make_summary_text_from_member_events(
+        self,
+        room_id: str,
+        notifs: List[Dict[str, Any]],
+        room_state_ids: StateMap[str],
+        notif_events: Dict[str, EventBase],
+    ) -> str:
+        """
+        Make a summary text for the email when only a single room has notifications.
+
+        Args:
+            room_id: The ID of the room.
+            notifs: The notifications for this room.
+            room_state_ids: The state map for the room.
+            notif_events: A map of event ID -> notification event.
+
+        Returns:
+            The summary text.
+        """
+        # If the room doesn't have a name, say who the messages
+        # are from explicitly to avoid, "messages in the Bob room"
+        sender_ids = {notif_events[n["event_id"]].sender for n in notifs}
+
+        member_events = await self.store.get_events(
+            [room_state_ids[("m.room.member", s)] for s in sender_ids]
+        )
+
+        # There was a single sender.
+        if len(sender_ids) == 1:
+            return self.email_subjects.messages_from_person % {
+                "person": descriptor_from_member_events(member_events.values()),
+                "app": self.app_name,
+            }
+
+        # There was more than one sender, use the first one and a tweaked template.
+        return self.email_subjects.messages_from_person_and_others % {
+            "person": descriptor_from_member_events(list(member_events.values())[:1]),
+            "app": self.app_name,
+        }
+
     def make_room_link(self, room_id: str) -> str:
         if self.hs.config.email_riot_base_url:
             base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
@@ -668,6 +719,15 @@ class Mailer:
 
 
 def safe_markup(raw_html: str) -> jinja2.Markup:
+    """
+    Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
+
+    Args
+        raw_html: Unsafe HTML.
+
+    Returns:
+        A Markup object ready to safely use in a Jinja template.
+    """
     return jinja2.Markup(
         bleach.linkify(
             bleach.clean(
@@ -684,8 +744,13 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
 
 def safe_text(raw_text: str) -> jinja2.Markup:
     """
-    Process text: treat it as HTML but escape any tags (ie. just escape the
-    HTML) then linkify it.
+    Sanitise text (escape any HTML tags), and then linkify any bare URLs.
+
+    Args
+        raw_text: Unsafe text which might include HTML markup.
+
+    Returns:
+        A Markup object ready to safely use in a Jinja template.
     """
     return jinja2.Markup(
         bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 7e50341d74..04c2c1482c 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -17,7 +17,7 @@ import logging
 import re
 from typing import TYPE_CHECKING, Dict, Iterable, Optional
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.types import StateMap
 
@@ -63,7 +63,7 @@ async def calculate_room_name(
         m_room_name = await store.get_event(
             room_state_ids[(EventTypes.Name, "")], allow_none=True
         )
-        if m_room_name and m_room_name.content and m_room_name.content["name"]:
+        if m_room_name and m_room_name.content and m_room_name.content.get("name"):
             return m_room_name.content["name"]
 
     # does it have a canonical alias?
@@ -74,15 +74,11 @@ async def calculate_room_name(
         if (
             canon_alias
             and canon_alias.content
-            and canon_alias.content["alias"]
+            and canon_alias.content.get("alias")
             and _looks_like_an_alias(canon_alias.content["alias"])
         ):
             return canon_alias.content["alias"]
 
-    # at this point we're going to need to search the state by all state keys
-    # for an event type, so rearrange the data structure
-    room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
-
     if not fallback_to_members:
         return None
 
@@ -94,7 +90,7 @@ async def calculate_room_name(
 
     if (
         my_member_event is not None
-        and my_member_event.content["membership"] == "invite"
+        and my_member_event.content.get("membership") == Membership.INVITE
     ):
         if (EventTypes.Member, my_member_event.sender) in room_state_ids:
             inviter_member_event = await store.get_event(
@@ -111,6 +107,10 @@ async def calculate_room_name(
         else:
             return "Room Invite"
 
+    # at this point we're going to need to search the state by all state keys
+    # for an event type, so rearrange the data structure
+    room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
+
     # we're going to have to generate a name based on who's in the room,
     # so find out who is in the room that isn't the user.
     if EventTypes.Member in room_state_bytype_ids:
@@ -120,8 +120,8 @@ async def calculate_room_name(
         all_members = [
             ev
             for ev in member_events.values()
-            if ev.content["membership"] == "join"
-            or ev.content["membership"] == "invite"
+            if ev.content.get("membership") == Membership.JOIN
+            or ev.content.get("membership") == Membership.INVITE
         ]
         # Sort the member events oldest-first so the we name people in the
         # order the joined (it should at least be deterministic rather than
@@ -194,11 +194,7 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
 
 
 def name_from_member_event(member_event: EventBase) -> str:
-    if (
-        member_event.content
-        and "displayname" in member_event.content
-        and member_event.content["displayname"]
-    ):
+    if member_event.content and member_event.content.get("displayname"):
         return member_event.content["displayname"]
     return member_event.state_key
 
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
new file mode 100644
index 0000000000..34fa3ff5b3
--- /dev/null
+++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from prometheus_client import Counter
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import json_decoder, json_encoder
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+set_counter = Counter(
+    "synapse_external_cache_set",
+    "Number of times we set a cache",
+    labelnames=["cache_name"],
+)
+
+get_counter = Counter(
+    "synapse_external_cache_get",
+    "Number of times we get a cache",
+    labelnames=["cache_name", "hit"],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalCache:
+    """A cache backed by an external Redis. Does nothing if no Redis is
+    configured.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._redis_connection = hs.get_outbound_redis_connection()
+
+    def _get_redis_key(self, cache_name: str, key: str) -> str:
+        return "cache_v1:%s:%s" % (cache_name, key)
+
+    def is_enabled(self) -> bool:
+        """Whether the external cache is used or not.
+
+        It's safe to use the cache when this returns false, the methods will
+        just no-op, but the function is useful to avoid doing unnecessary work.
+        """
+        return self._redis_connection is not None
+
+    async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+        """Add the key/value to the named cache, with the expiry time given.
+        """
+
+        if self._redis_connection is None:
+            return
+
+        set_counter.labels(cache_name).inc()
+
+        # txredisapi requires the value to be string, bytes or numbers, so we
+        # encode stuff in JSON.
+        encoded_value = json_encoder.encode(value)
+
+        logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
+
+        return await make_deferred_yieldable(
+            self._redis_connection.set(
+                self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
+            )
+        )
+
+    async def get(self, cache_name: str, key: str) -> Optional[Any]:
+        """Look up a key/value in the named cache.
+        """
+
+        if self._redis_connection is None:
+            return None
+
+        result = await make_deferred_yieldable(
+            self._redis_connection.get(self._get_redis_key(cache_name, key))
+        )
+
+        logger.debug("Got cache result %s %s: %r", cache_name, key, result)
+
+        get_counter.labels(cache_name, result is not None).inc()
+
+        if not result:
+            return None
+
+        # For some reason the integers get magically converted back to integers
+        if isinstance(result, int):
+            return result
+
+        return json_decoder.decode(result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 317796d5e0..8ea8dcd587 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 import logging
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Dict,
@@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
     TypingStream,
 )
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -88,7 +92,7 @@ class ReplicationCommandHandler:
     back out to connections.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._replication_data_handler = hs.get_replication_data_handler()
         self._presence_handler = hs.get_presence_handler()
         self._store = hs.get_datastore()
@@ -282,13 +286,6 @@ class ReplicationCommandHandler:
         if hs.config.redis.redis_enabled:
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
-                lazyConnection,
-            )
-
-            logger.info(
-                "Connecting to redis (host=%r port=%r)",
-                hs.config.redis_host,
-                hs.config.redis_port,
             )
 
             # First let's ensure that we have a ReplicationStreamer started.
@@ -299,13 +296,7 @@ class ReplicationCommandHandler:
             # connection after SUBSCRIBE is called).
 
             # First create the connection for sending commands.
-            outbound_redis_connection = lazyConnection(
-                reactor=hs.get_reactor(),
-                host=hs.config.redis_host,
-                port=hs.config.redis_port,
-                password=hs.config.redis.redis_password,
-                reconnect=True,
-            )
+            outbound_redis_connection = hs.get_outbound_redis_connection()
 
             # Now create the factory/connection for the subscription stream.
             self._factory = RedisDirectTcpReplicationClientFactory(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..fdd087683b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Type, cast
 
 import txredisapi
 
@@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import (
     BackgroundProcessLoggingContext,
     run_as_background_process,
+    wrap_as_background_process,
 )
 from synapse.replication.tcp.commands import (
     Command,
@@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     immediately after initialisation.
 
     Attributes:
-        handler: The command handler to handle incoming commands.
-        stream_name: The *redis* stream name to subscribe to and publish from
-            (not anything to do with Synapse replication streams).
-        outbound_redis_connection: The connection to redis to use to send
+        synapse_handler: The command handler to handle incoming commands.
+        synapse_stream_name: The *redis* stream name to subscribe to and publish
+            from (not anything to do with Synapse replication streams).
+        synapse_outbound_redis_connection: The connection to redis to use to send
             commands.
     """
 
-    handler = None  # type: ReplicationCommandHandler
-    stream_name = None  # type: str
-    outbound_redis_connection = None  # type: txredisapi.RedisProtocol
+    synapse_handler = None  # type: ReplicationCommandHandler
+    synapse_stream_name = None  # type: str
+    synapse_outbound_redis_connection = None  # type: txredisapi.RedisProtocol
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # it's important to make sure that we only send the REPLICATE command once we
         # have successfully subscribed to the stream - otherwise we might miss the
         # POSITION response sent back by the other end.
-        logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
-        await make_deferred_yieldable(self.subscribe(self.stream_name))
+        logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
+        await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
         logger.info(
             "Successfully subscribed to redis stream, sending REPLICATE command"
         )
-        self.handler.new_connection(self)
+        self.synapse_handler.new_connection(self)
         await self._async_send_command(ReplicateCommand())
         logger.info("REPLICATE successfully sent")
 
         # We send out our positions when there is a new connection in case the
         # other side missed updates. We do this for Redis connections as the
         # otherside won't know we've connected and so won't issue a REPLICATE.
-        self.handler.send_positions_to_connection(self)
+        self.synapse_handler.send_positions_to_connection(self)
 
     def messageReceived(self, pattern: str, channel: str, message: str):
         """Received a message from redis.
@@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
             cmd: received command
         """
 
-        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
         if not cmd_func:
             logger.warning("Unhandled command: %r", cmd)
             return
@@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     def connectionLost(self, reason):
         logger.info("Lost connection to redis")
         super().connectionLost(reason)
-        self.handler.lost_connection(self)
+        self.synapse_handler.lost_connection(self)
 
         # mark the logging context as finished
         self._logging_context.__exit__(None, None, None)
@@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
         await make_deferred_yieldable(
-            self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+            self.synapse_outbound_redis_connection.publish(
+                self.synapse_stream_name, encoded_string
+            )
+        )
+
+
+class SynapseRedisFactory(txredisapi.RedisFactory):
+    """A subclass of RedisFactory that periodically sends pings to ensure that
+    we detect dead connections.
+    """
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        uuid: str,
+        dbid: Optional[int],
+        poolsize: int,
+        isLazy: bool = False,
+        handler: Type = txredisapi.ConnectionHandler,
+        charset: str = "utf-8",
+        password: Optional[str] = None,
+        replyTimeout: int = 30,
+        convertNumbers: Optional[int] = True,
+    ):
+        super().__init__(
+            uuid=uuid,
+            dbid=dbid,
+            poolsize=poolsize,
+            isLazy=isLazy,
+            handler=handler,
+            charset=charset,
+            password=password,
+            replyTimeout=replyTimeout,
+            convertNumbers=convertNumbers,
         )
 
+        hs.get_clock().looping_call(self._send_ping, 30 * 1000)
+
+    @wrap_as_background_process("redis_ping")
+    async def _send_ping(self):
+        for connection in self.pool:
+            try:
+                await make_deferred_yieldable(connection.ping())
+            except Exception:
+                logger.warning("Failed to send ping to a redis connection")
 
-class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+
+class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
     """This is a reconnecting factory that connects to redis and immediately
     subscribes to a stream.
 
@@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
         self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
     ):
 
-        super().__init__()
-
-        # This sets the password on the RedisFactory base class (as
-        # SubscriberFactory constructor doesn't pass it through).
-        self.password = hs.config.redis.redis_password
+        super().__init__(
+            hs,
+            uuid="subscriber",
+            dbid=None,
+            poolsize=1,
+            replyTimeout=30,
+            password=hs.config.redis.redis_password,
+        )
 
-        self.handler = hs.get_tcp_replication()
-        self.stream_name = hs.hostname
+        self.synapse_handler = hs.get_tcp_replication()
+        self.synapse_stream_name = hs.hostname
 
-        self.outbound_redis_connection = outbound_redis_connection
+        self.synapse_outbound_redis_connection = outbound_redis_connection
 
     def buildProtocol(self, addr):
-        p = super().buildProtocol(addr)  # type: RedisSubscriber
+        p = super().buildProtocol(addr)
+        p = cast(RedisSubscriber, p)
 
         # We do this here rather than add to the constructor of `RedisSubcriber`
         # as to do so would involve overriding `buildProtocol` entirely, however
         # the base method does some other things than just instantiating the
         # protocol.
-        p.handler = self.handler
-        p.outbound_redis_connection = self.outbound_redis_connection
-        p.stream_name = self.stream_name
-        p.password = self.password
+        p.synapse_handler = self.synapse_handler
+        p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
+        p.synapse_stream_name = self.synapse_stream_name
 
         return p
 
 
 def lazyConnection(
-    reactor,
+    hs: "HomeServer",
     host: str = "localhost",
     port: int = 6379,
     dbid: Optional[int] = None,
     reconnect: bool = True,
-    charset: str = "utf-8",
     password: Optional[str] = None,
-    connectTimeout: Optional[int] = None,
-    replyTimeout: Optional[int] = None,
-    convertNumbers: bool = True,
+    replyTimeout: int = 30,
 ) -> txredisapi.RedisProtocol:
-    """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
-    reactor.
+    """Creates a connection to Redis that is lazily set up and reconnects if the
+    connections is lost.
     """
 
-    isLazy = True
-    poolsize = 1
-
     uuid = "%s:%d" % (host, port)
-    factory = txredisapi.RedisFactory(
-        uuid,
-        dbid,
-        poolsize,
-        isLazy,
-        txredisapi.ConnectionHandler,
-        charset,
-        password,
-        replyTimeout,
-        convertNumbers,
+    factory = SynapseRedisFactory(
+        hs,
+        uuid=uuid,
+        dbid=dbid,
+        poolsize=1,
+        isLazy=True,
+        handler=txredisapi.ConnectionHandler,
+        password=password,
+        replyTimeout=replyTimeout,
     )
     factory.continueTrying = reconnect
-    for x in range(poolsize):
-        reactor.connectTCP(host, port, factory, connectTimeout)
+
+    reactor = hs.get_reactor()
+    reactor.connectTCP(host, port, factory, 30)
 
     return factory.handler
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
index a75c73a142..c9bd4bef20 100644
--- a/synapse/res/templates/sso_auth_bad_user.html
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -12,7 +12,7 @@
         <header>
             <h1>That doesn't look right</h1>
             <p>
-                <strong>We were unable to validate your {{ server_name | e }} account</strong>
+                <strong>We were unable to validate your {{ server_name }} account</strong>
                 via single&nbsp;sign&#8209;on&nbsp;(SSO), because the SSO Identity
                 Provider returned different details than when you logged in.
             </p>
diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html
index d572ab87f7..790470fb59 100644
--- a/synapse/res/templates/sso_auth_confirm.html
+++ b/synapse/res/templates/sso_auth_confirm.html
@@ -12,7 +12,7 @@
         <header>
             <h1>Confirm it's you to continue</h1>
             <p>
-                A client is trying to {{ description | e }}. To confirm this action
+                A client is trying to {{ description }}. To confirm this action
                 re-authorize your account with single sign-on.
             </p>
             <p><strong>
@@ -20,8 +20,8 @@
             </strong></p>
         </header>
         <main>
-            <a href="{{ redirect_url | e }}" class="primary-button"/>
-                Continue with {{ idp.idp_name | e }}
+            <a href="{{ redirect_url }}" class="primary-button"/>
+                Continue with {{ idp.idp_name }}
             </a>
         </main>
     </body>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 69b93d65c1..b223ca0f56 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -22,7 +22,7 @@
         <header>
             <h1>There was an error</h1>
             <p>
-                <strong id="errormsg">{{ error_description | e }}</strong>
+                <strong id="errormsg">{{ error_description }}</strong>
             </p>
             <p>
                 If you are seeing this page after clicking a link sent to you via email,
@@ -35,7 +35,7 @@
             </p>
             <div id="error_code">
                 <p><strong>Error code</strong></p>
-                <p>{{ error | e }}</p>
+                <p>{{ error }}</p>
             </div>
         </header>
 
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
index 5b38481012..62a640dad2 100644
--- a/synapse/res/templates/sso_login_idp_picker.html
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -3,22 +3,22 @@
     <head>
         <meta charset="UTF-8">
         <link rel="stylesheet" href="/_matrix/static/client/login/style.css">
-        <title>{{server_name | e}} Login</title>
+        <title>{{ server_name }} Login</title>
     </head>
     <body>
         <div id="container">
-            <h1 id="title">{{server_name | e}} Login</h1>
+            <h1 id="title">{{ server_name }} Login</h1>
             <div class="login_flow">
                 <p>Choose one of the following identity providers:</p>
             <form>
-                <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
+                <input type="hidden" name="redirectUrl" value="{{ redirect_url }}">
                 <ul class="radiobuttons">
 {% for p in providers %}
                     <li>
-                        <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
-                        <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
+                        <input type="radio" name="idp" id="prov{{ loop.index }}" value="{{ p.idp_id }}">
+                        <label for="prov{{ loop.index }}">{{ p.idp_name }}</label>
 {% if p.idp_icon %}
-                        <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
+                        <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/>
 {% endif %}
                     </li>
 {% endfor %}
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index ce4f573848..d1328a6969 100644
--- a/synapse/res/templates/sso_redirect_confirm.html
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -12,11 +12,11 @@
         <header>
             {% if new_user %}
             <h1>Your account is now ready</h1>
-            <p>You've made your account on {{ server_name | e }}.</p>
+            <p>You've made your account on {{ server_name }}.</p>
             {% else %}
             <h1>Log in</h1>
             {% endif %}
-            <p>Continue to confirm you trust <strong>{{ display_url | e }}</strong>.</p>
+            <p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p>
         </header>
         <main>
             {% if user_profile.avatar_url %}
@@ -24,13 +24,13 @@
                 <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
                 <div class="profile-details">
                     {% if user_profile.display_name %}
-                    <div class="display-name">{{ user_profile.display_name | e }}</div>
+                    <div class="display-name">{{ user_profile.display_name }}</div>
                     {% endif %}
-                    <div class="user-id">{{ user_id | e }}</div>
+                    <div class="user-id">{{ user_id }}</div>
                 </div>
             </div>
             {% endif %}
-            <a href="{{ redirect_url | e }}" class="primary-button">Continue</a>
+            <a href="{{ redirect_url }}" class="primary-button">Continue</a>
         </main>
     </body>
 </html>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6f7dc06503..57e0a10837 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018-2019 New Vector Ltd
+# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
+
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -36,6 +38,7 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
 from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
 from synapse.rest.admin.rooms import (
     DeleteRoomRestServlet,
+    ForwardExtremitiesRestServlet,
     JoinRoomAliasServlet,
     ListRoomRestServlet,
     MakeRoomAdminRestServlet,
@@ -51,6 +54,7 @@ from synapse.rest.admin.users import (
     PushersRestServlet,
     ResetPasswordRestServlet,
     SearchUsersRestServlet,
+    ShadowBanRestServlet,
     UserAdminServlet,
     UserMediaRestServlet,
     UserMembershipRestServlet,
@@ -230,6 +234,8 @@ def register_servlets(hs, http_server):
     EventReportsRestServlet(hs).register(http_server)
     PushersRestServlet(hs).register(http_server)
     MakeRoomAdminRestServlet(hs).register(http_server)
+    ShadowBanRestServlet(hs).register(http_server)
+    ForwardExtremitiesRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ab7cc9102a..f14915d47e 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -431,7 +431,17 @@ class MakeRoomAdminRestServlet(RestServlet):
             if not admin_users:
                 raise SynapseError(400, "No local admin user in room")
 
-            admin_user_id = admin_users[-1]
+            admin_user_id = None
+
+            for admin_user in reversed(admin_users):
+                if room_state.get((EventTypes.Member, admin_user)):
+                    admin_user_id = admin_user
+                    break
+
+            if not admin_user_id:
+                raise SynapseError(
+                    400, "No local admin user in room",
+                )
 
             pl_content = power_levels.content
         else:
@@ -499,3 +509,60 @@ class MakeRoomAdminRestServlet(RestServlet):
         )
 
         return 200, {}
+
+
+class ForwardExtremitiesRestServlet(RestServlet):
+    """Allows a server admin to get or clear forward extremities.
+
+    Clearing does not require restarting the server.
+
+        Clear forward extremities:
+        DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+
+        Get forward_extremities:
+        GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.room_member_handler = hs.get_room_member_handler()
+        self.store = hs.get_datastore()
+
+    async def resolve_room_id(self, room_identifier: str) -> str:
+        """Resolve to a room ID, if necessary."""
+        if RoomID.is_valid(room_identifier):
+            resolved_room_id = room_identifier
+        elif RoomAlias.is_valid(room_identifier):
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+            resolved_room_id = room_id.to_string()
+        else:
+            raise SynapseError(
+                400, "%s was not legal room ID or room alias" % (room_identifier,)
+            )
+        if not resolved_room_id:
+            raise SynapseError(
+                400, "Unknown room ID or room alias %s" % room_identifier
+            )
+        return resolved_room_id
+
+    async def on_DELETE(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        room_id = await self.resolve_room_id(room_identifier)
+
+        deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
+        return 200, {"deleted": deleted_count}
+
+    async def on_GET(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        room_id = await self.resolve_room_id(room_identifier)
+
+        extremities = await self.store.get_forward_extremities_for_room(room_id)
+        return 200, {"count": len(extremities), "results": extremities}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f39e3d6d5c..68c3c64a0d 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet):
     The parameter `deactivated` can be used to include deactivated users.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         start = parse_integer(request, "from", default=0)
         limit = parse_integer(request, "limit", default=100)
+
+        if start < 0:
+            raise SynapseError(
+                400,
+                "Query parameter from must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if limit < 0:
+            raise SynapseError(
+                400,
+                "Query parameter limit must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
         user_id = parse_string(request, "user_id", default=None)
         name = parse_string(request, "name", default=None)
         guests = parse_boolean(request, "guests", default=True)
@@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet):
             start, limit, user_id, name, guests, deactivated
         )
         ret = {"users": users, "total": total}
-        if len(users) >= limit:
+        if (start + limit) < total:
             ret["next_token"] = str(start + len(users))
 
         return 200, ret
@@ -875,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
         )
 
         return 200, {"access_token": token}
+
+
+class ShadowBanRestServlet(RestServlet):
+    """An admin API for shadow-banning a user.
+
+    A shadow-banned users receives successful responses to their client-server
+    API requests, but the events are not propagated into rooms.
+
+    Shadow-banning a user should be used as a tool of last resort and may lead
+    to confusing or broken behaviour for the client.
+
+    Example:
+
+        POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
+        {}
+
+        200 OK
+        {}
+    """
+
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_POST(self, request, user_id):
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Only local users can be shadow-banned")
+
+        await self.store.set_shadow_banned(UserID.from_string(user_id), True)
+
+        return 200, {}
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 65e68d641b..a84a2fb385 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
 class EmailPasswordRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password/email/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.datastore = hs.get_datastore()
@@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
         # an email address which is controlled by the attacker but which, after
@@ -379,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
@@ -430,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
 class MsisdnThreepidRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         super().__init__()
         self.store = self.hs.get_datastore()
@@ -458,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b093183e79..10e1891174 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "email", email
         )
@@ -205,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "msisdn", msisdn
         )
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 31a41e4a27..f71a03a12d 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -300,6 +300,7 @@ class FileInfo:
         thumbnail_height (int)
         thumbnail_method (str)
         thumbnail_type (str): Content type of thumbnail, e.g. image/png
+        thumbnail_length (int): The size of the media file, in bytes.
     """
 
     def __init__(
@@ -312,6 +313,7 @@ class FileInfo:
         thumbnail_height=None,
         thumbnail_method=None,
         thumbnail_type=None,
+        thumbnail_length=None,
     ):
         self.server_name = server_name
         self.file_id = file_id
@@ -321,6 +323,7 @@ class FileInfo:
         self.thumbnail_height = thumbnail_height
         self.thumbnail_method = thumbnail_method
         self.thumbnail_type = thumbnail_type
+        self.thumbnail_length = thumbnail_length
 
 
 def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a632099167..bf3be653aa 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -386,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         """
         Check whether the URL should be downloaded as oEmbed content instead.
 
-        Params:
+        Args:
             url: The URL to check.
 
         Returns:
@@ -403,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         """
         Request content from an oEmbed endpoint.
 
-        Params:
+        Args:
             endpoint: The oEmbed API endpoint.
             url: The URL to pass to the API.
 
@@ -692,27 +692,51 @@ class PreviewUrlResource(DirectServeJsonResource):
 def decode_and_calc_og(
     body: bytes, media_uri: str, request_encoding: Optional[str] = None
 ) -> Dict[str, Optional[str]]:
+    """
+    Calculate metadata for an HTML document.
+
+    This uses lxml to parse the HTML document into the OG response. If errors
+    occur during processing of the document, an empty response is returned.
+
+    Args:
+        body: The HTML document, as bytes.
+        media_url: The URI used to download the body.
+        request_encoding: The character encoding of the body, as a string.
+
+    Returns:
+        The OG response as a dictionary.
+    """
     # If there's no body, nothing useful is going to be found.
     if not body:
         return {}
 
     from lxml import etree
 
+    # Create an HTML parser. If this fails, log and return no metadata.
     try:
         parser = etree.HTMLParser(recover=True, encoding=request_encoding)
-        tree = etree.fromstring(body, parser)
-        og = _calc_og(tree, media_uri)
+    except LookupError:
+        # blindly consider the encoding as utf-8.
+        parser = etree.HTMLParser(recover=True, encoding="utf-8")
+    except Exception as e:
+        logger.warning("Unable to create HTML parser: %s" % (e,))
+        return {}
+
+    def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
+        # Attempt to parse the body. If this fails, log and return no metadata.
+        tree = etree.fromstring(body_attempt, parser)
+        return _calc_og(tree, media_uri)
+
+    # Attempt to parse the body. If this fails, log and return no metadata.
+    try:
+        return _attempt_calc_og(body)
     except UnicodeDecodeError:
         # blindly try decoding the body as utf-8, which seems to fix
         # the charset mismatches on https://google.com
-        parser = etree.HTMLParser(recover=True, encoding=request_encoding)
-        tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
-        og = _calc_og(tree, media_uri)
-
-    return og
+        return _attempt_calc_og(body.decode("utf-8", "ignore"))
 
 
-def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
+def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
     # suck our tree into lxml and define our OG response.
 
     # if we see any image URLs in the OG response, then spider them
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d6880f2e6e..d653a58be9 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,7 +16,7 @@
 
 
 import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
 
 from twisted.web.http import Request
 
@@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource):
             return
 
         thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
-        if thumbnail_infos:
-            thumbnail_info = self._select_thumbnail(
-                width, height, method, m_type, thumbnail_infos
-            )
-
-            file_info = FileInfo(
-                server_name=None,
-                file_id=media_id,
-                url_cache=media_info["url_cache"],
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-            )
-
-            t_type = file_info.thumbnail_type
-            t_length = thumbnail_info["thumbnail_length"]
-
-            responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(request, responder, t_type, t_length)
-        else:
-            logger.info("Couldn't find any generated thumbnails")
-            respond_404(request)
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_id,
+            url_cache=media_info["url_cache"],
+            server_name=None,
+        )
 
     async def _select_or_generate_local_thumbnail(
         self,
@@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource):
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
         )
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_info["filesystem_id"],
+            url_cache=None,
+            server_name=server_name,
+        )
 
+    async def _select_and_respond_with_thumbnail(
+        self,
+        request: Request,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        thumbnail_infos: List[Dict[str, Any]],
+        file_id: str,
+        url_cache: Optional[str] = None,
+        server_name: Optional[str] = None,
+    ) -> None:
+        """
+        Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            request: The incoming request.
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of dictionaries of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: The URL cache value.
+            server_name: The server name, if this is a remote thumbnail.
+        """
         if thumbnail_infos:
-            thumbnail_info = self._select_thumbnail(
-                width, height, method, m_type, thumbnail_infos
+            file_info = self._select_thumbnail(
+                desired_width,
+                desired_height,
+                desired_method,
+                desired_type,
+                thumbnail_infos,
+                file_id,
+                url_cache,
+                server_name,
             )
-            file_info = FileInfo(
-                server_name=server_name,
-                file_id=media_info["filesystem_id"],
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-            )
-
-            t_type = file_info.thumbnail_type
-            t_length = thumbnail_info["thumbnail_length"]
+            if not file_info:
+                logger.info("Couldn't find a thumbnail matching the desired inputs")
+                respond_404(request)
+                return
 
             responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(request, responder, t_type, t_length)
+            await respond_with_responder(
+                request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+            )
         else:
             logger.info("Failed to find any generated thumbnails")
             respond_404(request)
@@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource):
         desired_height: int,
         desired_method: str,
         desired_type: str,
-        thumbnail_infos,
-    ) -> dict:
+        thumbnail_infos: List[Dict[str, Any]],
+        file_id: str,
+        url_cache: Optional[str],
+        server_name: Optional[str],
+    ) -> Optional[FileInfo]:
+        """
+        Choose an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of dictionaries of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: The URL cache value.
+            server_name: The server name, if this is a remote thumbnail.
+
+        Returns:
+             The thumbnail which best matches the desired parameters.
+        """
+        desired_method = desired_method.lower()
+
+        # The chosen thumbnail.
+        thumbnail_info = None
+
         d_w = desired_width
         d_h = desired_height
 
-        if desired_method.lower() == "crop":
+        if desired_method == "crop":
+            # Thumbnails that match equal or larger sizes of desired width/height.
             crop_info_list = []
+            # Other thumbnails.
             crop_info_list2 = []
             for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info["thumbnail_method"] != "crop":
+                    continue
+
                 t_w = info["thumbnail_width"]
                 t_h = info["thumbnail_height"]
-                t_method = info["thumbnail_method"]
-                if t_method == "crop":
-                    aspect_quality = abs(d_w * t_h - d_h * t_w)
-                    min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
-                    size_quality = abs((d_w - t_w) * (d_h - t_h))
-                    type_quality = desired_type != info["thumbnail_type"]
-                    length_quality = info["thumbnail_length"]
-                    if t_w >= d_w or t_h >= d_h:
-                        crop_info_list.append(
-                            (
-                                aspect_quality,
-                                min_quality,
-                                size_quality,
-                                type_quality,
-                                length_quality,
-                                info,
-                            )
+                aspect_quality = abs(d_w * t_h - d_h * t_w)
+                min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+                size_quality = abs((d_w - t_w) * (d_h - t_h))
+                type_quality = desired_type != info["thumbnail_type"]
+                length_quality = info["thumbnail_length"]
+                if t_w >= d_w or t_h >= d_h:
+                    crop_info_list.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
                         )
-                    else:
-                        crop_info_list2.append(
-                            (
-                                aspect_quality,
-                                min_quality,
-                                size_quality,
-                                type_quality,
-                                length_quality,
-                                info,
-                            )
+                    )
+                else:
+                    crop_info_list2.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
                         )
+                    )
             if crop_info_list:
-                return min(crop_info_list)[-1]
-            else:
-                return min(crop_info_list2)[-1]
-        else:
+                thumbnail_info = min(crop_info_list)[-1]
+            elif crop_info_list2:
+                thumbnail_info = min(crop_info_list2)[-1]
+        elif desired_method == "scale":
+            # Thumbnails that match equal or larger sizes of desired width/height.
             info_list = []
+            # Other thumbnails.
             info_list2 = []
+
             for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info["thumbnail_method"] != "scale":
+                    continue
+
                 t_w = info["thumbnail_width"]
                 t_h = info["thumbnail_height"]
-                t_method = info["thumbnail_method"]
                 size_quality = abs((d_w - t_w) * (d_h - t_h))
                 type_quality = desired_type != info["thumbnail_type"]
                 length_quality = info["thumbnail_length"]
-                if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+                if t_w >= d_w or t_h >= d_h:
                     info_list.append((size_quality, type_quality, length_quality, info))
-                elif t_method == "scale":
+                else:
                     info_list2.append(
                         (size_quality, type_quality, length_quality, info)
                     )
             if info_list:
-                return min(info_list)[-1]
-            else:
-                return min(info_list2)[-1]
+                thumbnail_info = min(info_list)[-1]
+            elif info_list2:
+                thumbnail_info = min(info_list2)[-1]
+
+        if thumbnail_info:
+            return FileInfo(
+                file_id=file_id,
+                url_cache=url_cache,
+                server_name=server_name,
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
+                thumbnail_length=thumbnail_info["thumbnail_length"],
+            )
+
+        # No matching thumbnail was found.
+        return None
diff --git a/synapse/server.py b/synapse/server.py
index 9cdda83aa1..9bdd3177d7 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -103,6 +103,7 @@ from synapse.notifier import Notifier
 from synapse.push.action_generator import ActionGenerator
 from synapse.push.pusherpool import PusherPool
 from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.tcp.external_cache import ExternalCache
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.resource import ReplicationStreamer
 from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string
 logger = logging.getLogger(__name__)
 
 if TYPE_CHECKING:
+    from txredisapi import RedisProtocol
+
     from synapse.handlers.oidc_handler import OidcHandler
     from synapse.handlers.saml_handler import SamlHandler
 
@@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_account_data_handler(self) -> AccountDataHandler:
         return AccountDataHandler(self)
 
+    @cache_in_self
+    def get_external_cache(self) -> ExternalCache:
+        return ExternalCache(self)
+
+    @cache_in_self
+    def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
+        if not self.config.redis.redis_enabled:
+            return None
+
+        # We only want to import redis module if we're using it, as we have
+        # `txredisapi` as an optional dependency.
+        from synapse.replication.tcp.redis import lazyConnection
+
+        logger.info(
+            "Connecting to redis (host=%r port=%r) for external cache",
+            self.config.redis_host,
+            self.config.redis_port,
+        )
+
+        return lazyConnection(
+            hs=self,
+            host=self.config.redis_host,
+            port=self.config.redis_port,
+            password=self.config.redis.redis_password,
+            reconnect=True,
+        )
+
     async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
         return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 84f59c7d85..3bd9ff8ca0 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -310,6 +310,7 @@ class StateHandler:
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
+            entry = None
 
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
                 current_state_ids=state_ids_before_event,
             )
 
-            # XXX: can we update the state cache entry for the new state group? or
-            # could we set a flag on resolve_state_groups_for_events to tell it to
-            # always make a state group?
+            # Assign the new state group to the cached state entry.
+            #
+            # Note that this can race in that we could generate multiple state
+            # groups for the same state entry, but that is just inefficient
+            # rather than dangerous.
+            if entry and entry.state_group is None:
+                entry.state_group = state_group_before_event
 
         #
         # now if it's not a state event, we're done
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index c7220bc778..d2ba4bd2fc 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -262,6 +262,12 @@ class LoggingTransaction:
         return self.txn.description
 
     def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+        """Similar to `executemany`, except `txn.rowcount` will not be correct
+        afterwards.
+
+        More efficient than `executemany` on PostgreSQL
+        """
+
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch  # type: ignore
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ae561a2da3..5d0845588c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
 from .event_federation import EventFederationStore
 from .event_push_actions import EventPushActionsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
+from .events_forward_extremities import EventForwardExtremitiesStore
 from .filtering import FilteringStore
 from .group_server import GroupServerStore
 from .keys import KeyStore
@@ -118,6 +119,7 @@ class DataStore(
     UIAuthStore,
     CacheInvalidationWorkerStore,
     ServerMetricsStore,
+    EventForwardExtremitiesStore,
 ):
     def __init__(self, database: DatabasePool, db_conn, hs):
         self.hs = hs
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9097677648..659d8f245f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 DELETE FROM device_lists_outbound_last_success
                 WHERE destination = ? AND user_id = ?
             """
-            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+            txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
 
             logger.info("Pruned %d device list outbound pokes", count)
 
@@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
-        txn.executemany(
+        txn.execute_batch(
             """
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c128889bf9..309f1e865b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, dict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str]]
-    ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+    ) -> Dict[str, Dict[str, Dict[str, str]]]:
         """Take a list of one time keys out of the database.
 
         Args:
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 1b657191a9..438383abe1 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 VALUES (?, ?, ?, ?, ?, ?)
             """
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     _gen_entry(user_id, actions)
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             ],
         )
 
-        txn.executemany(
+        txn.execute_batch(
             """
                 UPDATE event_push_summary
                 SET notif_count = ?, unread_count = ?, stream_ordering = ?
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5db7d7aaa8..ccda9f1caa 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -473,8 +473,9 @@ class PersistEventsStore:
             txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
         )
 
-    @staticmethod
+    @classmethod
     def _add_chain_cover_index(
+        cls,
         txn,
         db_pool: DatabasePool,
         event_to_room_id: Dict[str, str],
@@ -614,60 +615,17 @@ class PersistEventsStore:
         if not events_to_calc_chain_id_for:
             return
 
-        # We now calculate the chain IDs/sequence numbers for the events. We
-        # do this by looking at the chain ID and sequence number of any auth
-        # event with the same type/state_key and incrementing the sequence
-        # number by one. If there was no match or the chain ID/sequence
-        # number is already taken we generate a new chain.
-        #
-        # We need to do this in a topologically sorted order as we want to
-        # generate chain IDs/sequence numbers of an event's auth events
-        # before the event itself.
-        chains_tuples_allocated = set()  # type: Set[Tuple[int, int]]
-        new_chain_tuples = {}  # type: Dict[str, Tuple[int, int]]
-        for event_id in sorted_topologically(
-            events_to_calc_chain_id_for, event_to_auth_chain
-        ):
-            existing_chain_id = None
-            for auth_id in event_to_auth_chain.get(event_id, []):
-                if event_to_types.get(event_id) == event_to_types.get(auth_id):
-                    existing_chain_id = chain_map[auth_id]
-                    break
-
-            new_chain_tuple = None
-            if existing_chain_id:
-                # We found a chain ID/sequence number candidate, check its
-                # not already taken.
-                proposed_new_id = existing_chain_id[0]
-                proposed_new_seq = existing_chain_id[1] + 1
-                if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
-                    already_allocated = db_pool.simple_select_one_onecol_txn(
-                        txn,
-                        table="event_auth_chains",
-                        keyvalues={
-                            "chain_id": proposed_new_id,
-                            "sequence_number": proposed_new_seq,
-                        },
-                        retcol="event_id",
-                        allow_none=True,
-                    )
-                    if already_allocated:
-                        # Mark it as already allocated so we don't need to hit
-                        # the DB again.
-                        chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
-                    else:
-                        new_chain_tuple = (
-                            proposed_new_id,
-                            proposed_new_seq,
-                        )
-
-            if not new_chain_tuple:
-                new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
-
-            chains_tuples_allocated.add(new_chain_tuple)
-
-            chain_map[event_id] = new_chain_tuple
-            new_chain_tuples[event_id] = new_chain_tuple
+        # Allocate chain ID/sequence numbers to each new event.
+        new_chain_tuples = cls._allocate_chain_ids(
+            txn,
+            db_pool,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
+            events_to_calc_chain_id_for,
+            chain_map,
+        )
+        chain_map.update(new_chain_tuples)
 
         db_pool.simple_insert_many_txn(
             txn,
@@ -794,6 +752,137 @@ class PersistEventsStore:
             ],
         )
 
+    @staticmethod
+    def _allocate_chain_ids(
+        txn,
+        db_pool: DatabasePool,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+        events_to_calc_chain_id_for: Set[str],
+        chain_map: Dict[str, Tuple[int, int]],
+    ) -> Dict[str, Tuple[int, int]]:
+        """Allocates, but does not persist, chain ID/sequence numbers for the
+        events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+        for info on args)
+        """
+
+        # We now calculate the chain IDs/sequence numbers for the events. We do
+        # this by looking at the chain ID and sequence number of any auth event
+        # with the same type/state_key and incrementing the sequence number by
+        # one. If there was no match or the chain ID/sequence number is already
+        # taken we generate a new chain.
+        #
+        # We try to reduce the number of times that we hit the database by
+        # batching up calls, to make this more efficient when persisting large
+        # numbers of state events (e.g. during joins).
+        #
+        # We do this by:
+        #   1. Calculating for each event which auth event will be used to
+        #      inherit the chain ID, i.e. converting the auth chain graph to a
+        #      tree that we can allocate chains on. We also keep track of which
+        #      existing chain IDs have been referenced.
+        #   2. Fetching the max allocated sequence number for each referenced
+        #      existing chain ID, generating a map from chain ID to the max
+        #      allocated sequence number.
+        #   3. Iterating over the tree and allocating a chain ID/seq no. to the
+        #      new event, by incrementing the sequence number from the
+        #      referenced event's chain ID/seq no. and checking that the
+        #      incremented sequence number hasn't already been allocated (by
+        #      looking in the map generated in the previous step). We generate a
+        #      new chain if the sequence number has already been allocated.
+        #
+
+        existing_chains = set()  # type: Set[int]
+        tree = []  # type: List[Tuple[str, Optional[str]]]
+
+        # We need to do this in a topologically sorted order as we want to
+        # generate chain IDs/sequence numbers of an event's auth events before
+        # the event itself.
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            for auth_id in event_to_auth_chain.get(event_id, []):
+                if event_to_types.get(event_id) == event_to_types.get(auth_id):
+                    existing_chain_id = chain_map.get(auth_id)
+                    if existing_chain_id:
+                        existing_chains.add(existing_chain_id[0])
+
+                    tree.append((event_id, auth_id))
+                    break
+            else:
+                tree.append((event_id, None))
+
+        # Fetch the current max sequence number for each existing referenced chain.
+        sql = """
+            SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+            WHERE %s
+            GROUP BY chain_id
+        """
+        clause, args = make_in_list_sql_clause(
+            db_pool.engine, "chain_id", existing_chains
+        )
+        txn.execute(sql % (clause,), args)
+
+        chain_to_max_seq_no = {row[0]: row[1] for row in txn}  # type: Dict[Any, int]
+
+        # Allocate the new events chain ID/sequence numbers.
+        #
+        # To reduce the number of calls to the database we don't allocate a
+        # chain ID number in the loop, instead we use a temporary `object()` for
+        # each new chain ID. Once we've done the loop we generate the necessary
+        # number of new chain IDs in one call, replacing all temporary
+        # objects with real allocated chain IDs.
+
+        unallocated_chain_ids = set()  # type: Set[object]
+        new_chain_tuples = {}  # type: Dict[str, Tuple[Any, int]]
+        for event_id, auth_event_id in tree:
+            # If we reference an auth_event_id we fetch the allocated chain ID,
+            # either from the existing `chain_map` or the newly generated
+            # `new_chain_tuples` map.
+            existing_chain_id = None
+            if auth_event_id:
+                existing_chain_id = new_chain_tuples.get(auth_event_id)
+                if not existing_chain_id:
+                    existing_chain_id = chain_map[auth_event_id]
+
+            new_chain_tuple = None  # type: Optional[Tuple[Any, int]]
+            if existing_chain_id:
+                # We found a chain ID/sequence number candidate, check its
+                # not already taken.
+                proposed_new_id = existing_chain_id[0]
+                proposed_new_seq = existing_chain_id[1] + 1
+
+                if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+                    new_chain_tuple = (
+                        proposed_new_id,
+                        proposed_new_seq,
+                    )
+
+            # If we need to start a new chain we allocate a temporary chain ID.
+            if not new_chain_tuple:
+                new_chain_tuple = (object(), 1)
+                unallocated_chain_ids.add(new_chain_tuple[0])
+
+            new_chain_tuples[event_id] = new_chain_tuple
+            chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+        # Generate new chain IDs for all unallocated chain IDs.
+        newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+            txn, len(unallocated_chain_ids)
+        )
+
+        # Map from potentially temporary chain ID to real chain ID
+        chain_id_to_allocated_map = dict(
+            zip(unallocated_chain_ids, newly_allocated_chain_ids)
+        )  # type: Dict[Any, int]
+        chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+        return {
+            event_id: (chain_id_to_allocated_map[chain_id], seq)
+            for event_id, (chain_id, seq) in new_chain_tuples.items()
+        }
+
     def _persist_transaction_ids_txn(
         self,
         txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e46e44ba54..5ca4fa6817 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id, json FROM events"
@@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
 
-            for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
-                clump = update_rows[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, update_rows)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_search_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id FROM events"
@@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
 
-            for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
-                clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, rows_to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
new file mode 100644
index 0000000000..0ac1da9c35
--- /dev/null
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventForwardExtremitiesStore(SQLBaseStore):
+    async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+        """Delete any extra forward extremities for a room.
+
+        Invalidates the "get_latest_event_ids_in_room" cache if any forward
+        extremities were deleted.
+
+        Returns count deleted.
+        """
+
+        def delete_forward_extremities_for_room_txn(txn):
+            # First we need to get the event_id to not delete
+            sql = """
+                SELECT event_id FROM event_forward_extremities
+                INNER JOIN events USING (room_id, event_id)
+                WHERE room_id = ?
+                ORDER BY stream_ordering DESC
+                LIMIT 1
+            """
+            txn.execute(sql, (room_id,))
+            rows = txn.fetchall()
+            try:
+                event_id = rows[0][0]
+                logger.debug(
+                    "Found event_id %s as the forward extremity to keep for room %s",
+                    event_id,
+                    room_id,
+                )
+            except KeyError:
+                msg = "No forward extremity event found for room %s" % room_id
+                logger.warning(msg)
+                raise SynapseError(400, msg)
+
+            # Now delete the extra forward extremities
+            sql = """
+                DELETE FROM event_forward_extremities
+                WHERE event_id != ? AND room_id = ?
+            """
+
+            txn.execute(sql, (event_id, room_id))
+            logger.info(
+                "Deleted %s extra forward extremities for room %s",
+                txn.rowcount,
+                room_id,
+            )
+
+            if txn.rowcount > 0:
+                # Invalidate the cache
+                self._invalidate_cache_and_stream(
+                    txn, self.get_latest_event_ids_in_room, (room_id,),
+                )
+
+            return txn.rowcount
+
+        return await self.db_pool.runInteraction(
+            "delete_forward_extremities_for_room",
+            delete_forward_extremities_for_room_txn,
+        )
+
+    async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+        """Get list of forward extremities for a room."""
+
+        def get_forward_extremities_for_room_txn(txn):
+            sql = """
+                SELECT event_id, state_group, depth, received_ts
+                FROM event_forward_extremities
+                INNER JOIN event_to_state_groups USING (event_id)
+                INNER JOIN events USING (room_id, event_id)
+                WHERE room_id = ?
+            """
+
+            txn.execute(sql, (room_id,))
+            return self.db_pool.cursor_to_dict(txn)
+
+        return await self.db_pool.runInteraction(
+            "get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
+        )
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 283c8a5e22..e017177655 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_origin = ? AND media_id = ?"
             )
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (time_ms, media_origin, media_id)
@@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_id = ?"
             )
 
-            txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+            txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
 
         return await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
@@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
         def _delete_url_cache_txn(txn):
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache", _delete_url_cache_txn
@@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def _delete_url_cache_media_txn(txn):
             sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
             sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..92e65aa640 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
+    async def count_daily_e2ee_messages(self):
+        """
+        Returns an estimate of the number of messages sent in the last day.
+
+        If it has been significantly less or more than one day since the last
+        call to this function, it will return None.
+        """
+
+        def _count_messages(txn):
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
+
+    async def count_daily_sent_e2ee_messages(self):
+        def _count_messages(txn):
+            # This is good enough as if you have silly characters in your own
+            # hostname then thats your own fault.
+            like_clause = "%:" + self.hs.hostname
+
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                    AND sender LIKE ?
+                AND stream_ordering > ?
+            """
+
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction(
+            "count_daily_sent_e2ee_messages", _count_messages
+        )
+
+    async def count_daily_active_e2ee_rooms(self):
+        def _count(txn):
+            sql = """
+                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction(
+            "count_daily_active_e2ee_rooms", _count
+        )
+
     async def count_daily_messages(self):
         """
         Returns an estimate of the number of messages sent in the last day.
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         )
 
         # Update backward extremeties
-        txn.executemany(
+        txn.execute_batch(
             "INSERT INTO event_backward_extremities (room_id, event_id)"
             " VALUES (?, ?)",
             [(room_id, event_id) for event_id, in new_backwards_extrems],
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bc7621b8d6..2687ef3e43 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore):
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
-            self.db_pool.simple_delete_one_txn(
+            # It is expected that there is exactly one pusher to delete, but
+            # if it isn't there (or there are multiple) delete them all.
+            self.db_pool.simple_delete_txn(
                 txn,
                 "pushers",
                 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 14c0878d81..8405dd460f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
+    async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
+        """Sets whether a user shadow-banned.
+
+        Args:
+            user: user ID of the user to test
+            shadow_banned: true iff the user is to be shadow-banned, false otherwise.
+        """
+
+        def set_shadow_banned_txn(txn):
+            self.db_pool.simple_update_one_txn(
+                txn,
+                table="users",
+                keyvalues={"name": user.to_string()},
+                updatevalues={"shadow_banned": shadow_banned},
+            )
+            # In order for this to apply immediately, clear the cache for this user.
+            tokens = self.db_pool.simple_select_onecol_txn(
+                txn,
+                table="access_tokens",
+                keyvalues={"user_id": user.to_string()},
+                retcol="token",
+            )
+            for token in tokens:
+                self._invalidate_cache_and_stream(
+                    txn, self.get_user_by_access_token, (token,)
+                )
+
+        await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+
     def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
         sql = """
             SELECT users.name as user_id,
@@ -1124,7 +1153,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
                 FROM user_threepids
             """
 
-            txn.executemany(sql, [(id_server,) for id_server in id_servers])
+            txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
 
         if id_servers:
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..92382bed28 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             "max_stream_id_exclusive", self._stream_order_on_start + 1
         )
 
-        INSERT_CLUMP_SIZE = 1000
-
         def add_membership_profile_txn(txn):
             sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
                 UPDATE room_memberships SET display_name = ?, avatar_url = ?
                 WHERE event_id = ? AND room_id = ?
             """
-            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
-                clump = to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(to_update_sql, clump)
+            txn.execute_batch(to_update_sql, to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
index f35c70b699..9e8f35c1d2 100644
--- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
         # { "ignored_users": "@someone:example.org": {} }
         ignored_users = content.get("ignored_users", {})
         if isinstance(ignored_users, dict) and ignored_users:
-            cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+            cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
 
     # Add indexes after inserting data for efficiency.
     logger.info("Adding constraints to ignored_users table")
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..f5e7d9ef98 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
 
@@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
 
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
@@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
     async def search_rooms(
         self,
-        room_ids: List[str],
+        room_ids: Collection[str],
         search_term: str,
         keys: List[str],
         limit,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0cdb3ec1f7..d421d18f8d 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,11 +15,12 @@
 # limitations under the License.
 
 import logging
-from collections import Counter
 from enum import Enum
 from itertools import chain
 from typing import Any, Dict, List, Optional, Tuple
 
+from typing_extensions import Counter
+
 from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
@@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore):
         return slice_list
 
     @cached()
-    async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
+    async def get_earliest_token_for_stats(
+        self, stats_type: str, id: str
+    ) -> Optional[int]:
         """
         Fetch the "earliest token". This is used by the room stats delta
         processor to ignore deltas that have been processed between the
@@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore):
         )
 
     async def bulk_update_stats_delta(
-        self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+        self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
     ) -> None:
         """Bulk update stats tables for a given stream_id and updates the stats
         incremental position.
@@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore):
 
     async def get_changes_room_total_events_and_bytes(
         self, min_pos: int, max_pos: int
-    ) -> Dict[str, Dict[str, int]]:
+    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
         """Fetches the counts of events in the given range of stream IDs.
 
         Args:
@@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore):
             max_pos,
         )
 
-    def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+    def get_changes_room_total_events_and_bytes_txn(
+        self, txn, low_pos: int, high_pos: int
+    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
         """Gets the total_events and total_event_bytes counts for rooms and
         senders, in a range of stream_orderings (including backfilled events).
 
         Args:
             txn
-            low_pos (int): Low stream ordering
-            high_pos (int): High stream ordering
+            low_pos: Low stream ordering
+            high_pos: High stream ordering
 
         Returns:
-            tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
-            room and user deltas for total_events/total_event_bytes in the
+            The room and user deltas for total_events/total_event_bytes in the
             format of `stats_id` -> fields
         """
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ef11f1c3b3..7b9729da09 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             desc="get_user_in_directory",
         )
 
-    async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+    async def update_user_directory_stream_pos(self, stream_id: int) -> None:
         await self.db_pool.simple_update_one(
             table="user_directory_stream_pos",
             keyvalues={},
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..89cdc84a9c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             )
 
         logger.info("[purge] removing redundant state groups")
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             ((sg,) for sg in state_groups_to_delete),
         )
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups WHERE id = ?",
             ((sg,) for sg in state_groups_to_delete),
         )
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index bb84c0d792..71ef5a72dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
 import heapq
 import logging
 import threading
-from collections import deque
+from collections import OrderedDict
 from contextlib import contextmanager
 from typing import Dict, List, Optional, Set, Tuple, Union
 
 import attr
-from typing_extensions import Deque
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
             )
-        self._unfinished_ids = deque()  # type: Deque[int]
+
+        # We use this as an ordered set, as we want to efficiently append items,
+        # remove items and get the first item. Since we insert IDs in order, the
+        # insertion ordering will ensure its in the correct ordering.
+        #
+        # The key and values are the same, but we never look at the values.
+        self._unfinished_ids = OrderedDict()  # type: OrderedDict[int, int]
 
     def get_next(self):
         """
@@ -113,7 +118,7 @@ class StreamIdGenerator:
             self._current += self._step
             next_id = self._current
 
-            self._unfinished_ids.append(next_id)
+            self._unfinished_ids[next_id] = next_id
 
         @contextmanager
         def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
                 yield next_id
             finally:
                 with self._lock:
-                    self._unfinished_ids.remove(next_id)
+                    self._unfinished_ids.pop(next_id)
 
         return _AsyncCtxManagerWrapper(manager())
 
@@ -140,7 +145,7 @@ class StreamIdGenerator:
             self._current += n * self._step
 
             for next_id in next_ids:
-                self._unfinished_ids.append(next_id)
+                self._unfinished_ids[next_id] = next_id
 
         @contextmanager
         def manager():
@@ -149,7 +154,7 @@ class StreamIdGenerator:
             finally:
                 with self._lock:
                     for next_id in next_ids:
-                        self._unfinished_ids.remove(next_id)
+                        self._unfinished_ids.pop(next_id)
 
         return _AsyncCtxManagerWrapper(manager())
 
@@ -162,7 +167,7 @@ class StreamIdGenerator:
         """
         with self._lock:
             if self._unfinished_ids:
-                return self._unfinished_ids[0] - self._step
+                return next(iter(self._unfinished_ids)) - self._step
 
             return self._current
 
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c780ade077..0ec4dc2918 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -70,6 +70,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        """Get the next `n` IDs in the sequence"""
+        ...
+
+    @abc.abstractmethod
     def check_consistency(
         self,
         db_conn: "LoggingDatabaseConnection",
@@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator):
             self._current_max_id += 1
             return self._current_max_id
 
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            first_id = self._current_max_id + 1
+            self._current_max_id += n
+            return [first_id + i for i in range(n)]
+
     def check_consistency(
         self,
         db_conn: Connection,
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 1ee61851e4..09b094ded7 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -49,7 +49,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
     module = importlib.import_module(module)
     provider_class = getattr(module, clz)
 
-    module_config = provider.get("config")
+    # Load the module config. If None, pass an empty dictionary instead
+    module_config = provider.get("config") or {}
     try:
         provider_config = provider_class.parse_config(module_config)
     except jsonschema.ValidationError as e: