summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2021-03-01 10:06:09 +0000
committerBen Banfield-Zanin <benbz@matrix.org>2021-03-01 10:06:09 +0000
commitb26bee9faf957643cd34c4146b250b0009be205d (patch)
treea7a7e29f30acb437d010bdf6116c0f2729f21a1b /synapse
parentMerge remote-tracking branch 'origin/release-v1.26.0' into toml/keycloak_hints (diff)
parentFixup changelog (diff)
downloadsynapse-github/toml/keycloak_hints.tar.xz
Merge remote-tracking branch 'origin/release-v1.28.0' into toml/keycloak_hints github/toml/keycloak_hints toml/keycloak_hints
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py16
-rw-r--r--synapse/api/constants.py8
-rw-r--r--synapse/api/errors.py21
-rw-r--r--synapse/api/presence.py3
-rw-r--r--synapse/api/urls.py2
-rw-r--r--synapse/app/_base.py17
-rw-r--r--synapse/app/generic_worker.py14
-rw-r--r--synapse/app/homeserver.py16
-rw-r--r--synapse/app/phone_stats_home.py9
-rw-r--r--synapse/appservice/__init__.py5
-rw-r--r--synapse/appservice/api.py7
-rw-r--r--synapse/appservice/scheduler.py2
-rw-r--r--synapse/config/_base.py83
-rw-r--r--synapse/config/_base.pyi6
-rw-r--r--synapse/config/auth.py15
-rw-r--r--synapse/config/captcha.py4
-rw-r--r--synapse/config/cas.py48
-rw-r--r--synapse/config/consent_config.py2
-rw-r--r--synapse/config/database.py3
-rw-r--r--synapse/config/emailconfig.py14
-rw-r--r--synapse/config/experimental.py29
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/config/logger.py5
-rw-r--r--synapse/config/oidc_config.py102
-rw-r--r--synapse/config/ratelimiting.py32
-rw-r--r--synapse/config/registration.py27
-rw-r--r--synapse/config/repository.py21
-rw-r--r--synapse/config/room_directory.py2
-rw-r--r--synapse/config/saml2_config.py35
-rw-r--r--synapse/config/server.py127
-rw-r--r--synapse/config/sso.py133
-rw-r--r--synapse/config/workers.py19
-rw-r--r--synapse/crypto/context_factory.py13
-rw-r--r--synapse/event_auth.py10
-rw-r--r--synapse/events/builder.py4
-rw-r--r--synapse/events/snapshot.py3
-rw-r--r--synapse/events/spamcheck.py47
-rw-r--r--synapse/events/third_party_rules.py3
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/federation/federation_client.py133
-rw-r--r--synapse/federation/federation_server.py21
-rw-r--r--synapse/federation/persistence.py6
-rw-r--r--synapse/federation/send_queue.py8
-rw-r--r--synapse/federation/sender/__init__.py60
-rw-r--r--synapse/federation/sender/per_destination_queue.py12
-rw-r--r--synapse/federation/sender/transaction_manager.py5
-rw-r--r--synapse/federation/transport/client.py86
-rw-r--r--synapse/federation/transport/server.py105
-rw-r--r--synapse/federation/units.py8
-rw-r--r--synapse/groups/attestations.py49
-rw-r--r--synapse/groups/groups_server.py312
-rw-r--r--synapse/handlers/acme.py12
-rw-r--r--synapse/handlers/acme_issuing_service.py27
-rw-r--r--synapse/handlers/admin.py6
-rw-r--r--synapse/handlers/appservice.py4
-rw-r--r--synapse/handlers/auth.py88
-rw-r--r--synapse/handlers/cas_handler.py57
-rw-r--r--synapse/handlers/deactivate_account.py6
-rw-r--r--synapse/handlers/device.py42
-rw-r--r--synapse/handlers/devicemessage.py7
-rw-r--r--synapse/handlers/e2e_keys.py247
-rw-r--r--synapse/handlers/e2e_room_keys.py91
-rw-r--r--synapse/handlers/events.py3
-rw-r--r--synapse/handlers/federation.py119
-rw-r--r--synapse/handlers/groups_local.py107
-rw-r--r--synapse/handlers/identity.py35
-rw-r--r--synapse/handlers/initial_sync.py12
-rw-r--r--synapse/handlers/message.py76
-rw-r--r--synapse/handlers/oidc_handler.py294
-rw-r--r--synapse/handlers/pagination.py14
-rw-r--r--synapse/handlers/presence.py33
-rw-r--r--synapse/handlers/profile.py3
-rw-r--r--synapse/handlers/receipts.py9
-rw-r--r--synapse/handlers/register.py35
-rw-r--r--synapse/handlers/room.py72
-rw-r--r--synapse/handlers/room_member.py58
-rw-r--r--synapse/handlers/room_member_worker.py6
-rw-r--r--synapse/handlers/saml_handler.py37
-rw-r--r--synapse/handlers/search.py38
-rw-r--r--synapse/handlers/set_password.py10
-rw-r--r--synapse/handlers/sso.py319
-rw-r--r--synapse/handlers/state_deltas.py14
-rw-r--r--synapse/handlers/stats.py42
-rw-r--r--synapse/handlers/sync.py37
-rw-r--r--synapse/handlers/typing.py78
-rw-r--r--synapse/handlers/user_directory.py16
-rw-r--r--synapse/http/__init__.py3
-rw-r--r--synapse/http/client.py33
-rw-r--r--synapse/http/federation/matrix_federation_agent.py13
-rw-r--r--synapse/http/federation/well_known_resolver.py3
-rw-r--r--synapse/http/matrixfederationclient.py20
-rw-r--r--synapse/http/request_metrics.py3
-rw-r--r--synapse/http/server.py123
-rw-r--r--synapse/http/servlet.py2
-rw-r--r--synapse/http/site.py9
-rw-r--r--synapse/logging/_remote.py4
-rw-r--r--synapse/logging/_structured.py9
-rw-r--r--synapse/logging/context.py12
-rw-r--r--synapse/logging/opentracing.py10
-rw-r--r--synapse/logging/utils.py3
-rw-r--r--synapse/metrics/__init__.py7
-rw-r--r--synapse/metrics/_exposition.py2
-rw-r--r--synapse/metrics/background_process_metrics.py9
-rw-r--r--synapse/module_api/__init__.py19
-rw-r--r--synapse/notifier.py30
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py9
-rw-r--r--synapse/push/emailpusher.py26
-rw-r--r--synapse/push/httppusher.py23
-rw-r--r--synapse/push/mailer.py492
-rw-r--r--synapse/push/presentable_names.py26
-rw-r--r--synapse/push/pusherpool.py9
-rw-r--r--synapse/python_dependencies.py8
-rw-r--r--synapse/replication/http/_base.py5
-rw-r--r--synapse/replication/http/account_data.py6
-rw-r--r--synapse/replication/http/membership.py5
-rw-r--r--synapse/replication/http/register.py6
-rw-r--r--synapse/replication/tcp/commands.py3
-rw-r--r--synapse/replication/tcp/external_cache.py105
-rw-r--r--synapse/replication/tcp/handler.py51
-rw-r--r--synapse/replication/tcp/protocol.py27
-rw-r--r--synapse/replication/tcp/redis.py173
-rw-r--r--synapse/replication/tcp/resource.py6
-rw-r--r--synapse/replication/tcp/streams/_base.py26
-rw-r--r--synapse/replication/tcp/streams/events.py3
-rw-r--r--synapse/res/templates/sso.css129
-rw-r--r--synapse/res/templates/sso_account_deactivated.html27
-rw-r--r--synapse/res/templates/sso_auth_account_details.html188
-rw-r--r--synapse/res/templates/sso_auth_account_details.js116
-rw-r--r--synapse/res/templates/sso_auth_bad_user.html28
-rw-r--r--synapse/res/templates/sso_auth_confirm.html33
-rw-r--r--synapse/res/templates/sso_auth_success.html40
-rw-r--r--synapse/res/templates/sso_error.html102
-rw-r--r--synapse/res/templates/sso_footer.html19
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html76
-rw-r--r--synapse/res/templates/sso_new_user_consent.html32
-rw-r--r--synapse/res/templates/sso_partial_profile.html19
-rw-r--r--synapse/res/templates/sso_redirect_confirm.html40
-rw-r--r--synapse/res/username_picker/index.html19
-rw-r--r--synapse/res/username_picker/script.js95
-rw-r--r--synapse/res/username_picker/style.css27
-rw-r--r--synapse/rest/admin/__init__.py10
-rw-r--r--synapse/rest/admin/groups.py3
-rw-r--r--synapse/rest/admin/media.py9
-rw-r--r--synapse/rest/admin/rooms.py182
-rw-r--r--synapse/rest/admin/users.py61
-rw-r--r--synapse/rest/client/v1/login.py63
-rw-r--r--synapse/rest/client/v1/profile.py4
-rw-r--r--synapse/rest/client/v1/pusher.py4
-rw-r--r--synapse/rest/client/v1/room.py10
-rw-r--r--synapse/rest/client/v2_alpha/account.py27
-rw-r--r--synapse/rest/client/v2_alpha/devices.py22
-rw-r--r--synapse/rest/client/v2_alpha/groups.py346
-rw-r--r--synapse/rest/client/v2_alpha/keys.py5
-rw-r--r--synapse/rest/client/v2_alpha/register.py17
-rw-r--r--synapse/rest/client/v2_alpha/relations.py8
-rw-r--r--synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py2
-rw-r--r--synapse/rest/consent/consent_resource.py1
-rw-r--r--synapse/rest/media/v1/_base.py5
-rw-r--r--synapse/rest/media/v1/download_resource.py3
-rw-r--r--synapse/rest/media/v1/media_repository.py34
-rw-r--r--synapse/rest/media/v1/media_storage.py61
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py118
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py236
-rw-r--r--synapse/rest/media/v1/upload_resource.py12
-rw-r--r--synapse/rest/synapse/client/__init__.py55
-rw-r--r--synapse/rest/synapse/client/new_user_consent.py97
-rw-r--r--synapse/rest/synapse/client/oidc/__init__.py (renamed from synapse/rest/oidc/__init__.py)6
-rw-r--r--synapse/rest/synapse/client/oidc/callback_resource.py (renamed from synapse/rest/oidc/callback_resource.py)13
-rw-r--r--synapse/rest/synapse/client/pick_username.py101
-rw-r--r--synapse/rest/synapse/client/saml2/__init__.py (renamed from synapse/rest/saml2/__init__.py)8
-rw-r--r--synapse/rest/synapse/client/saml2/metadata_resource.py (renamed from synapse/rest/saml2/metadata_resource.py)0
-rw-r--r--synapse/rest/synapse/client/saml2/response_resource.py (renamed from synapse/rest/saml2/response_resource.py)0
-rw-r--r--synapse/rest/synapse/client/sso_register.py50
-rw-r--r--synapse/rest/well_known.py4
-rw-r--r--synapse/server.py46
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py2
-rw-r--r--synapse/state/__init__.py27
-rw-r--r--synapse/state/v1.py14
-rw-r--r--synapse/state/v2.py6
-rw-r--r--synapse/storage/__init__.py3
-rw-r--r--synapse/storage/background_updates.py8
-rw-r--r--synapse/storage/database.py52
-rw-r--r--synapse/storage/databases/__init__.py5
-rw-r--r--synapse/storage/databases/main/__init__.py6
-rw-r--r--synapse/storage/databases/main/appservice.py3
-rw-r--r--synapse/storage/databases/main/client_ips.py12
-rw-r--r--synapse/storage/databases/main/deviceinbox.py2
-rw-r--r--synapse/storage/databases/main/devices.py42
-rw-r--r--synapse/storage/databases/main/directory.py7
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py15
-rw-r--r--synapse/storage/databases/main/event_federation.py19
-rw-r--r--synapse/storage/databases/main/event_push_actions.py22
-rw-r--r--synapse/storage/databases/main/events.py262
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py41
-rw-r--r--synapse/storage/databases/main/events_forward_extremities.py104
-rw-r--r--synapse/storage/databases/main/events_worker.py16
-rw-r--r--synapse/storage/databases/main/group_server.py40
-rw-r--r--synapse/storage/databases/main/keys.py7
-rw-r--r--synapse/storage/databases/main/media_repository.py25
-rw-r--r--synapse/storage/databases/main/metrics.py58
-rw-r--r--synapse/storage/databases/main/presence.py4
-rw-r--r--synapse/storage/databases/main/profile.py6
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py11
-rw-r--r--synapse/storage/databases/main/receipts.py13
-rw-r--r--synapse/storage/databases/main/registration.py81
-rw-r--r--synapse/storage/databases/main/room.py25
-rw-r--r--synapse/storage/databases/main/roommember.py23
-rw-r--r--synapse/storage/databases/main/schema/delta/33/remote_media_ts.py3
-rw-r--r--synapse/storage/databases/main/schema/delta/59/01ignored_user.py2
-rw-r--r--synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite10
-rw-r--r--synapse/storage/databases/main/search.py7
-rw-r--r--synapse/storage/databases/main/state.py11
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/stats.py26
-rw-r--r--synapse/storage/databases/main/stream.py42
-rw-r--r--synapse/storage/databases/main/transactions.py21
-rw-r--r--synapse/storage/databases/main/ui_auth.py22
-rw-r--r--synapse/storage/databases/main/user_directory.py16
-rw-r--r--synapse/storage/databases/state/bg_updates.py2
-rw-r--r--synapse/storage/databases/state/store.py10
-rw-r--r--synapse/storage/engines/__init__.py8
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py3
-rw-r--r--synapse/storage/engines/sqlite.py14
-rw-r--r--synapse/storage/persist_events.py12
-rw-r--r--synapse/storage/prepare_database.py14
-rw-r--r--synapse/storage/purge_events.py6
-rw-r--r--synapse/storage/state.py5
-rw-r--r--synapse/storage/types.py37
-rw-r--r--synapse/storage/util/id_generators.py45
-rw-r--r--synapse/storage/util/sequence.py27
-rw-r--r--synapse/types.py11
-rw-r--r--synapse/util/async_helpers.py15
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/util/caches/cached_call.py129
-rw-r--r--synapse/util/caches/descriptors.py17
-rw-r--r--synapse/util/caches/stream_change_cache.py6
-rw-r--r--synapse/util/distributor.py5
-rw-r--r--synapse/util/file_consumer.py15
-rw-r--r--synapse/util/iterutils.py3
-rw-r--r--synapse/util/jsonobject.py6
-rw-r--r--synapse/util/metrics.py3
-rw-r--r--synapse/util/module_loader.py5
-rw-r--r--synapse/util/patch_inline_callbacks.py17
-rw-r--r--synapse/util/stringutils.py33
-rw-r--r--synapse/util/templates.py115
-rw-r--r--synapse/visibility.py3
250 files changed, 6589 insertions, 2832 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 93601dbad0..869e860fb0 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.26.0"
+__version__ = "1.28.0"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 67ecbd32ff..89e62b0e36 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -168,7 +168,7 @@ class Auth:
         rights: str = "access",
         allow_expired: bool = False,
     ) -> synapse.types.Requester:
-        """ Get a registered user's ID.
+        """Get a registered user's ID.
 
         Args:
             request: An HTTP request with an access_token query parameter.
@@ -294,9 +294,12 @@ class Auth:
         return user_id, app_service
 
     async def get_user_by_access_token(
-        self, token: str, rights: str = "access", allow_expired: bool = False,
+        self,
+        token: str,
+        rights: str = "access",
+        allow_expired: bool = False,
     ) -> TokenLookupResult:
-        """ Validate access token and get user_id from it
+        """Validate access token and get user_id from it
 
         Args:
             token: The access token to get the user by
@@ -489,7 +492,7 @@ class Auth:
         return service
 
     async def is_server_admin(self, user: UserID) -> bool:
-        """ Check if the given user is a local server admin.
+        """Check if the given user is a local server admin.
 
         Args:
             user: user to check
@@ -500,7 +503,10 @@ class Auth:
         return await self.store.is_server_admin(user)
 
     def compute_auth_events(
-        self, event, current_state_ids: StateMap[str], for_verification: bool = False,
+        self,
+        event,
+        current_state_ids: StateMap[str],
+        for_verification: bool = False,
     ) -> List[str]:
         """Given an event and current state return the list of event IDs used
         to auth an event.
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 565a8cd76a..af8d59cf87 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -27,6 +27,11 @@ MAX_ALIAS_LENGTH = 255
 # the maximum length for a user id is 255 characters
 MAX_USERID_LENGTH = 255
 
+# The maximum length for a group id is 255 characters
+MAX_GROUPID_LENGTH = 255
+MAX_GROUP_CATEGORYID_LENGTH = 255
+MAX_GROUP_ROLEID_LENGTH = 255
+
 
 class Membership:
 
@@ -128,8 +133,7 @@ class UserTypes:
 
 
 class RelationTypes:
-    """The types of relations known to this server.
-    """
+    """The types of relations known to this server."""
 
     ANNOTATION = "m.annotation"
     REPLACE = "m.replace"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cd6670d0a2..2a789ea3e8 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -390,8 +390,7 @@ class InvalidCaptchaError(SynapseError):
 
 
 class LimitExceededError(SynapseError):
-    """A client has sent too many requests and is being throttled.
-    """
+    """A client has sent too many requests and is being throttled."""
 
     def __init__(
         self,
@@ -408,8 +407,7 @@ class LimitExceededError(SynapseError):
 
 
 class RoomKeysVersionError(SynapseError):
-    """A client has tried to upload to a non-current version of the room_keys store
-    """
+    """A client has tried to upload to a non-current version of the room_keys store"""
 
     def __init__(self, current_version: str):
         """
@@ -426,7 +424,9 @@ class UnsupportedRoomVersionError(SynapseError):
 
     def __init__(self, msg: str = "Homeserver does not support this room version"):
         super().__init__(
-            code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION,
+            code=400,
+            msg=msg,
+            errcode=Codes.UNSUPPORTED_ROOM_VERSION,
         )
 
 
@@ -461,8 +461,7 @@ class IncompatibleRoomVersionError(SynapseError):
 
 
 class PasswordRefusedError(SynapseError):
-    """A password has been refused, either during password reset/change or registration.
-    """
+    """A password has been refused, either during password reset/change or registration."""
 
     def __init__(
         self,
@@ -470,7 +469,9 @@ class PasswordRefusedError(SynapseError):
         errcode: str = Codes.WEAK_PASSWORD,
     ):
         super().__init__(
-            code=400, msg=msg, errcode=errcode,
+            code=400,
+            msg=msg,
+            errcode=errcode,
         )
 
 
@@ -493,7 +494,7 @@ class RequestSendFailed(RuntimeError):
 
 
 def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
-    """ Utility method for constructing an error response for client-server
+    """Utility method for constructing an error response for client-server
     interactions.
 
     Args:
@@ -510,7 +511,7 @@ def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
 
 
 class FederationError(RuntimeError):
-    """  This class is used to inform remote homeservers about erroneous
+    """This class is used to inform remote homeservers about erroneous
     PDUs they sent us.
 
     FATAL: The remote server could not interpret the source event.
diff --git a/synapse/api/presence.py b/synapse/api/presence.py
index 18a462f0ee..b9a8e29460 100644
--- a/synapse/api/presence.py
+++ b/synapse/api/presence.py
@@ -56,8 +56,7 @@ class UserPresenceState(
 
     @classmethod
     def default(cls, user_id):
-        """Returns a default presence state.
-        """
+        """Returns a default presence state."""
         return cls(
             user_id=user_id,
             state=PresenceState.OFFLINE,
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index e36aeef31f..6379c86dde 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -42,6 +42,8 @@ class ConsentURIBuilder:
         """
         if hs_config.form_secret is None:
             raise ConfigError("form_secret not set in config")
+        if hs_config.public_baseurl is None:
+            raise ConfigError("public_baseurl not set in config")
 
         self._hmac_secret = hs_config.form_secret.encode("utf-8")
         self._public_baseurl = hs_config.public_baseurl
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 395e202b89..43b1f1e94b 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -16,6 +16,7 @@
 import gc
 import logging
 import os
+import platform
 import signal
 import socket
 import sys
@@ -57,7 +58,7 @@ def register_sighup(func, *args, **kwargs):
 
 
 def start_worker_reactor(appname, config, run_command=reactor.run):
-    """ Run the reactor in the main process
+    """Run the reactor in the main process
 
     Daemonizes if necessary, and then configures some resources, before starting
     the reactor. Pulls configuration from the 'worker' settings in 'config'.
@@ -92,7 +93,7 @@ def start_reactor(
     logger,
     run_command=reactor.run,
 ):
-    """ Run the reactor in the main process
+    """Run the reactor in the main process
 
     Daemonizes if necessary, and then configures some resources, before starting
     the reactor
@@ -312,9 +313,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
     refresh_certificate(hs)
 
     # Start the tracer
-    synapse.logging.opentracing.init_tracer(  # type: ignore[attr-defined] # noqa
-        hs
-    )
+    synapse.logging.opentracing.init_tracer(hs)  # type: ignore[attr-defined] # noqa
 
     # It is now safe to start your Synapse.
     hs.start_listening(listeners)
@@ -339,7 +338,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
     # rest of time. Doing so means less work each GC (hopefully).
     #
     # This only works on Python 3.7
-    if sys.version_info >= (3, 7):
+    if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
         gc.collect()
         gc.freeze()
 
@@ -369,8 +368,7 @@ def setup_sentry(hs):
 
 
 def setup_sdnotify(hs):
-    """Adds process state hooks to tell systemd what we are up to.
-    """
+    """Adds process state hooks to tell systemd what we are up to."""
 
     # Tell systemd our state, if we're using it. This will silently fail if
     # we're not using systemd.
@@ -404,8 +402,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
 
 
 class _LimitedHostnameResolver:
-    """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
-    """
+    """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups."""
 
     def __init__(self, resolver, max_dns_requests_in_flight):
         self._resolver = resolver
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index e60988fa4a..6526acb2f2 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -22,6 +22,7 @@ from typing import Dict, Iterable, Optional, Set
 from typing_extensions import ContextManager
 
 from twisted.internet import address
+from twisted.web.resource import IResource
 
 import synapse
 import synapse.events
@@ -90,9 +91,8 @@ from synapse.replication.tcp.streams import (
     ToDeviceStream,
 )
 from synapse.rest.admin import register_servlets_for_media_repo
-from synapse.rest.client.v1 import events, room
+from synapse.rest.client.v1 import events, login, room
 from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
-from synapse.rest.client.v1.login import LoginRestServlet
 from synapse.rest.client.v1.profile import (
     ProfileAvatarURLRestServlet,
     ProfileDisplaynameRestServlet,
@@ -127,6 +127,7 @@ from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.server import HomeServer, cache_in_self
 from synapse.storage.databases.main.censor_events import CensorEventsStore
 from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
@@ -420,8 +421,7 @@ class GenericWorkerPresence(BasePresenceHandler):
         ]
 
     async def set_state(self, target_user, state, ignore_status_msg=False):
-        """Set the presence state of the user.
-        """
+        """Set the presence state of the user."""
         presence = state["presence"]
 
         valid_presence = (
@@ -507,7 +507,7 @@ class GenericWorkerServer(HomeServer):
             site_tag = port
 
         # We always include a health resource.
-        resources = {"/health": HealthResource()}
+        resources = {"/health": HealthResource()}  # type: Dict[str, IResource]
 
         for res in listener_config.http_options.resources:
             for name in res.names:
@@ -517,7 +517,7 @@ class GenericWorkerServer(HomeServer):
                     resource = JsonResource(self, canonical_json=False)
 
                     RegisterRestServlet(self).register(resource)
-                    LoginRestServlet(self).register(resource)
+                    login.register_servlets(self, resource)
                     ThreepidRestServlet(self).register(resource)
                     DevicesRestServlet(self).register(resource)
                     KeyQueryServlet(self).register(resource)
@@ -557,6 +557,8 @@ class GenericWorkerServer(HomeServer):
                     groups.register_servlets(self, resource)
 
                     resources.update({CLIENT_API_PREFIX: resource})
+
+                    resources.update(build_synapse_client_resource_tree(self))
                 elif name == "federation":
                     resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
                 elif name == "media":
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 57a2f5237c..244657cb88 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -60,8 +60,7 @@ from synapse.rest import ClientRestResource
 from synapse.rest.admin import AdminRestResource
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.rest.synapse.client.pick_idp import PickIdpResource
-from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.rest.well_known import WellKnownResource
 from synapse.server import HomeServer
 from synapse.storage import DataStore
@@ -190,21 +189,10 @@ class SynapseHomeServer(HomeServer):
                     "/_matrix/client/versions": client_resource,
                     "/.well-known/matrix/client": WellKnownResource(self),
                     "/_synapse/admin": AdminRestResource(self),
-                    "/_synapse/client/pick_username": pick_username_resource(self),
-                    "/_synapse/client/pick_idp": PickIdpResource(self),
+                    **build_synapse_client_resource_tree(self),
                 }
             )
 
-            if self.get_config().oidc_enabled:
-                from synapse.rest.oidc import OIDCResource
-
-                resources["/_synapse/oidc"] = OIDCResource(self)
-
-            if self.get_config().saml2_enabled:
-                from synapse.rest.saml2 import SAML2Resource
-
-                resources["/_matrix/saml2"] = SAML2Resource(self)
-
             if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL:
                 from synapse.rest.synapse.client.password_reset import (
                     PasswordResetSubmitTokenResource,
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index c38cf8231f..8f86cecb76 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -93,15 +93,20 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
 
     stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
     stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+    daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
+    stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
+    stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
+    daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
+    stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
     stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
     stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
+    daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+    stats["daily_sent_messages"] = daily_sent_messages
 
     r30_results = await hs.get_datastore().count_r30_users()
     for name, count in r30_results.items():
         stats["r30_users_" + name] = count
 
-    daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
-    stats["daily_sent_messages"] = daily_sent_messages
     stats["cache_factor"] = hs.config.caches.global_factor
     stats["event_cache_size"] = hs.config.caches.event_cache_size
 
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 3944780a42..0bfc5e445f 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -166,7 +166,10 @@ class ApplicationService:
 
     @cached(num_args=1, cache_context=True)
     async def matches_user_in_member_list(
-        self, room_id: str, store: "DataStore", cache_context: _CacheContext,
+        self,
+        room_id: str,
+        store: "DataStore",
+        cache_context: _CacheContext,
     ) -> bool:
         """Check if this service is interested a room based upon it's membership
 
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e366a982b8..93c2aabcca 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -76,9 +76,6 @@ def _is_valid_3pe_result(r, field):
     fields = r["fields"]
     if not isinstance(fields, dict):
         return False
-    for k in fields.keys():
-        if not isinstance(fields[k], str):
-            return False
 
     return True
 
@@ -230,7 +227,9 @@ class ApplicationServiceApi(SimpleHttpClient):
 
         try:
             await self.put_json(
-                uri=uri, json_body=body, args={"access_token": service.hs_token},
+                uri=uri,
+                json_body=body,
+                args={"access_token": service.hs_token},
             )
             sent_transactions_counter.labels(service.id).inc()
             sent_events_counter.labels(service.id).inc(len(events))
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 58291afc22..366c476f80 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -68,7 +68,7 @@ MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
 
 
 class ApplicationServiceScheduler:
-    """ Public facing API for this module. Does the required DI to tie the
+    """Public facing API for this module. Does the required DI to tie the
     components together. This also serves as the "event_pool", which in this
     case is a simple array.
     """
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 94144efc87..97399eb9ba 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,18 +18,18 @@
 import argparse
 import errno
 import os
-import time
-import urllib.parse
 from collections import OrderedDict
 from hashlib import sha256
 from textwrap import dedent
-from typing import Any, Callable, Iterable, List, MutableMapping, Optional
+from typing import Any, Iterable, List, MutableMapping, Optional
 
 import attr
 import jinja2
 import pkg_resources
 import yaml
 
+from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter
+
 
 class ConfigError(Exception):
     """Represents a problem parsing the configuration
@@ -203,11 +203,30 @@ class Config:
         with open(file_path) as file_stream:
             return file_stream.read()
 
+    def read_template(self, filename: str) -> jinja2.Template:
+        """Load a template file from disk.
+
+        This function will attempt to load the given template from the default Synapse
+        template directory.
+
+        Files read are treated as Jinja templates. The templates is not rendered yet
+        and has autoescape enabled.
+
+        Args:
+            filename: A template filename to read.
+
+        Raises:
+            ConfigError: if the file's path is incorrect or otherwise cannot be read.
+
+        Returns:
+            A jinja2 template.
+        """
+        return self.read_templates([filename])[0]
+
     def read_templates(
         self,
         filenames: List[str],
         custom_template_directory: Optional[str] = None,
-        autoescape: bool = False,
     ) -> List[jinja2.Template]:
         """Load a list of template files from disk using the given variables.
 
@@ -215,7 +234,8 @@ class Config:
         template directory. If `custom_template_directory` is supplied, that directory
         is tried first.
 
-        Files read are treated as Jinja templates. These templates are not rendered yet.
+        Files read are treated as Jinja templates. The templates are not rendered yet
+        and have autoescape enabled.
 
         Args:
             filenames: A list of template filenames to read.
@@ -223,16 +243,12 @@ class Config:
             custom_template_directory: A directory to try to look for the templates
                 before using the default Synapse template directory instead.
 
-            autoescape: Whether to autoescape variables before inserting them into the
-                template.
-
         Raises:
             ConfigError: if the file's path is incorrect or otherwise cannot be read.
 
         Returns:
             A list of jinja2 templates.
         """
-        templates = []
         search_directories = [self.default_template_dir]
 
         # The loader will first look in the custom template directory (if specified) for the
@@ -248,8 +264,12 @@ class Config:
             # Search the custom template directory as well
             search_directories.insert(0, custom_template_directory)
 
+        # TODO: switch to synapse.util.templates.build_jinja_env
         loader = jinja2.FileSystemLoader(search_directories)
-        env = jinja2.Environment(loader=loader, autoescape=autoescape)
+        env = jinja2.Environment(
+            loader=loader,
+            autoescape=jinja2.select_autoescape(),
+        )
 
         # Update the environment with our custom filters
         env.filters.update(
@@ -259,44 +279,8 @@ class Config:
             }
         )
 
-        for filename in filenames:
-            # Load the template
-            template = env.get_template(filename)
-            templates.append(template)
-
-        return templates
-
-
-def _format_ts_filter(value: int, format: str):
-    return time.strftime(format, time.localtime(value / 1000))
-
-
-def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
-    """Create and return a jinja2 filter that converts MXC urls to HTTP
-
-    Args:
-        public_baseurl: The public, accessible base URL of the homeserver
-    """
-
-    def mxc_to_http_filter(value, width, height, resize_method="crop"):
-        if value[0:6] != "mxc://":
-            return ""
-
-        server_and_media_id = value[6:]
-        fragment = None
-        if "#" in server_and_media_id:
-            server_and_media_id, fragment = server_and_media_id.split("#", 1)
-            fragment = "#" + fragment
-
-        params = {"width": width, "height": height, "method": resize_method}
-        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
-            public_baseurl,
-            server_and_media_id,
-            urllib.parse.urlencode(params),
-            fragment or "",
-        )
-
-    return mxc_to_http_filter
+        # Load the templates
+        return [env.get_template(filename) for filename in filenames]
 
 
 class RootConfig:
@@ -846,8 +830,7 @@ class ShardedWorkerHandlingConfig:
     instances = attr.ib(type=List[str])
 
     def should_handle(self, instance_name: str, key: str) -> bool:
-        """Whether this instance is responsible for handling the given key.
-        """
+        """Whether this instance is responsible for handling the given key."""
         # If multiple instances are not defined we always return true
         if not self.instances or len(self.instances) == 1:
             return True
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 29aa064e57..70025b5d60 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -9,6 +9,7 @@ from synapse.config import (
     consent_config,
     database,
     emailconfig,
+    experimental,
     groups,
     jwt_config,
     key,
@@ -18,6 +19,7 @@ from synapse.config import (
     password_auth_providers,
     push,
     ratelimiting,
+    redis,
     registration,
     repository,
     room_directory,
@@ -48,10 +50,11 @@ def path_exists(file_path: str): ...
 
 class RootConfig:
     server: server.ServerConfig
+    experimental: experimental.ExperimentalConfig
     tls: tls.TlsConfig
     database: database.DatabaseConfig
     logging: logger.LoggingConfig
-    ratelimit: ratelimiting.RatelimitConfig
+    ratelimiting: ratelimiting.RatelimitConfig
     media: repository.ContentRepositoryConfig
     captcha: captcha.CaptchaConfig
     voip: voip.VoipConfig
@@ -79,6 +82,7 @@ class RootConfig:
     roomdirectory: room_directory.RoomDirectoryConfig
     thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
     tracer: tracer.TracerConfig
+    redis: redis.RedisConfig
 
     config_classes: List = ...
     def __init__(self) -> None: ...
diff --git a/synapse/config/auth.py b/synapse/config/auth.py
index 2b3e2ce87b..9aabaadf9e 100644
--- a/synapse/config/auth.py
+++ b/synapse/config/auth.py
@@ -18,8 +18,7 @@ from ._base import Config
 
 
 class AuthConfig(Config):
-    """Password and login configuration
-    """
+    """Password and login configuration"""
 
     section = "auth"
 
@@ -38,7 +37,9 @@ class AuthConfig(Config):
 
         # User-interactive authentication
         ui_auth = config.get("ui_auth") or {}
-        self.ui_auth_session_timeout = ui_auth.get("session_timeout", 0)
+        self.ui_auth_session_timeout = self.parse_duration(
+            ui_auth.get("session_timeout", 0)
+        )
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
@@ -94,11 +95,11 @@ class AuthConfig(Config):
               #require_uppercase: true
 
         ui_auth:
-            # The number of milliseconds to allow a user-interactive authentication
-            # session to be active.
+            # The amount of time to allow a user-interactive authentication session
+            # to be active.
             #
             # This defaults to 0, meaning the user is queried for their credentials
-            # before every action, but this can be overridden to alow a single
+            # before every action, but this can be overridden to allow a single
             # validation to be re-used.  This weakens the protections afforded by
             # the user-interactive authentication process, by allowing for multiple
             # (and potentially different) operations to use the same validation session.
@@ -106,5 +107,5 @@ class AuthConfig(Config):
             # Uncomment below to allow for credential validation to last for 15
             # seconds.
             #
-            #session_timeout: 15000
+            #session_timeout: "15s"
         """
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index cb00958165..9e48f865cc 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -28,9 +28,7 @@ class CaptchaConfig(Config):
             "recaptcha_siteverify_api",
             "https://www.recaptcha.net/recaptcha/api/siteverify",
         )
-        self.recaptcha_template = self.read_templates(
-            ["recaptcha.html"], autoescape=True
-        )[0]
+        self.recaptcha_template = self.read_template("recaptcha.html")
 
     def generate_config_section(self, **kwargs):
         return """\
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index c7877b4095..dbf5085965 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -13,7 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import Config
+from typing import Any, List
+
+from synapse.config.sso import SsoAttributeRequirement
+
+from ._base import Config, ConfigError
+from ._util import validate_config
 
 
 class CasConfig(Config):
@@ -30,14 +35,26 @@ class CasConfig(Config):
 
         if self.cas_enabled:
             self.cas_server_url = cas_config["server_url"]
-            self.cas_service_url = cas_config["service_url"]
+
+            # The public baseurl is required because it is used by the redirect
+            # template.
+            public_baseurl = self.public_baseurl
+            if not public_baseurl:
+                raise ConfigError("cas_config requires a public_baseurl to be set")
+
+            # TODO Update this to a _synapse URL.
+            self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket"
             self.cas_displayname_attribute = cas_config.get("displayname_attribute")
-            self.cas_required_attributes = cas_config.get("required_attributes") or {}
+            required_attributes = cas_config.get("required_attributes") or {}
+            self.cas_required_attributes = _parsed_required_attributes_def(
+                required_attributes
+            )
+
         else:
             self.cas_server_url = None
             self.cas_service_url = None
             self.cas_displayname_attribute = None
-            self.cas_required_attributes = {}
+            self.cas_required_attributes = []
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
@@ -53,10 +70,6 @@ class CasConfig(Config):
           #
           #server_url: "https://cas-server.com"
 
-          # The public URL of the homeserver.
-          #
-          #service_url: "https://homeserver.domain.com:8448"
-
           # The attribute of the CAS response to use as the display name.
           #
           # If unset, no displayname will be set.
@@ -73,3 +86,22 @@ class CasConfig(Config):
           #  userGroup: "staff"
           #  department: None
         """
+
+
+# CAS uses a legacy required attributes mapping, not the one provided by
+# SsoAttributeRequirement.
+REQUIRED_ATTRIBUTES_SCHEMA = {
+    "type": "object",
+    "additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]},
+}
+
+
+def _parsed_required_attributes_def(
+    required_attributes: Any,
+) -> List[SsoAttributeRequirement]:
+    validate_config(
+        REQUIRED_ATTRIBUTES_SCHEMA,
+        required_attributes,
+        config_path=("cas_config", "required_attributes"),
+    )
+    return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()]
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index 6efa59b110..c47f364b14 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -89,7 +89,7 @@ class ConsentConfig(Config):
 
     def read_config(self, config, **kwargs):
         consent_config = config.get("user_consent")
-        self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0]
+        self.terms_template = self.read_template("terms.html")
 
         if consent_config is None:
             return
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 8a18a9ca2a..e7889b9c20 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -207,8 +207,7 @@ class DatabaseConfig(Config):
         )
 
     def get_single_database(self) -> DatabaseConnectionConfig:
-        """Returns the database if there is only one, useful for e.g. tests
-        """
+        """Returns the database if there is only one, useful for e.g. tests"""
         if not self.databases:
             raise Exception("More than one database exists")
 
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 6a487afd34..52505ac5d2 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -166,6 +166,11 @@ class EmailConfig(Config):
             if not self.email_notif_from:
                 missing.append("email.notif_from")
 
+            # public_baseurl is required to build password reset and validation links that
+            # will be emailed to users
+            if config.get("public_baseurl") is None:
+                missing.append("public_baseurl")
+
             if missing:
                 raise ConfigError(
                     MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),)
@@ -264,6 +269,9 @@ class EmailConfig(Config):
             if not self.email_notif_from:
                 missing.append("email.notif_from")
 
+            if config.get("public_baseurl") is None:
+                missing.append("public_baseurl")
+
             if missing:
                 raise ConfigError(
                     "email.enable_notifs is True but required keys are missing: %s"
@@ -281,7 +289,8 @@ class EmailConfig(Config):
                 self.email_notif_template_html,
                 self.email_notif_template_text,
             ) = self.read_templates(
-                [notif_template_html, notif_template_text], template_dir,
+                [notif_template_html, notif_template_text],
+                template_dir,
             )
 
             self.email_notif_for_new_users = email_config.get(
@@ -303,7 +312,8 @@ class EmailConfig(Config):
                 self.account_validity_template_html,
                 self.account_validity_template_text,
             ) = self.read_templates(
-                [expiry_template_html, expiry_template_text], template_dir,
+                [expiry_template_html, expiry_template_text],
+                template_dir,
             )
 
         subjects_config = email_config.get("subjects", {})
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
new file mode 100644
index 0000000000..b1c1c51e4d
--- /dev/null
+++ b/synapse/config/experimental.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config._base import Config
+from synapse.types import JsonDict
+
+
+class ExperimentalConfig(Config):
+    """Config section for enabling experimental features"""
+
+    section = "experimental"
+
+    def read_config(self, config: JsonDict, **kwargs):
+        experimental = config.get("experimental_features") or {}
+
+        # MSC2858 (multiple SSO identity providers)
+        self.msc2858_enabled = experimental.get("msc2858_enabled", False)  # type: bool
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 4bd2b3587b..64a2429f77 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -24,6 +24,7 @@ from .cas import CasConfig
 from .consent_config import ConsentConfig
 from .database import DatabaseConfig
 from .emailconfig import EmailConfig
+from .experimental import ExperimentalConfig
 from .federation import FederationConfig
 from .groups import GroupsConfig
 from .jwt_config import JWTConfig
@@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig):
 
     config_classes = [
         ServerConfig,
+        ExperimentalConfig,
         TlsConfig,
         FederationConfig,
         CacheConfig,
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 4df3f93c1c..e56cf846f5 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -162,7 +162,10 @@ class LoggingConfig(Config):
         )
 
         logging_group.add_argument(
-            "-f", "--log-file", dest="log_file", help=argparse.SUPPRESS,
+            "-f",
+            "--log-file",
+            dest="log_file",
+            help=argparse.SUPPRESS,
         )
 
     def generate_files(self, config, config_dir_path):
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index d58a83be7f..a27594befc 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import string
+from collections import Counter
 from typing import Iterable, Optional, Tuple, Type
 
 import attr
@@ -43,8 +43,20 @@ class OIDCConfig(Config):
         except DependencyException as e:
             raise ConfigError(e.message) from e
 
+        # check we don't have any duplicate idp_ids now. (The SSO handler will also
+        # check for duplicates when the REST listeners get registered, but that happens
+        # after synapse has forked so doesn't give nice errors.)
+        c = Counter([i.idp_id for i in self.oidc_providers])
+        for idp_id, count in c.items():
+            if count > 1:
+                raise ConfigError(
+                    "Multiple OIDC providers have the idp_id %r." % idp_id
+                )
+
         public_baseurl = self.public_baseurl
-        self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
+        if public_baseurl is None:
+            raise ConfigError("oidc_config requires a public_baseurl to be set")
+        self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback"
 
     @property
     def oidc_enabled(self) -> bool:
@@ -68,10 +80,14 @@ class OIDCConfig(Config):
         #       offer the user a choice of login mechanisms.
         #
         #   idp_icon: An optional icon for this identity provider, which is presented
-        #       by identity picker pages. If given, must be an MXC URI of the format
-        #       mxc://<server-name>/<media-id>. (An easy way to obtain such an MXC URI
-        #       is to upload an image to an (unencrypted) room and then copy the "url"
-        #       from the source of the event.)
+        #       by clients and Synapse's own IdP picker page. If given, must be an
+        #       MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to
+        #       obtain such an MXC URI is to upload an image to an (unencrypted) room
+        #       and then copy the "url" from the source of the event.)
+        #
+        #   idp_brand: An optional brand for this identity provider, allowing clients
+        #       to style the login flow according to the identity provider in question.
+        #       See the spec for possible options here.
         #
         #   discover: set to 'false' to disable the use of the OIDC discovery mechanism
         #       to discover endpoints. Defaults to true.
@@ -132,17 +148,21 @@ class OIDCConfig(Config):
         #
         #           For the default provider, the following settings are available:
         #
-        #             sub: name of the claim containing a unique identifier for the
-        #                 user. Defaults to 'sub', which OpenID Connect compliant
-        #                 providers should provide.
+        #             subject_claim: name of the claim containing a unique identifier
+        #                 for the user. Defaults to 'sub', which OpenID Connect
+        #                 compliant providers should provide.
         #
         #             localpart_template: Jinja2 template for the localpart of the MXID.
         #                 If this is not set, the user will be prompted to choose their
-        #                 own username.
+        #                 own username (see 'sso_auth_account_details.html' in the 'sso'
+        #                 section of this file).
         #
         #             display_name_template: Jinja2 template for the display name to set
         #                 on first login. If unset, no displayname will be set.
         #
+        #             email_template: Jinja2 template for the email address of the user.
+        #                 If unset, no email address will be added to the account.
+        #
         #             extra_attributes: a map of Jinja2 templates for extra attributes
         #                 to send back to the client during login.
         #                 Note that these are non-standard and clients will ignore them
@@ -178,6 +198,12 @@ class OIDCConfig(Config):
           #  userinfo_endpoint: "https://accounts.example.com/userinfo"
           #  jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
           #  skip_verification: true
+          #  user_mapping_provider:
+          #    config:
+          #      subject_claim: "id"
+          #      localpart_template: "{{{{ user.login }}}}"
+          #      display_name_template: "{{{{ user.name }}}}"
+          #      email_template: "{{{{ user.email }}}}"
 
           # For use with Keycloak
           #
@@ -192,6 +218,7 @@ class OIDCConfig(Config):
           #
           #- idp_id: github
           #  idp_name: Github
+          #  idp_brand: org.matrix.github
           #  discover: false
           #  issuer: "https://github.com/"
           #  client_id: "your-client-id" # TO BE FILLED
@@ -203,8 +230,8 @@ class OIDCConfig(Config):
           #  user_mapping_provider:
           #    config:
           #      subject_claim: "id"
-          #      localpart_template: "{{ user.login }}"
-          #      display_name_template: "{{ user.name }}"
+          #      localpart_template: "{{{{ user.login }}}}"
+          #      display_name_template: "{{{{ user.name }}}}"
         """.format(
             mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
         )
@@ -215,11 +242,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
     "type": "object",
     "required": ["issuer", "client_id", "client_secret"],
     "properties": {
-        # TODO: fix the maxLength here depending on what MSC2528 decides
-        #   remember that we prefix the ID given here with `oidc-`
-        "idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
+        "idp_id": {
+            "type": "string",
+            "minLength": 1,
+            # MSC2858 allows a maxlen of 255, but we prefix with "oidc-"
+            "maxLength": 250,
+            "pattern": "^[A-Za-z0-9._~-]+$",
+        },
         "idp_name": {"type": "string"},
         "idp_icon": {"type": "string"},
+        "idp_brand": {
+            "type": "string",
+            # MSC2758-style namespaced identifier
+            "minLength": 1,
+            "maxLength": 255,
+            "pattern": "^[a-z][a-z0-9_.-]*$",
+        },
         "discover": {"type": "boolean"},
         "issuer": {"type": "string"},
         "client_id": {"type": "string"},
@@ -317,9 +355,10 @@ def _parse_oidc_config_dict(
     ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
     ump_config.setdefault("config", {})
 
-    (user_mapping_provider_class, user_mapping_provider_config,) = load_module(
-        ump_config, config_path + ("user_mapping_provider",)
-    )
+    (
+        user_mapping_provider_class,
+        user_mapping_provider_config,
+    ) = load_module(ump_config, config_path + ("user_mapping_provider",))
 
     # Ensure loaded user mapping module has defined all necessary methods
     required_methods = [
@@ -334,29 +373,16 @@ def _parse_oidc_config_dict(
     if missing_methods:
         raise ConfigError(
             "Class %s is missing required "
-            "methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),),
+            "methods: %s"
+            % (
+                user_mapping_provider_class,
+                ", ".join(missing_methods),
+            ),
             config_path + ("user_mapping_provider", "module"),
         )
 
-    # MSC2858 will apply certain limits in what can be used as an IdP id, so let's
-    # enforce those limits now.
-    # TODO: factor out this stuff to a generic function
     idp_id = oidc_config.get("idp_id", "oidc")
 
-    # TODO: update this validity check based on what MSC2858 decides.
-    valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._")
-
-    if any(c not in valid_idp_chars for c in idp_id):
-        raise ConfigError(
-            'idp_id may only contain a-z, 0-9, "-", ".", "_"',
-            config_path + ("idp_id",),
-        )
-
-    if idp_id[0] not in string.ascii_lowercase:
-        raise ConfigError(
-            "idp_id must start with a-z", config_path + ("idp_id",),
-        )
-
     # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid
     # clashes with other mechs (such as SAML, CAS).
     #
@@ -382,6 +408,7 @@ def _parse_oidc_config_dict(
         idp_id=idp_id,
         idp_name=oidc_config.get("idp_name", "OIDC"),
         idp_icon=idp_icon,
+        idp_brand=oidc_config.get("idp_brand"),
         discover=oidc_config.get("discover", True),
         issuer=oidc_config["issuer"],
         client_id=oidc_config["client_id"],
@@ -412,6 +439,9 @@ class OidcProviderConfig:
     # Optional MXC URI for icon for this IdP.
     idp_icon = attr.ib(type=Optional[str])
 
+    # Optional brand identifier for this IdP.
+    idp_brand = attr.ib(type=Optional[str])
+
     # whether the OIDC discovery mechanism is used to discover endpoints
     discover = attr.ib(type=bool)
 
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 14b8836197..def33a60ad 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -24,7 +24,7 @@ class RateLimitConfig:
         defaults={"per_second": 0.17, "burst_count": 3.0},
     ):
         self.per_second = config.get("per_second", defaults["per_second"])
-        self.burst_count = config.get("burst_count", defaults["burst_count"])
+        self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
 
 
 class FederationRateLimitConfig:
@@ -102,6 +102,20 @@ class RatelimitConfig(Config):
             defaults={"per_second": 0.01, "burst_count": 3},
         )
 
+        self.rc_3pid_validation = RateLimitConfig(
+            config.get("rc_3pid_validation") or {},
+            defaults={"per_second": 0.003, "burst_count": 5},
+        )
+
+        self.rc_invites_per_room = RateLimitConfig(
+            config.get("rc_invites", {}).get("per_room", {}),
+            defaults={"per_second": 0.3, "burst_count": 10},
+        )
+        self.rc_invites_per_user = RateLimitConfig(
+            config.get("rc_invites", {}).get("per_user", {}),
+            defaults={"per_second": 0.003, "burst_count": 5},
+        )
+
     def generate_config_section(self, **kwargs):
         return """\
         ## Ratelimiting ##
@@ -131,6 +145,9 @@ class RatelimitConfig(Config):
         #     users are joining rooms the server is already in (this is cheap) vs
         #     "remote" for when users are trying to join rooms not on the server (which
         #     can be more expensive)
+        #   - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
+        #   - two for ratelimiting how often invites can be sent in a room or to a
+        #     specific user.
         #
         # The defaults are as shown below.
         #
@@ -164,7 +181,18 @@ class RatelimitConfig(Config):
         #  remote:
         #    per_second: 0.01
         #    burst_count: 3
-
+        #
+        #rc_3pid_validation:
+        #  per_second: 0.003
+        #  burst_count: 5
+        #
+        #rc_invites:
+        #  per_room:
+        #    per_second: 0.3
+        #    burst_count: 10
+        #  per_user:
+        #    per_second: 0.003
+        #    burst_count: 5
 
         # Ratelimiting settings for incoming federation
         #
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 4bfc69cb7a..ead007ba5a 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -49,6 +49,10 @@ class AccountValidityConfig(Config):
 
             self.startup_job_max_delta = self.period * 10.0 / 100.0
 
+        if self.renew_by_email_enabled:
+            if "public_baseurl" not in synapse_config:
+                raise ConfigError("Can't send renewal emails without 'public_baseurl'")
+
         template_dir = config.get("template_dir")
 
         if not template_dir:
@@ -105,6 +109,13 @@ class RegistrationConfig(Config):
         account_threepid_delegates = config.get("account_threepid_delegates") or {}
         self.account_threepid_delegate_email = account_threepid_delegates.get("email")
         self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
+        if self.account_threepid_delegate_msisdn and not self.public_baseurl:
+            raise ConfigError(
+                "The configuration option `public_baseurl` is required if "
+                "`account_threepid_delegate.msisdn` is set, such that "
+                "clients know where to submit validation tokens to. Please "
+                "configure `public_baseurl`."
+            )
 
         self.default_identity_server = config.get("default_identity_server")
         self.allow_guest_access = config.get("allow_guest_access", False)
@@ -176,9 +187,7 @@ class RegistrationConfig(Config):
         self.session_lifetime = session_lifetime
 
         # The success template used during fallback auth.
-        self.fallback_success_template = self.read_templates(
-            ["auth_success.html"], autoescape=True
-        )[0]
+        self.fallback_success_template = self.read_template("auth_success.html")
 
     def generate_config_section(self, generate_secrets=False, **kwargs):
         if generate_secrets:
@@ -229,9 +238,8 @@ class RegistrationConfig(Config):
           # send an email to the account's email address with a renewal link. By
           # default, no such emails are sent.
           #
-          # If you enable this setting, you will also need to fill out the 'email'
-          # configuration section. You should also check that 'public_baseurl' is set
-          # correctly.
+          # If you enable this setting, you will also need to fill out the 'email' and
+          # 'public_baseurl' configuration sections.
           #
           #renew_at: 1w
 
@@ -322,7 +330,8 @@ class RegistrationConfig(Config):
         # The identity server which we suggest that clients should use when users log
         # in on this server.
         #
-        # (By default, no suggestion is made, so it is left up to the client.)
+        # (By default, no suggestion is made, so it is left up to the client.
+        # This setting is ignored unless public_baseurl is also set.)
         #
         #default_identity_server: https://matrix.org
 
@@ -347,6 +356,8 @@ class RegistrationConfig(Config):
         # by the Matrix Identity Service API specification:
         # https://matrix.org/docs/spec/identity_service/latest
         #
+        # If a delegate is specified, the config option public_baseurl must also be filled out.
+        #
         account_threepid_delegates:
             #email: https://example.com     # Delegate email sending to example.com
             #msisdn: http://localhost:8090  # Delegate SMS sending to this local process
@@ -380,6 +391,8 @@ class RegistrationConfig(Config):
         # By default, any room aliases included in this list will be created
         # as a publicly joinable room when the first user registers for the
         # homeserver. This behaviour can be customised with the settings below.
+        # If the room already exists, make certain it is a publicly joinable
+        # room. The join rule of the room must be set to 'public'.
         #
         #auto_join_rooms:
         #  - "#example:example.com"
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 850ac3ebd6..52849c3256 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -17,9 +17,7 @@ import os
 from collections import namedtuple
 from typing import Dict, List
 
-from netaddr import IPSet
-
-from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST
+from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
 from synapse.python_dependencies import DependencyException, check_requirements
 from synapse.util.module_loader import load_module
 
@@ -54,7 +52,7 @@ MediaStorageProviderConfig = namedtuple(
 
 
 def parse_thumbnail_requirements(thumbnail_sizes):
-    """ Takes a list of dictionaries with "width", "height", and "method" keys
+    """Takes a list of dictionaries with "width", "height", and "method" keys
     and creates a map from image media types to the thumbnail size, thumbnailing
     method, and thumbnail media type to precalculate
 
@@ -187,16 +185,17 @@ class ContentRepositoryConfig(Config):
                     "to work"
                 )
 
-            self.url_preview_ip_range_blacklist = IPSet(
-                config["url_preview_ip_range_blacklist"]
-            )
-
             # we always blacklist '0.0.0.0' and '::', which are supposed to be
             # unroutable addresses.
-            self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"])
+            self.url_preview_ip_range_blacklist = generate_ip_set(
+                config["url_preview_ip_range_blacklist"],
+                ["0.0.0.0", "::"],
+                config_path=("url_preview_ip_range_blacklist",),
+            )
 
-            self.url_preview_ip_range_whitelist = IPSet(
-                config.get("url_preview_ip_range_whitelist", ())
+            self.url_preview_ip_range_whitelist = generate_ip_set(
+                config.get("url_preview_ip_range_whitelist", ()),
+                config_path=("url_preview_ip_range_whitelist",),
             )
 
             self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ())
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 9a3e1c3e7d..2dd719c388 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -123,7 +123,7 @@ class RoomDirectoryConfig(Config):
             alias (str)
 
         Returns:
-            boolean: True if user is allowed to crate the alias
+            boolean: True if user is allowed to create the alias
         """
         for rule in self._alias_creation_rules:
             if rule.matches(user_id, room_id, [alias]):
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index f33dfa0d6a..4b494f217f 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -17,8 +17,7 @@
 import logging
 from typing import Any, List
 
-import attr
-
+from synapse.config.sso import SsoAttributeRequirement
 from synapse.python_dependencies import DependencyException, check_requirements
 from synapse.util.module_loader import load_module, load_python_module
 
@@ -189,13 +188,15 @@ class SAML2Config(Config):
         import saml2
 
         public_baseurl = self.public_baseurl
+        if public_baseurl is None:
+            raise ConfigError("saml2_config requires a public_baseurl to be set")
 
         if self.saml2_grandfathered_mxid_source_attribute:
             optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
         optional_attributes -= required_attributes
 
-        metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
-        response_url = public_baseurl + "_matrix/saml2/authn_response"
+        metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml"
+        response_url = public_baseurl + "_synapse/client/saml2/authn_response"
         return {
             "entityid": metadata_url,
             "service": {
@@ -233,10 +234,10 @@ class SAML2Config(Config):
         # enable SAML login.
         #
         # Once SAML support is enabled, a metadata file will be exposed at
-        # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
+        # https://<server>:<port>/_synapse/client/saml2/metadata.xml, which you may be able to
         # use to configure your SAML IdP with. Alternatively, you can manually configure
         # the IdP to use an ACS location of
-        # https://<server>:<port>/_matrix/saml2/authn_response.
+        # https://<server>:<port>/_synapse/client/saml2/authn_response.
         #
         saml2_config:
           # `sp_config` is the configuration for the pysaml2 Service Provider.
@@ -396,32 +397,18 @@ class SAML2Config(Config):
         }
 
 
-@attr.s(frozen=True)
-class SamlAttributeRequirement:
-    """Object describing a single requirement for SAML attributes."""
-
-    attribute = attr.ib(type=str)
-    value = attr.ib(type=str)
-
-    JSON_SCHEMA = {
-        "type": "object",
-        "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
-        "required": ["attribute", "value"],
-    }
-
-
 ATTRIBUTE_REQUIREMENTS_SCHEMA = {
     "type": "array",
-    "items": SamlAttributeRequirement.JSON_SCHEMA,
+    "items": SsoAttributeRequirement.JSON_SCHEMA,
 }
 
 
 def _parse_attribute_requirements_def(
     attribute_requirements: Any,
-) -> List[SamlAttributeRequirement]:
+) -> List[SsoAttributeRequirement]:
     validate_config(
         ATTRIBUTE_REQUIREMENTS_SCHEMA,
         attribute_requirements,
-        config_path=["saml2_config", "attribute_requirements"],
+        config_path=("saml2_config", "attribute_requirements"),
     )
-    return [SamlAttributeRequirement(**x) for x in attribute_requirements]
+    return [SsoAttributeRequirement(**x) for x in attribute_requirements]
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 47a0370173..6f3325ff81 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import itertools
 import logging
 import os.path
 import re
@@ -23,7 +24,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set
 
 import attr
 import yaml
-from netaddr import IPSet
+from netaddr import AddrFormatError, IPNetwork, IPSet
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.util.stringutils import parse_and_validate_server_name
@@ -40,6 +41,71 @@ logger = logging.Logger(__name__)
 # in the list.
 DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
 
+
+def _6to4(network: IPNetwork) -> IPNetwork:
+    """Convert an IPv4 network into a 6to4 IPv6 network per RFC 3056."""
+
+    # 6to4 networks consist of:
+    # * 2002 as the first 16 bits
+    # * The first IPv4 address in the network hex-encoded as the next 32 bits
+    # * The new prefix length needs to include the bits from the 2002 prefix.
+    hex_network = hex(network.first)[2:]
+    hex_network = ("0" * (8 - len(hex_network))) + hex_network
+    return IPNetwork(
+        "2002:%s:%s::/%d"
+        % (
+            hex_network[:4],
+            hex_network[4:],
+            16 + network.prefixlen,
+        )
+    )
+
+
+def generate_ip_set(
+    ip_addresses: Optional[Iterable[str]],
+    extra_addresses: Optional[Iterable[str]] = None,
+    config_path: Optional[Iterable[str]] = None,
+) -> IPSet:
+    """
+    Generate an IPSet from a list of IP addresses or CIDRs.
+
+    Additionally, for each IPv4 network in the list of IP addresses, also
+    includes the corresponding IPv6 networks.
+
+    This includes:
+
+    * IPv4-Compatible IPv6 Address (see RFC 4291, section 2.5.5.1)
+    * IPv4-Mapped IPv6 Address (see RFC 4291, section 2.5.5.2)
+    * 6to4 Address (see RFC 3056, section 2)
+
+    Args:
+        ip_addresses: An iterable of IP addresses or CIDRs.
+        extra_addresses: An iterable of IP addresses or CIDRs.
+        config_path: The path in the configuration for error messages.
+
+    Returns:
+        A new IP set.
+    """
+    result = IPSet()
+    for ip in itertools.chain(ip_addresses or (), extra_addresses or ()):
+        try:
+            network = IPNetwork(ip)
+        except AddrFormatError as e:
+            raise ConfigError(
+                "Invalid IP range provided: %s." % (ip,), config_path
+            ) from e
+        result.add(network)
+
+        # It is possible that these already exist in the set, but that's OK.
+        if ":" not in str(network):
+            result.add(IPNetwork(network).ipv6(ipv4_compatible=True))
+            result.add(IPNetwork(network).ipv6(ipv4_compatible=False))
+            result.add(_6to4(network))
+
+    return result
+
+
+# IP ranges that are considered private / unroutable / don't make sense.
 DEFAULT_IP_RANGE_BLACKLIST = [
     # Localhost
     "127.0.0.0/8",
@@ -53,6 +119,8 @@ DEFAULT_IP_RANGE_BLACKLIST = [
     "192.0.0.0/24",
     # Link-local networks.
     "169.254.0.0/16",
+    # Formerly used for 6to4 relay.
+    "192.88.99.0/24",
     # Testing networks.
     "198.18.0.0/15",
     "192.0.2.0/24",
@@ -66,6 +134,12 @@ DEFAULT_IP_RANGE_BLACKLIST = [
     "fe80::/10",
     # Unique local addresses.
     "fc00::/7",
+    # Testing networks.
+    "2001:db8::/32",
+    # Multicast.
+    "ff00::/8",
+    # Site-local addresses
+    "fec0::/10",
 ]
 
 DEFAULT_ROOM_VERSION = "6"
@@ -161,11 +235,7 @@ class ServerConfig(Config):
         self.print_pidfile = config.get("print_pidfile")
         self.user_agent_suffix = config.get("user_agent_suffix")
         self.use_frozen_dicts = config.get("use_frozen_dicts", False)
-        self.public_baseurl = config.get("public_baseurl") or "https://%s/" % (
-            self.server_name,
-        )
-        if self.public_baseurl[-1] != "/":
-            self.public_baseurl += "/"
+        self.public_baseurl = config.get("public_baseurl")
 
         # Whether to enable user presence.
         self.use_presence = config.get("use_presence", True)
@@ -189,7 +259,8 @@ class ServerConfig(Config):
         # Whether to require sharing a room with a user to retrieve their
         # profile data
         self.limit_profile_requests_to_users_who_share_rooms = config.get(
-            "limit_profile_requests_to_users_who_share_rooms", False,
+            "limit_profile_requests_to_users_who_share_rooms",
+            False,
         )
 
         if "restrict_public_rooms_to_local_users" in config and (
@@ -294,17 +365,15 @@ class ServerConfig(Config):
         )
 
         # Attempt to create an IPSet from the given ranges
-        try:
-            self.ip_range_blacklist = IPSet(ip_range_blacklist)
-        except Exception as e:
-            raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e
+
         # Always blacklist 0.0.0.0, ::
-        self.ip_range_blacklist.update(["0.0.0.0", "::"])
+        self.ip_range_blacklist = generate_ip_set(
+            ip_range_blacklist, ["0.0.0.0", "::"], config_path=("ip_range_blacklist",)
+        )
 
-        try:
-            self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ()))
-        except Exception as e:
-            raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e
+        self.ip_range_whitelist = generate_ip_set(
+            config.get("ip_range_whitelist", ()), config_path=("ip_range_whitelist",)
+        )
 
         # The federation_ip_range_blacklist is used for backwards-compatibility
         # and only applies to federation and identity servers. If it is not given,
@@ -312,15 +381,16 @@ class ServerConfig(Config):
         federation_ip_range_blacklist = config.get(
             "federation_ip_range_blacklist", ip_range_blacklist
         )
-        try:
-            self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist)
-        except Exception as e:
-            raise ConfigError(
-                "Invalid range(s) provided in federation_ip_range_blacklist."
-            ) from e
         # Always blacklist 0.0.0.0, ::
-        self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+        self.federation_ip_range_blacklist = generate_ip_set(
+            federation_ip_range_blacklist,
+            ["0.0.0.0", "::"],
+            config_path=("federation_ip_range_blacklist",),
+        )
 
+        if self.public_baseurl is not None:
+            if self.public_baseurl[-1] != "/":
+                self.public_baseurl += "/"
         self.start_pushers = config.get("start_pushers", True)
 
         # (undocumented) option for torturing the worker-mode replication a bit,
@@ -550,7 +620,9 @@ class ServerConfig(Config):
         if manhole:
             self.listeners.append(
                 ListenerConfig(
-                    port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+                    port=manhole,
+                    bind_addresses=["127.0.0.1"],
+                    type="manhole",
                 )
             )
 
@@ -586,7 +658,8 @@ class ServerConfig(Config):
         # and letting the client know which email address is bound to an account and
         # which one isn't.
         self.request_token_inhibit_3pid_errors = config.get(
-            "request_token_inhibit_3pid_errors", False,
+            "request_token_inhibit_3pid_errors",
+            False,
         )
 
         # List of users trialing the new experimental default push rules. This setting is
@@ -748,10 +821,6 @@ class ServerConfig(Config):
         # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
         # 'listeners' below).
         #
-        # If this is left unset, it defaults to 'https://<server_name>/'. (Note that
-        # that will not work unless you configure Synapse or a reverse-proxy to listen
-        # on port 443.)
-        #
         #public_baseurl: https://example.com/
 
         # Set the soft limit on the number of file descriptors synapse can use
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 59be825532..243cc681e8 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -12,14 +12,30 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict
+from typing import Any, Dict, Optional
+
+import attr
 
 from ._base import Config
 
 
+@attr.s(frozen=True)
+class SsoAttributeRequirement:
+    """Object describing a single requirement for SSO attributes."""
+
+    attribute = attr.ib(type=str)
+    # If a value is not given, than the attribute must simply exist.
+    value = attr.ib(type=Optional[str])
+
+    JSON_SCHEMA = {
+        "type": "object",
+        "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
+        "required": ["attribute", "value"],
+    }
+
+
 class SSOConfig(Config):
-    """SSO Configuration
-    """
+    """SSO Configuration"""
 
     section = "sso"
 
@@ -27,7 +43,7 @@ class SSOConfig(Config):
         sso_config = config.get("sso") or {}  # type: Dict[str, Any]
 
         # The sso-specific template_dir
-        template_dir = sso_config.get("template_dir")
+        self.sso_template_dir = sso_config.get("template_dir")
 
         # Read templates from disk
         (
@@ -48,7 +64,7 @@ class SSOConfig(Config):
                 "sso_auth_success.html",
                 "sso_auth_bad_user.html",
             ],
-            template_dir,
+            self.sso_template_dir,
         )
 
         # These templates have no placeholders, so render them here
@@ -64,8 +80,11 @@ class SSOConfig(Config):
         # gracefully to the client). This would make it pointless to ask the user for
         # confirmation, since the URL the confirmation page would be showing wouldn't be
         # the client's.
-        login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
-        self.sso_client_whitelist.append(login_fallback_url)
+        # public_baseurl is an optional setting, so we only add the fallback's URL to the
+        # list if it's provided (because we can't figure out what that URL is otherwise).
+        if self.public_baseurl:
+            login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
+            self.sso_client_whitelist.append(login_fallback_url)
 
     def generate_config_section(self, **kwargs):
         return """\
@@ -83,9 +102,9 @@ class SSOConfig(Config):
             # phishing attacks from evil.site. To avoid this, include a slash after the
             # hostname: "https://my.client/".
             #
-            # The login fallback page (used by clients that don't natively support the
-            # required login flows) is automatically whitelisted in addition to any URLs
-            # in this list.
+            # If public_baseurl is set, then the login fallback page (used by clients
+            # that don't natively support the required login flows) is whitelisted in
+            # addition to any URLs in this list.
             #
             # By default, this list is empty.
             #
@@ -106,15 +125,19 @@ class SSOConfig(Config):
             #
             #   When rendering, this template is given the following variables:
             #     * redirect_url: the URL that the user will be redirected to after
-            #       login. Needs manual escaping (see
-            #       https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+            #       login.
             #
             #     * server_name: the homeserver's name.
             #
             #     * providers: a list of available Identity Providers. Each element is
             #       an object with the following attributes:
+            #
             #         * idp_id: unique identifier for the IdP
             #         * idp_name: user-facing name for the IdP
+            #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+            #              for the IdP
+            #         * idp_brand: if specified in the IdP config, a textual identifier
+            #              for the brand of the IdP
             #
             #   The rendered HTML page should contain a form which submits its results
             #   back as a GET request, with the following query parameters:
@@ -124,33 +147,101 @@ class SSOConfig(Config):
             #
             #     * idp: the 'idp_id' of the chosen IDP.
             #
+            # * HTML page to prompt new users to enter a userid and confirm other
+            #   details: 'sso_auth_account_details.html'. This is only shown if the
+            #   SSO implementation (with any user_mapping_provider) does not return
+            #   a localpart.
+            #
+            #   When rendering, this template is given the following variables:
+            #
+            #     * server_name: the homeserver's name.
+            #
+            #     * idp: details of the SSO Identity Provider that the user logged in
+            #       with: an object with the following attributes:
+            #
+            #         * idp_id: unique identifier for the IdP
+            #         * idp_name: user-facing name for the IdP
+            #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+            #              for the IdP
+            #         * idp_brand: if specified in the IdP config, a textual identifier
+            #              for the brand of the IdP
+            #
+            #     * user_attributes: an object containing details about the user that
+            #       we received from the IdP. May have the following attributes:
+            #
+            #         * display_name: the user's display_name
+            #         * emails: a list of email addresses
+            #
+            #   The template should render a form which submits the following fields:
+            #
+            #     * username: the localpart of the user's chosen user id
+            #
+            # * HTML page allowing the user to consent to the server's terms and
+            #   conditions. This is only shown for new users, and only if
+            #   `user_consent.require_at_registration` is set.
+            #
+            #   When rendering, this template is given the following variables:
+            #
+            #     * server_name: the homeserver's name.
+            #
+            #     * user_id: the user's matrix proposed ID.
+            #
+            #     * user_profile.display_name: the user's proposed display name, if any.
+            #
+            #     * consent_version: the version of the terms that the user will be
+            #       shown
+            #
+            #     * terms_url: a link to the page showing the terms.
+            #
+            #   The template should render a form which submits the following fields:
+            #
+            #     * accepted_version: the version of the terms accepted by the user
+            #       (ie, 'consent_version' from the input variables).
+            #
             # * HTML page for a confirmation step before redirecting back to the client
             #   with the login token: 'sso_redirect_confirm.html'.
             #
-            #   When rendering, this template is given three variables:
-            #     * redirect_url: the URL the user is about to be redirected to. Needs
-            #                     manual escaping (see
-            #                     https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+            #   When rendering, this template is given the following variables:
+            #
+            #     * redirect_url: the URL the user is about to be redirected to.
             #
             #     * display_url: the same as `redirect_url`, but with the query
             #                    parameters stripped. The intention is to have a
             #                    human-readable URL to show to users, not to use it as
-            #                    the final address to redirect to. Needs manual escaping
-            #                    (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+            #                    the final address to redirect to.
             #
             #     * server_name: the homeserver's name.
             #
+            #     * new_user: a boolean indicating whether this is the user's first time
+            #          logging in.
+            #
+            #     * user_id: the user's matrix ID.
+            #
+            #     * user_profile.avatar_url: an MXC URI for the user's avatar, if any.
+            #           None if the user has not set an avatar.
+            #
+            #     * user_profile.display_name: the user's display name. None if the user
+            #           has not set a display name.
+            #
             # * HTML page which notifies the user that they are authenticating to confirm
             #   an operation on their account during the user interactive authentication
             #   process: 'sso_auth_confirm.html'.
             #
             #   When rendering, this template is given the following variables:
-            #     * redirect_url: the URL the user is about to be redirected to. Needs
-            #                     manual escaping (see
-            #                     https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+            #     * redirect_url: the URL the user is about to be redirected to.
             #
             #     * description: the operation which the user is being asked to confirm
             #
+            #     * idp: details of the Identity Provider that we will use to confirm
+            #       the user's identity: an object with the following attributes:
+            #
+            #         * idp_id: unique identifier for the IdP
+            #         * idp_name: user-facing name for the IdP
+            #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+            #              for the IdP
+            #         * idp_brand: if specified in the IdP config, a textual identifier
+            #              for the brand of the IdP
+            #
             # * HTML page shown after a successful user interactive authentication session:
             #   'sso_auth_success.html'.
             #
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index f10e33f7b8..7a0ca16da8 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -33,8 +33,7 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
 
 @attr.s
 class InstanceLocationConfig:
-    """The host and port to talk to an instance via HTTP replication.
-    """
+    """The host and port to talk to an instance via HTTP replication."""
 
     host = attr.ib(type=str)
     port = attr.ib(type=int)
@@ -54,13 +53,19 @@ class WriterLocations:
     )
     typing = attr.ib(default="master", type=str)
     to_device = attr.ib(
-        default=["master"], type=List[str], converter=_instance_to_list_converter,
+        default=["master"],
+        type=List[str],
+        converter=_instance_to_list_converter,
     )
     account_data = attr.ib(
-        default=["master"], type=List[str], converter=_instance_to_list_converter,
+        default=["master"],
+        type=List[str],
+        converter=_instance_to_list_converter,
     )
     receipts = attr.ib(
-        default=["master"], type=List[str], converter=_instance_to_list_converter,
+        default=["master"],
+        type=List[str],
+        converter=_instance_to_list_converter,
     )
 
 
@@ -107,7 +112,9 @@ class WorkerConfig(Config):
         if manhole:
             self.worker_listeners.append(
                 ListenerConfig(
-                    port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
+                    port=manhole,
+                    bind_addresses=["127.0.0.1"],
+                    type="manhole",
                 )
             )
 
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 74b67b230a..14b21796d9 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -125,19 +125,24 @@ class FederationPolicyForHTTPS:
         self._no_verify_ssl_context = _no_verify_ssl.getContext()
         self._no_verify_ssl_context.set_info_callback(_context_info_cb)
 
-    def get_options(self, host: bytes):
+        self._should_verify = self._config.federation_verify_certificates
+
+        self._federation_certificate_verification_whitelist = (
+            self._config.federation_certificate_verification_whitelist
+        )
 
+    def get_options(self, host: bytes):
         # IPolicyForHTTPS.get_options takes bytes, but we want to compare
         # against the str whitelist. The hostnames in the whitelist are already
         # IDNA-encoded like the hosts will be here.
         ascii_host = host.decode("ascii")
 
         # Check if certificate verification has been enabled
-        should_verify = self._config.federation_verify_certificates
+        should_verify = self._should_verify
 
         # Check if we've disabled certificate verification for this host
-        if should_verify:
-            for regex in self._config.federation_certificate_verification_whitelist:
+        if self._should_verify:
+            for regex in self._federation_certificate_verification_whitelist:
                 if regex.match(ascii_host):
                     should_verify = False
                     break
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 56f8dc9caf..91ad5b3d3c 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -42,7 +42,7 @@ def check(
     do_sig_check: bool = True,
     do_size_check: bool = True,
 ) -> None:
-    """ Checks if this event is correctly authed.
+    """Checks if this event is correctly authed.
 
     Args:
         room_version_obj: the version of the room
@@ -423,7 +423,9 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
 
 
 def check_redaction(
-    room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+    room_version_obj: RoomVersion,
+    event: EventBase,
+    auth_events: StateMap[EventBase],
 ) -> bool:
     """Check whether the event sender is allowed to redact the target event.
 
@@ -459,7 +461,9 @@ def check_redaction(
 
 
 def _check_power_levels(
-    room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+    room_version_obj: RoomVersion,
+    event: EventBase,
+    auth_events: StateMap[EventBase],
 ) -> None:
     user_list = event.content.get("users", {})
     # Validate users
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 07df258e6e..c1c0426f6e 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -98,7 +98,9 @@ class EventBuilder:
         return self._state_key is not None
 
     async def build(
-        self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
+        self,
+        prev_event_ids: List[str],
+        auth_event_ids: Optional[List[str]],
     ) -> EventBase:
         """Transform into a fully signed and hashed event
 
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index afecafe15c..7295df74fe 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -341,8 +341,7 @@ def _encode_state_dict(state_dict):
 
 
 def _decode_state_dict(input):
-    """Decodes a state dict encoded using `_encode_state_dict` above
-    """
+    """Decodes a state dict encoded using `_encode_state_dict` above"""
     if input is None:
         return None
 
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index e7e3a7b9a4..8cfc0bb3cb 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -17,6 +17,8 @@
 import inspect
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
+from synapse.rest.media.v1._base import FileInfo
+from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import Collection
 from synapse.util.async_helpers import maybe_awaitable
@@ -214,3 +216,48 @@ class SpamChecker:
                     return behaviour
 
         return RegistrationBehaviour.ALLOW
+
+    async def check_media_file_for_spam(
+        self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+    ) -> bool:
+        """Checks if a piece of newly uploaded media should be blocked.
+
+        This will be called for local uploads, downloads of remote media, each
+        thumbnail generated for those, and web pages/images used for URL
+        previews.
+
+        Note that care should be taken to not do blocking IO operations in the
+        main thread. For example, to get the contents of a file a module
+        should do::
+
+            async def check_media_file_for_spam(
+                self, file: ReadableFileWrapper, file_info: FileInfo
+            ) -> bool:
+                buffer = BytesIO()
+                await file.write_chunks_to(buffer.write)
+
+                if buffer.getvalue() == b"Hello World":
+                    return True
+
+                return False
+
+
+        Args:
+            file: An object that allows reading the contents of the media.
+            file_info: Metadata about the file.
+
+        Returns:
+            True if the media should be blocked or False if it should be
+            allowed.
+        """
+
+        for spam_checker in self.spam_checkers:
+            # For backwards compatibility, only run if the method exists on the
+            # spam checker
+            checker = getattr(spam_checker, "check_media_file_for_spam", None)
+            if checker:
+                spam = await maybe_awaitable(checker(file_wrapper, file_info))
+                if spam:
+                    return True
+
+        return False
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 77fbd3f68a..02bce8b5c9 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -40,7 +40,8 @@ class ThirdPartyEventRules:
 
         if module is not None:
             self.third_party_rules = module(
-                config=config, module_api=hs.get_module_api(),
+                config=config,
+                module_api=hs.get_module_api(),
             )
 
     async def check_event_allowed(
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9c22e33813..7ca5c9940a 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -34,7 +34,7 @@ SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
 
 
 def prune_event(event: EventBase) -> EventBase:
-    """ Returns a pruned version of the given event, which removes all keys we
+    """Returns a pruned version of the given event, which removes all keys we
     don't know about or think could potentially be dodgy.
 
     This is used when we "redact" an event. We want to remove all fields that
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 302b2f69bc..bee81fc019 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,6 +18,7 @@ import copy
 import itertools
 import logging
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Callable,
@@ -26,7 +27,6 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Sequence,
     Tuple,
     TypeVar,
     Union,
@@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):
 
 
 class FederationClient(FederationBase):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.pdu_destination_tried = {}
+        self.pdu_destination_tried = {}  # type: Dict[str, Dict[str, int]]
         self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
         self.state = hs.get_state_handler()
         self.transport_layer = hs.get_federation_transport_client()
@@ -116,33 +119,32 @@ class FederationClient(FederationBase):
                 self.pdu_destination_tried[event_id] = destination_dict
 
     @log_function
-    def make_query(
+    async def make_query(
         self,
-        destination,
-        query_type,
-        args,
-        retry_on_dns_fail=False,
-        ignore_backoff=False,
-    ):
+        destination: str,
+        query_type: str,
+        args: dict,
+        retry_on_dns_fail: bool = False,
+        ignore_backoff: bool = False,
+    ) -> JsonDict:
         """Sends a federation Query to a remote homeserver of the given type
         and arguments.
 
         Args:
-            destination (str): Domain name of the remote homeserver
-            query_type (str): Category of the query type; should match the
+            destination: Domain name of the remote homeserver
+            query_type: Category of the query type; should match the
                 handler name used in register_query_handler().
-            args (dict): Mapping of strings to strings containing the details
+            args: Mapping of strings to strings containing the details
                 of the query request.
-            ignore_backoff (bool): true to ignore the historical backoff data
+            ignore_backoff: true to ignore the historical backoff data
                 and try the request anyway.
 
         Returns:
-            a Awaitable which will eventually yield a JSON object from the
-            response
+            The JSON object from the response
         """
         sent_queries_counter.labels(query_type).inc()
 
-        return self.transport_layer.make_query(
+        return await self.transport_layer.make_query(
             destination,
             query_type,
             args,
@@ -151,42 +153,52 @@ class FederationClient(FederationBase):
         )
 
     @log_function
-    def query_client_keys(self, destination, content, timeout):
+    async def query_client_keys(
+        self, destination: str, content: JsonDict, timeout: int
+    ) -> JsonDict:
         """Query device keys for a device hosted on a remote server.
 
         Args:
-            destination (str): Domain name of the remote homeserver
-            content (dict): The query content.
+            destination: Domain name of the remote homeserver
+            content: The query content.
 
         Returns:
-            an Awaitable which will eventually yield a JSON object from the
-            response
+            The JSON object from the response
         """
         sent_queries_counter.labels("client_device_keys").inc()
-        return self.transport_layer.query_client_keys(destination, content, timeout)
+        return await self.transport_layer.query_client_keys(
+            destination, content, timeout
+        )
 
     @log_function
-    def query_user_devices(self, destination, user_id, timeout=30000):
+    async def query_user_devices(
+        self, destination: str, user_id: str, timeout: int = 30000
+    ) -> JsonDict:
         """Query the device keys for a list of user ids hosted on a remote
         server.
         """
         sent_queries_counter.labels("user_devices").inc()
-        return self.transport_layer.query_user_devices(destination, user_id, timeout)
+        return await self.transport_layer.query_user_devices(
+            destination, user_id, timeout
+        )
 
     @log_function
-    def claim_client_keys(self, destination, content, timeout):
+    async def claim_client_keys(
+        self, destination: str, content: JsonDict, timeout: int
+    ) -> JsonDict:
         """Claims one-time keys for a device hosted on a remote server.
 
         Args:
-            destination (str): Domain name of the remote homeserver
-            content (dict): The query content.
+            destination: Domain name of the remote homeserver
+            content: The query content.
 
         Returns:
-            an Awaitable which will eventually yield a JSON object from the
-            response
+            The JSON object from the response
         """
         sent_queries_counter.labels("client_one_time_keys").inc()
-        return self.transport_layer.claim_client_keys(destination, content, timeout)
+        return await self.transport_layer.claim_client_keys(
+            destination, content, timeout
+        )
 
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
@@ -195,10 +207,10 @@ class FederationClient(FederationBase):
         given destination server.
 
         Args:
-            dest (str): The remote homeserver to ask.
-            room_id (str): The room_id to backfill.
-            limit (int): The maximum number of events to return.
-            extremities (list): our current backwards extremities, to backfill from
+            dest: The remote homeserver to ask.
+            room_id: The room_id to backfill.
+            limit: The maximum number of events to return.
+            extremities: our current backwards extremities, to backfill from
         """
         logger.debug("backfill extrem=%s", extremities)
 
@@ -370,7 +382,7 @@ class FederationClient(FederationBase):
                 for events that have failed their checks
 
         Returns:
-            Deferred : A list of PDUs that have valid signatures and hashes.
+            A list of PDUs that have valid signatures and hashes.
         """
         deferreds = self._check_sigs_and_hashes(room_version, pdus)
 
@@ -418,7 +430,9 @@ class FederationClient(FederationBase):
         else:
             return [p for p in valid_pdus if p]
 
-    async def get_event_auth(self, destination, room_id, event_id):
+    async def get_event_auth(
+        self, destination: str, room_id: str, event_id: str
+    ) -> List[EventBase]:
         res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
 
         room_version = await self.store.get_room_version(room_id)
@@ -700,18 +714,16 @@ class FederationClient(FederationBase):
 
         return await self._try_destination_list("send_join", destinations, send_request)
 
-    async def _do_send_join(self, destination: str, pdu: EventBase):
+    async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
         time_now = self._clock.time_msec()
 
         try:
-            content = await self.transport_layer.send_join_v2(
+            return await self.transport_layer.send_join_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
                 content=pdu.get_pdu_json(time_now),
             )
-
-            return content
         except HttpResponseException as e:
             if e.code in [400, 404]:
                 err = e.to_synapse_error()
@@ -738,7 +750,11 @@ class FederationClient(FederationBase):
         return resp[1]
 
     async def send_invite(
-        self, destination: str, room_id: str, event_id: str, pdu: EventBase,
+        self,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        pdu: EventBase,
     ) -> EventBase:
         room_version = await self.store.get_room_version(room_id)
 
@@ -769,7 +785,7 @@ class FederationClient(FederationBase):
         time_now = self._clock.time_msec()
 
         try:
-            content = await self.transport_layer.send_invite_v2(
+            return await self.transport_layer.send_invite_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
@@ -779,7 +795,6 @@ class FederationClient(FederationBase):
                     "invite_room_state": pdu.unsigned.get("invite_room_state", []),
                 },
             )
-            return content
         except HttpResponseException as e:
             if e.code in [400, 404]:
                 err = e.to_synapse_error()
@@ -799,7 +814,7 @@ class FederationClient(FederationBase):
                         "User's homeserver does not support this room version",
                         Codes.UNSUPPORTED_ROOM_VERSION,
                     )
-            elif e.code == 403:
+            elif e.code in (403, 429):
                 raise e.to_synapse_error()
             else:
                 raise
@@ -842,18 +857,16 @@ class FederationClient(FederationBase):
             "send_leave", destinations, send_request
         )
 
-    async def _do_send_leave(self, destination, pdu):
+    async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
         time_now = self._clock.time_msec()
 
         try:
-            content = await self.transport_layer.send_leave_v2(
+            return await self.transport_layer.send_leave_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
                 content=pdu.get_pdu_json(time_now),
             )
-
-            return content
         except HttpResponseException as e:
             if e.code in [400, 404]:
                 err = e.to_synapse_error()
@@ -879,7 +892,7 @@ class FederationClient(FederationBase):
         # content.
         return resp[1]
 
-    def get_public_rooms(
+    async def get_public_rooms(
         self,
         remote_server: str,
         limit: Optional[int] = None,
@@ -887,7 +900,7 @@ class FederationClient(FederationBase):
         search_filter: Optional[Dict] = None,
         include_all_networks: bool = False,
         third_party_instance_id: Optional[str] = None,
-    ):
+    ) -> JsonDict:
         """Get the list of public rooms from a remote homeserver
 
         Args:
@@ -901,8 +914,7 @@ class FederationClient(FederationBase):
                 party instance
 
         Returns:
-            Awaitable[Dict[str, Any]]: The response from the remote server, or None if
-            `remote_server` is the same as the local server_name
+            The response from the remote server.
 
         Raises:
             HttpResponseException: There was an exception returned from the remote server
@@ -910,7 +922,7 @@ class FederationClient(FederationBase):
                 requests over federation
 
         """
-        return self.transport_layer.get_public_rooms(
+        return await self.transport_layer.get_public_rooms(
             remote_server,
             limit,
             since_token,
@@ -923,7 +935,7 @@ class FederationClient(FederationBase):
         self,
         destination: str,
         room_id: str,
-        earliest_events_ids: Sequence[str],
+        earliest_events_ids: Iterable[str],
         latest_events: Iterable[EventBase],
         limit: int,
         min_depth: int,
@@ -974,7 +986,9 @@ class FederationClient(FederationBase):
 
         return signed_events
 
-    async def forward_third_party_invite(self, destinations, room_id, event_dict):
+    async def forward_third_party_invite(
+        self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
+    ) -> None:
         for destination in destinations:
             if destination == self.server_name:
                 continue
@@ -983,7 +997,7 @@ class FederationClient(FederationBase):
                 await self.transport_layer.exchange_third_party_invite(
                     destination=destination, room_id=room_id, event_dict=event_dict
                 )
-                return None
+                return
             except CodeMessageException:
                 raise
             except Exception as e:
@@ -995,7 +1009,7 @@ class FederationClient(FederationBase):
 
     async def get_room_complexity(
         self, destination: str, room_id: str
-    ) -> Optional[dict]:
+    ) -> Optional[JsonDict]:
         """
         Fetch the complexity of a remote room from another server.
 
@@ -1008,10 +1022,9 @@ class FederationClient(FederationBase):
             could not fetch the complexity.
         """
         try:
-            complexity = await self.transport_layer.get_room_complexity(
+            return await self.transport_layer.get_room_complexity(
                 destination=destination, room_id=room_id
             )
-            return complexity
         except CodeMessageException as e:
             # We didn't manage to get it -- probably a 404. We are okay if other
             # servers don't give it to us.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 171d25c945..8d4bb621e7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -85,7 +85,8 @@ received_queries_counter = Counter(
 )
 
 pdu_process_time = Histogram(
-    "synapse_federation_server_pdu_process_time", "Time taken to process an event",
+    "synapse_federation_server_pdu_process_time",
+    "Time taken to process an event",
 )
 
 
@@ -204,7 +205,7 @@ class FederationServer(FederationBase):
     async def _handle_incoming_transaction(
         self, origin: str, transaction: Transaction, request_time: int
     ) -> Tuple[int, Dict[str, Any]]:
-        """ Process an incoming transaction and return the HTTP response
+        """Process an incoming transaction and return the HTTP response
 
         Args:
             origin: the server making the request
@@ -373,8 +374,7 @@ class FederationServer(FederationBase):
         return pdu_results
 
     async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
-        """Process the EDUs in a received transaction.
-        """
+        """Process the EDUs in a received transaction."""
 
         async def _process_edu(edu_dict):
             received_edus_counter.inc()
@@ -437,7 +437,10 @@ class FederationServer(FederationBase):
             raise AuthError(403, "Host not in room.")
 
         resp = await self._state_ids_resp_cache.wrap(
-            (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id,
+            (room_id, event_id),
+            self._on_state_ids_request_compute,
+            room_id,
+            event_id,
         )
 
         return 200, resp
@@ -679,7 +682,7 @@ class FederationServer(FederationBase):
         )
 
     async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
-        """ Process a PDU received in a federation /send/ transaction.
+        """Process a PDU received in a federation /send/ transaction.
 
         If the event is invalid, then this method throws a FederationError.
         (The error will then be logged and sent back to the sender (which
@@ -906,13 +909,11 @@ class FederationHandlerRegistry:
         self.query_handlers[query_type] = handler
 
     def register_instance_for_edu(self, edu_type: str, instance_name: str):
-        """Register that the EDU handler is on a different instance than master.
-        """
+        """Register that the EDU handler is on a different instance than master."""
         self._edu_type_to_instance[edu_type] = [instance_name]
 
     def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
-        """Register that the EDU handler is on multiple instances.
-        """
+        """Register that the EDU handler is on multiple instances."""
         self._edu_type_to_instance[edu_type] = instance_names
 
     async def on_edu(self, edu_type: str, origin: str, content: dict):
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 079e2b2fe0..ce5fc758f0 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -30,8 +30,7 @@ logger = logging.getLogger(__name__)
 
 
 class TransactionActions:
-    """ Defines persistence actions that relate to handling Transactions.
-    """
+    """Defines persistence actions that relate to handling Transactions."""
 
     def __init__(self, datastore):
         self.store = datastore
@@ -57,8 +56,7 @@ class TransactionActions:
     async def set_response(
         self, origin: str, transaction: Transaction, code: int, response: JsonDict
     ) -> None:
-        """Persist how we responded to a transaction.
-        """
+        """Persist how we responded to a transaction."""
         transaction_id = transaction.transaction_id  # type: ignore
         if not transaction_id:
             raise RuntimeError("Cannot persist a transaction with no transaction_id")
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5f1bf492c1..3e993b428b 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -468,8 +468,7 @@ class KeyedEduRow(
 
 
 class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))):  # Edu
-    """Streams EDUs that don't have keys. See KeyedEduRow
-    """
+    """Streams EDUs that don't have keys. See KeyedEduRow"""
 
     TypeId = "e"
 
@@ -519,7 +518,10 @@ def process_rows_for_federation(transaction_queue, rows):
     # them into the appropriate collection and then send them off.
 
     buff = ParsedFederationStreamData(
-        presence=[], presence_destinations=[], keyed_edus={}, edus={},
+        presence=[],
+        presence_destinations=[],
+        keyed_edus={},
+        edus={},
     )
 
     # Parse the rows in the stream and add to the buffer
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 604cfd1935..97fc4d0a82 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -142,6 +142,8 @@ class FederationSender:
             self._wake_destinations_needing_catchup,
         )
 
+        self._external_cache = hs.get_external_cache()
+
     def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
         """Get or create a PerDestinationQueue for the given destination
 
@@ -197,22 +199,40 @@ class FederationSender:
                     if not event.internal_metadata.should_proactively_send():
                         return
 
-                    try:
-                        # Get the state from before the event.
-                        # We need to make sure that this is the state from before
-                        # the event and not from after it.
-                        # Otherwise if the last member on a server in a room is
-                        # banned then it won't receive the event because it won't
-                        # be in the room after the ban.
-                        destinations = await self.state.get_hosts_in_room_at_events(
-                            event.room_id, event_ids=event.prev_event_ids()
-                        )
-                    except Exception:
-                        logger.exception(
-                            "Failed to calculate hosts in room for event: %s",
-                            event.event_id,
+                    destinations = None  # type: Optional[Set[str]]
+                    if not event.prev_event_ids():
+                        # If there are no prev event IDs then the state is empty
+                        # and so no remote servers in the room
+                        destinations = set()
+                    else:
+                        # We check the external cache for the destinations, which is
+                        # stored per state group.
+
+                        sg = await self._external_cache.get(
+                            "event_to_prev_state_group", event.event_id
                         )
-                        return
+                        if sg:
+                            destinations = await self._external_cache.get(
+                                "get_joined_hosts", str(sg)
+                            )
+
+                    if destinations is None:
+                        try:
+                            # Get the state from before the event.
+                            # We need to make sure that this is the state from before
+                            # the event and not from after it.
+                            # Otherwise if the last member on a server in a room is
+                            # banned then it won't receive the event because it won't
+                            # be in the room after the ban.
+                            destinations = await self.state.get_hosts_in_room_at_events(
+                                event.room_id, event_ids=event.prev_event_ids()
+                            )
+                        except Exception:
+                            logger.exception(
+                                "Failed to calculate hosts in room for event: %s",
+                                event.event_id,
+                            )
+                            return
 
                     destinations = {
                         d
@@ -308,7 +328,9 @@ class FederationSender:
         # to allow us to perform catch-up later on if the remote is unreachable
         # for a while.
         await self.store.store_destination_rooms_entries(
-            destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
+            destinations,
+            pdu.room_id,
+            pdu.internal_metadata.stream_ordering,
         )
 
         for destination in destinations:
@@ -455,7 +477,7 @@ class FederationSender:
         self, states: List[UserPresenceState], destinations: List[str]
     ) -> None:
         """Send the given presence states to the given destinations.
-            destinations (list[str])
+        destinations (list[str])
         """
 
         if not states or not self.hs.config.use_presence:
@@ -596,8 +618,8 @@ class FederationSender:
         last_processed = None  # type: Optional[str]
 
         while True:
-            destinations_to_wake = await self.store.get_catch_up_outstanding_destinations(
-                last_processed
+            destinations_to_wake = (
+                await self.store.get_catch_up_outstanding_destinations(last_processed)
             )
 
             if not destinations_to_wake:
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index db8e456fe8..deb519f3ef 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -85,7 +85,8 @@ class PerDestinationQueue:
             # processing. We have a guard in `attempt_new_transaction` that
             # ensure we don't start sending stuff.
             logger.error(
-                "Create a per destination queue for %s on wrong worker", destination,
+                "Create a per destination queue for %s on wrong worker",
+                destination,
             )
             self._should_send_on_this_instance = False
 
@@ -440,8 +441,10 @@ class PerDestinationQueue:
 
         if first_catch_up_check:
             # first catchup so get last_successful_stream_ordering from database
-            self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering(
-                self._destination
+            self._last_successful_stream_ordering = (
+                await self._store.get_destination_last_successful_stream_ordering(
+                    self._destination
+                )
             )
 
         if self._last_successful_stream_ordering is None:
@@ -457,7 +460,8 @@ class PerDestinationQueue:
         # get at most 50 catchup room/PDUs
         while True:
             event_ids = await self._store.get_catch_up_room_event_ids(
-                self._destination, self._last_successful_stream_ordering,
+                self._destination,
+                self._last_successful_stream_ordering,
             )
 
             if not event_ids:
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3e07f925e0..763aff296c 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -65,7 +65,10 @@ class TransactionManager:
 
     @measure_func("_send_new_transaction")
     async def send_new_transaction(
-        self, destination: str, pdus: List[EventBase], edus: List[Edu],
+        self,
+        destination: str,
+        pdus: List[EventBase],
+        edus: List[Edu],
     ) -> bool:
         """
         Args:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index abe9168c78..10c4747f97 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -39,7 +39,7 @@ class TransportLayerClient:
 
     @log_function
     def get_room_state_ids(self, destination, room_id, event_id):
-        """ Requests all state for a given room from the given server at the
+        """Requests all state for a given room from the given server at the
         given event. Returns the state's event_id's
 
         Args:
@@ -63,7 +63,7 @@ class TransportLayerClient:
 
     @log_function
     def get_event(self, destination, event_id, timeout=None):
-        """ Requests the pdu with give id and origin from the given server.
+        """Requests the pdu with give id and origin from the given server.
 
         Args:
             destination (str): The host name of the remote homeserver we want
@@ -84,7 +84,7 @@ class TransportLayerClient:
 
     @log_function
     def backfill(self, destination, room_id, event_tuples, limit):
-        """ Requests `limit` previous PDUs in a given context before list of
+        """Requests `limit` previous PDUs in a given context before list of
         PDUs.
 
         Args:
@@ -118,7 +118,7 @@ class TransportLayerClient:
 
     @log_function
     async def send_transaction(self, transaction, json_data_callback=None):
-        """ Sends the given Transaction to its destination
+        """Sends the given Transaction to its destination
 
         Args:
             transaction (Transaction)
@@ -551,8 +551,7 @@ class TransportLayerClient:
 
     @log_function
     def get_group_profile(self, destination, group_id, requester_user_id):
-        """Get a group profile
-        """
+        """Get a group profile"""
         path = _create_v1_path("/groups/%s/profile", group_id)
 
         return self.client.get_json(
@@ -584,8 +583,7 @@ class TransportLayerClient:
 
     @log_function
     def get_group_summary(self, destination, group_id, requester_user_id):
-        """Get a group summary
-        """
+        """Get a group summary"""
         path = _create_v1_path("/groups/%s/summary", group_id)
 
         return self.client.get_json(
@@ -597,8 +595,7 @@ class TransportLayerClient:
 
     @log_function
     def get_rooms_in_group(self, destination, group_id, requester_user_id):
-        """Get all rooms in a group
-        """
+        """Get all rooms in a group"""
         path = _create_v1_path("/groups/%s/rooms", group_id)
 
         return self.client.get_json(
@@ -611,8 +608,7 @@ class TransportLayerClient:
     def add_room_to_group(
         self, destination, group_id, requester_user_id, room_id, content
     ):
-        """Add a room to a group
-        """
+        """Add a room to a group"""
         path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
 
         return self.client.post_json(
@@ -626,8 +622,7 @@ class TransportLayerClient:
     def update_room_in_group(
         self, destination, group_id, requester_user_id, room_id, config_key, content
     ):
-        """Update room in group
-        """
+        """Update room in group"""
         path = _create_v1_path(
             "/groups/%s/room/%s/config/%s", group_id, room_id, config_key
         )
@@ -641,8 +636,7 @@ class TransportLayerClient:
         )
 
     def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
-        """Remove a room from a group
-        """
+        """Remove a room from a group"""
         path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
 
         return self.client.delete_json(
@@ -654,8 +648,7 @@ class TransportLayerClient:
 
     @log_function
     def get_users_in_group(self, destination, group_id, requester_user_id):
-        """Get users in a group
-        """
+        """Get users in a group"""
         path = _create_v1_path("/groups/%s/users", group_id)
 
         return self.client.get_json(
@@ -667,8 +660,7 @@ class TransportLayerClient:
 
     @log_function
     def get_invited_users_in_group(self, destination, group_id, requester_user_id):
-        """Get users that have been invited to a group
-        """
+        """Get users that have been invited to a group"""
         path = _create_v1_path("/groups/%s/invited_users", group_id)
 
         return self.client.get_json(
@@ -680,8 +672,7 @@ class TransportLayerClient:
 
     @log_function
     def accept_group_invite(self, destination, group_id, user_id, content):
-        """Accept a group invite
-        """
+        """Accept a group invite"""
         path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
 
         return self.client.post_json(
@@ -690,8 +681,7 @@ class TransportLayerClient:
 
     @log_function
     def join_group(self, destination, group_id, user_id, content):
-        """Attempts to join a group
-        """
+        """Attempts to join a group"""
         path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
 
         return self.client.post_json(
@@ -702,8 +692,7 @@ class TransportLayerClient:
     def invite_to_group(
         self, destination, group_id, user_id, requester_user_id, content
     ):
-        """Invite a user to a group
-        """
+        """Invite a user to a group"""
         path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
 
         return self.client.post_json(
@@ -730,8 +719,7 @@ class TransportLayerClient:
     def remove_user_from_group(
         self, destination, group_id, requester_user_id, user_id, content
     ):
-        """Remove a user from a group
-        """
+        """Remove a user from a group"""
         path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
 
         return self.client.post_json(
@@ -772,8 +760,7 @@ class TransportLayerClient:
     def update_group_summary_room(
         self, destination, group_id, user_id, room_id, category_id, content
     ):
-        """Update a room entry in a group summary
-        """
+        """Update a room entry in a group summary"""
         if category_id:
             path = _create_v1_path(
                 "/groups/%s/summary/categories/%s/rooms/%s",
@@ -796,8 +783,7 @@ class TransportLayerClient:
     def delete_group_summary_room(
         self, destination, group_id, user_id, room_id, category_id
     ):
-        """Delete a room entry in a group summary
-        """
+        """Delete a room entry in a group summary"""
         if category_id:
             path = _create_v1_path(
                 "/groups/%s/summary/categories/%s/rooms/%s",
@@ -817,8 +803,7 @@ class TransportLayerClient:
 
     @log_function
     def get_group_categories(self, destination, group_id, requester_user_id):
-        """Get all categories in a group
-        """
+        """Get all categories in a group"""
         path = _create_v1_path("/groups/%s/categories", group_id)
 
         return self.client.get_json(
@@ -830,8 +815,7 @@ class TransportLayerClient:
 
     @log_function
     def get_group_category(self, destination, group_id, requester_user_id, category_id):
-        """Get category info in a group
-        """
+        """Get category info in a group"""
         path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
 
         return self.client.get_json(
@@ -845,8 +829,7 @@ class TransportLayerClient:
     def update_group_category(
         self, destination, group_id, requester_user_id, category_id, content
     ):
-        """Update a category in a group
-        """
+        """Update a category in a group"""
         path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
 
         return self.client.post_json(
@@ -861,8 +844,7 @@ class TransportLayerClient:
     def delete_group_category(
         self, destination, group_id, requester_user_id, category_id
     ):
-        """Delete a category in a group
-        """
+        """Delete a category in a group"""
         path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
 
         return self.client.delete_json(
@@ -874,8 +856,7 @@ class TransportLayerClient:
 
     @log_function
     def get_group_roles(self, destination, group_id, requester_user_id):
-        """Get all roles in a group
-        """
+        """Get all roles in a group"""
         path = _create_v1_path("/groups/%s/roles", group_id)
 
         return self.client.get_json(
@@ -887,8 +868,7 @@ class TransportLayerClient:
 
     @log_function
     def get_group_role(self, destination, group_id, requester_user_id, role_id):
-        """Get a roles info
-        """
+        """Get a roles info"""
         path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
 
         return self.client.get_json(
@@ -902,8 +882,7 @@ class TransportLayerClient:
     def update_group_role(
         self, destination, group_id, requester_user_id, role_id, content
     ):
-        """Update a role in a group
-        """
+        """Update a role in a group"""
         path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
 
         return self.client.post_json(
@@ -916,8 +895,7 @@ class TransportLayerClient:
 
     @log_function
     def delete_group_role(self, destination, group_id, requester_user_id, role_id):
-        """Delete a role in a group
-        """
+        """Delete a role in a group"""
         path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
 
         return self.client.delete_json(
@@ -931,8 +909,7 @@ class TransportLayerClient:
     def update_group_summary_user(
         self, destination, group_id, requester_user_id, user_id, role_id, content
     ):
-        """Update a users entry in a group
-        """
+        """Update a users entry in a group"""
         if role_id:
             path = _create_v1_path(
                 "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
@@ -950,8 +927,7 @@ class TransportLayerClient:
 
     @log_function
     def set_group_join_policy(self, destination, group_id, requester_user_id, content):
-        """Sets the join policy for a group
-        """
+        """Sets the join policy for a group"""
         path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
 
         return self.client.put_json(
@@ -966,8 +942,7 @@ class TransportLayerClient:
     def delete_group_summary_user(
         self, destination, group_id, requester_user_id, user_id, role_id
     ):
-        """Delete a users entry in a group
-        """
+        """Delete a users entry in a group"""
         if role_id:
             path = _create_v1_path(
                 "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
@@ -983,8 +958,7 @@ class TransportLayerClient:
         )
 
     def bulk_get_publicised_groups(self, destination, user_ids):
-        """Get the groups a list of users are publicising
-        """
+        """Get the groups a list of users are publicising"""
 
         path = _create_v1_path("/get_groups_publicised")
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 95c64510a9..cce83704d4 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -21,6 +21,7 @@ import re
 from typing import Optional, Tuple, Type
 
 import synapse
+from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH
 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.api.urls import (
@@ -364,7 +365,10 @@ class BaseFederationServlet:
                 continue
 
             server.register_paths(
-                method, (pattern,), self._wrap(code), self.__class__.__name__,
+                method,
+                (pattern,),
+                self._wrap(code),
+                self.__class__.__name__,
             )
 
 
@@ -381,7 +385,7 @@ class FederationSendServlet(BaseFederationServlet):
 
     # This is when someone is trying to send us a bunch of data.
     async def on_PUT(self, origin, content, query, transaction_id):
-        """ Called on PUT /send/<transaction_id>/
+        """Called on PUT /send/<transaction_id>/
 
         Args:
             request (twisted.web.http.Request): The HTTP request.
@@ -855,8 +859,7 @@ class FederationVersionServlet(BaseFederationServlet):
 
 
 class FederationGroupsProfileServlet(BaseFederationServlet):
-    """Get/set the basic profile of a group on behalf of a user
-    """
+    """Get/set the basic profile of a group on behalf of a user"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/profile"
 
@@ -895,8 +898,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
 
 
 class FederationGroupsRoomsServlet(BaseFederationServlet):
-    """Get the rooms in a group on behalf of a user
-    """
+    """Get the rooms in a group on behalf of a user"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/rooms"
 
@@ -911,8 +913,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
 
 
 class FederationGroupsAddRoomsServlet(BaseFederationServlet):
-    """Add/remove room from group
-    """
+    """Add/remove room from group"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
 
@@ -940,8 +941,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
 
 
 class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
-    """Update room config in group
-    """
+    """Update room config in group"""
 
     PATH = (
         "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@@ -961,8 +961,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
 
 
 class FederationGroupsUsersServlet(BaseFederationServlet):
-    """Get the users in a group on behalf of a user
-    """
+    """Get the users in a group on behalf of a user"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/users"
 
@@ -977,8 +976,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
 
 
 class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
-    """Get the users that have been invited to a group
-    """
+    """Get the users that have been invited to a group"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
 
@@ -995,8 +993,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
 
 
 class FederationGroupsInviteServlet(BaseFederationServlet):
-    """Ask a group server to invite someone to the group
-    """
+    """Ask a group server to invite someone to the group"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
 
@@ -1013,8 +1010,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
 
 
 class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
-    """Accept an invitation from the group server
-    """
+    """Accept an invitation from the group server"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
 
@@ -1028,8 +1024,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
 
 
 class FederationGroupsJoinServlet(BaseFederationServlet):
-    """Attempt to join a group
-    """
+    """Attempt to join a group"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
 
@@ -1043,8 +1038,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
 
 
 class FederationGroupsRemoveUserServlet(BaseFederationServlet):
-    """Leave or kick a user from the group
-    """
+    """Leave or kick a user from the group"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
 
@@ -1061,8 +1055,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
 
 
 class FederationGroupsLocalInviteServlet(BaseFederationServlet):
-    """A group server has invited a local user
-    """
+    """A group server has invited a local user"""
 
     PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
 
@@ -1076,8 +1069,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
 
 
 class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
-    """A group server has removed a local user
-    """
+    """A group server has removed a local user"""
 
     PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
 
@@ -1093,8 +1085,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
 
 
 class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
-    """A group or user's server renews their attestation
-    """
+    """A group or user's server renews their attestation"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
 
@@ -1128,7 +1119,17 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
         if category_id == "":
-            raise SynapseError(400, "category_id cannot be empty string")
+            raise SynapseError(
+                400, "category_id cannot be empty string", Codes.INVALID_PARAM
+            )
+
+        if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+            raise SynapseError(
+                400,
+                "category_id may not be longer than %s characters"
+                % (MAX_GROUP_CATEGORYID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
 
         resp = await self.handler.update_group_summary_room(
             group_id,
@@ -1156,8 +1157,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
 
 
 class FederationGroupsCategoriesServlet(BaseFederationServlet):
-    """Get all categories for a group
-    """
+    """Get all categories for a group"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
 
@@ -1172,8 +1172,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
 
 
 class FederationGroupsCategoryServlet(BaseFederationServlet):
-    """Add/remove/get a category in a group
-    """
+    """Add/remove/get a category in a group"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
 
@@ -1196,6 +1195,14 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
         if category_id == "":
             raise SynapseError(400, "category_id cannot be empty string")
 
+        if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+            raise SynapseError(
+                400,
+                "category_id may not be longer than %s characters"
+                % (MAX_GROUP_CATEGORYID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         resp = await self.handler.upsert_group_category(
             group_id, requester_user_id, category_id, content
         )
@@ -1218,8 +1225,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
 
 
 class FederationGroupsRolesServlet(BaseFederationServlet):
-    """Get roles in a group
-    """
+    """Get roles in a group"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
 
@@ -1234,8 +1240,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
 
 
 class FederationGroupsRoleServlet(BaseFederationServlet):
-    """Add/remove/get a role in a group
-    """
+    """Add/remove/get a role in a group"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
 
@@ -1254,7 +1259,17 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
         if role_id == "":
-            raise SynapseError(400, "role_id cannot be empty string")
+            raise SynapseError(
+                400, "role_id cannot be empty string", Codes.INVALID_PARAM
+            )
+
+        if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+            raise SynapseError(
+                400,
+                "role_id may not be longer than %s characters"
+                % (MAX_GROUP_ROLEID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
 
         resp = await self.handler.update_group_role(
             group_id, requester_user_id, role_id, content
@@ -1299,6 +1314,14 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
         if role_id == "":
             raise SynapseError(400, "role_id cannot be empty string")
 
+        if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+            raise SynapseError(
+                400,
+                "role_id may not be longer than %s characters"
+                % (MAX_GROUP_ROLEID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         resp = await self.handler.update_group_summary_user(
             group_id,
             requester_user_id,
@@ -1325,8 +1348,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
 
 
 class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
-    """Get roles in a group
-    """
+    """Get roles in a group"""
 
     PATH = "/get_groups_publicised"
 
@@ -1339,8 +1361,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
 
 
 class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
-    """Sets whether a group is joinable without an invite or knock
-    """
+    """Sets whether a group is joinable without an invite or knock"""
 
     PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
 
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 64d98fc8f6..b662c42621 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 @attr.s(slots=True)
 class Edu(JsonEncodedObject):
-    """ An Edu represents a piece of data sent from one homeserver to another.
+    """An Edu represents a piece of data sent from one homeserver to another.
 
     In comparison to Pdus, Edus are not persisted for a long time on disk, are
     not meaningful beyond a given pair of homeservers, and don't have an
@@ -63,7 +63,7 @@ class Edu(JsonEncodedObject):
 
 
 class Transaction(JsonEncodedObject):
-    """ A transaction is a list of Pdus and Edus to be sent to a remote home
+    """A transaction is a list of Pdus and Edus to be sent to a remote home
     server with some extra metadata.
 
     Example transaction::
@@ -99,7 +99,7 @@ class Transaction(JsonEncodedObject):
     ]
 
     def __init__(self, transaction_id=None, pdus=[], **kwargs):
-        """ If we include a list of pdus then we decode then as PDU's
+        """If we include a list of pdus then we decode then as PDU's
         automatically.
         """
 
@@ -111,7 +111,7 @@ class Transaction(JsonEncodedObject):
 
     @staticmethod
     def create_new(pdus, **kwargs):
-        """ Used to create a new transaction. Will auto fill out
+        """Used to create a new transaction. Will auto fill out
         transaction_id and origin_server_ts keys.
         """
         if "origin_server_ts" not in kwargs:
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 41cf07cc88..a3f8d92d08 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -37,13 +37,16 @@ An attestation is a signed blob of json that looks like:
 
 import logging
 import random
-from typing import Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from signedjson.sign import sign_json
 
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -61,18 +64,21 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
 
 
 class GroupAttestationSigning:
-    """Creates and verifies group attestations.
-    """
+    """Creates and verifies group attestations."""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.keyring = hs.get_keyring()
         self.clock = hs.get_clock()
         self.server_name = hs.hostname
         self.signing_key = hs.signing_key
 
     async def verify_attestation(
-        self, attestation, group_id, user_id, server_name=None
-    ):
+        self,
+        attestation: JsonDict,
+        group_id: str,
+        user_id: str,
+        server_name: Optional[str] = None,
+    ) -> None:
         """Verifies that the given attestation matches the given parameters.
 
         An optional server_name can be supplied to explicitly set which server's
@@ -101,16 +107,18 @@ class GroupAttestationSigning:
         if valid_until_ms < now:
             raise SynapseError(400, "Attestation expired")
 
+        assert server_name is not None
         await self.keyring.verify_json_for_server(
             server_name, attestation, now, "Group attestation"
         )
 
-    def create_attestation(self, group_id, user_id):
+    def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
         """Create an attestation for the group_id and user_id with default
         validity length.
         """
-        validity_period = DEFAULT_ATTESTATION_LENGTH_MS
-        validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
+        validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform(
+            *DEFAULT_ATTESTATION_JITTER
+        )
         valid_until_ms = int(self.clock.time_msec() + validity_period)
 
         return sign_json(
@@ -125,10 +133,9 @@ class GroupAttestationSigning:
 
 
 class GroupAttestionRenewer:
-    """Responsible for sending and receiving attestation updates.
-    """
+    """Responsible for sending and receiving attestation updates."""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.assestations = hs.get_groups_attestation_signing()
@@ -141,9 +148,10 @@ class GroupAttestionRenewer:
                 self._start_renew_attestations, 30 * 60 * 1000
             )
 
-    async def on_renew_attestation(self, group_id, user_id, content):
-        """When a remote updates an attestation
-        """
+    async def on_renew_attestation(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
+        """When a remote updates an attestation"""
         attestation = content["attestation"]
 
         if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
@@ -157,12 +165,11 @@ class GroupAttestionRenewer:
 
         return {}
 
-    def _start_renew_attestations(self):
+    def _start_renew_attestations(self) -> None:
         return run_as_background_process("renew_attestations", self._renew_attestations)
 
-    async def _renew_attestations(self):
-        """Called periodically to check if we need to update any of our attestations
-        """
+    async def _renew_attestations(self) -> None:
+        """Called periodically to check if we need to update any of our attestations"""
 
         now = self.clock.time_msec()
 
@@ -170,7 +177,7 @@ class GroupAttestionRenewer:
             now + UPDATE_ATTESTATION_TIME_MS
         )
 
-        async def _renew_attestation(group_user: Tuple[str, str]):
+        async def _renew_attestation(group_user: Tuple[str, str]) -> None:
             group_id, user_id = group_user
             try:
                 if not self.is_mine_id(group_id):
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 0d042cbfac..f9a0f40221 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -16,11 +16,17 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import Codes, SynapseError
-from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
+from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -32,8 +38,13 @@ logger = logging.getLogger(__name__)
 # TODO: Flairs
 
 
+# Note that the maximum lengths are somewhat arbitrary.
+MAX_SHORT_DESC_LEN = 1000
+MAX_LONG_DESC_LEN = 10000
+
+
 class GroupsServerWorkerHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.room_list_handler = hs.get_room_list_handler()
@@ -48,16 +59,21 @@ class GroupsServerWorkerHandler:
         self.profile_handler = hs.get_profile_handler()
 
     async def check_group_is_ours(
-        self, group_id, requester_user_id, and_exists=False, and_is_admin=None
-    ):
+        self,
+        group_id: str,
+        requester_user_id: str,
+        and_exists: bool = False,
+        and_is_admin: Optional[str] = None,
+    ) -> Optional[dict]:
         """Check that the group is ours, and optionally if it exists.
 
         If group does exist then return group.
 
         Args:
-            group_id (str)
-            and_exists (bool): whether to also check if group exists
-            and_is_admin (str): whether to also check if given str is a user_id
+            group_id: The group ID to check.
+            requester_user_id: The user ID of the requester.
+            and_exists: whether to also check if group exists
+            and_is_admin: whether to also check if given str is a user_id
                 that is an admin
         """
         if not self.is_mine_id(group_id):
@@ -80,7 +96,9 @@ class GroupsServerWorkerHandler:
 
         return group
 
-    async def get_group_summary(self, group_id, requester_user_id):
+    async def get_group_summary(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the summary for a group as seen by requester_user_id.
 
         The group summary consists of the profile of the room, and a curated
@@ -113,6 +131,8 @@ class GroupsServerWorkerHandler:
             entry = await self.room_list_handler.generate_room_entry(
                 room_id, len(joined_users), with_alias=False, allow_private=True
             )
+            if entry is None:
+                continue
             entry = dict(entry)  # so we don't change what's cached
             entry.pop("room_id", None)
 
@@ -120,22 +140,22 @@ class GroupsServerWorkerHandler:
 
         rooms.sort(key=lambda e: e.get("order", 0))
 
-        for entry in users:
-            user_id = entry["user_id"]
+        for user in users:
+            user_id = user["user_id"]
 
             if not self.is_mine_id(requester_user_id):
                 attestation = await self.store.get_remote_attestation(group_id, user_id)
                 if not attestation:
                     continue
 
-                entry["attestation"] = attestation
+                user["attestation"] = attestation
             else:
-                entry["attestation"] = self.attestations.create_attestation(
+                user["attestation"] = self.attestations.create_attestation(
                     group_id, user_id
                 )
 
             user_profile = await self.profile_handler.get_profile_from_cache(user_id)
-            entry.update(user_profile)
+            user.update(user_profile)
 
         users.sort(key=lambda e: e.get("order", 0))
 
@@ -158,46 +178,44 @@ class GroupsServerWorkerHandler:
             "user": membership_info,
         }
 
-    async def get_group_categories(self, group_id, requester_user_id):
-        """Get all categories in a group (as seen by user)
-        """
+    async def get_group_categories(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
+        """Get all categories in a group (as seen by user)"""
         await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
 
         categories = await self.store.get_group_categories(group_id=group_id)
         return {"categories": categories}
 
-    async def get_group_category(self, group_id, requester_user_id, category_id):
-        """Get a specific category in a group (as seen by user)
-        """
+    async def get_group_category(
+        self, group_id: str, requester_user_id: str, category_id: str
+    ) -> JsonDict:
+        """Get a specific category in a group (as seen by user)"""
         await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
 
-        res = await self.store.get_group_category(
+        return await self.store.get_group_category(
             group_id=group_id, category_id=category_id
         )
 
-        logger.info("group %s", res)
-
-        return res
-
-    async def get_group_roles(self, group_id, requester_user_id):
-        """Get all roles in a group (as seen by user)
-        """
+    async def get_group_roles(self, group_id: str, requester_user_id: str) -> JsonDict:
+        """Get all roles in a group (as seen by user)"""
         await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
 
         roles = await self.store.get_group_roles(group_id=group_id)
         return {"roles": roles}
 
-    async def get_group_role(self, group_id, requester_user_id, role_id):
-        """Get a specific role in a group (as seen by user)
-        """
+    async def get_group_role(
+        self, group_id: str, requester_user_id: str, role_id: str
+    ) -> JsonDict:
+        """Get a specific role in a group (as seen by user)"""
         await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
 
-        res = await self.store.get_group_role(group_id=group_id, role_id=role_id)
-        return res
+        return await self.store.get_group_role(group_id=group_id, role_id=role_id)
 
-    async def get_group_profile(self, group_id, requester_user_id):
-        """Get the group profile as seen by requester_user_id
-        """
+    async def get_group_profile(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
+        """Get the group profile as seen by requester_user_id"""
 
         await self.check_group_is_ours(group_id, requester_user_id)
 
@@ -218,7 +236,9 @@ class GroupsServerWorkerHandler:
         else:
             raise SynapseError(404, "Unknown group")
 
-    async def get_users_in_group(self, group_id, requester_user_id):
+    async def get_users_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the users in group as seen by requester_user_id.
 
         The ordering is arbitrary at the moment
@@ -267,7 +287,9 @@ class GroupsServerWorkerHandler:
 
         return {"chunk": chunk, "total_user_count_estimate": len(user_results)}
 
-    async def get_invited_users_in_group(self, group_id, requester_user_id):
+    async def get_invited_users_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the users that have been invited to a group as seen by requester_user_id.
 
         The ordering is arbitrary at the moment
@@ -297,7 +319,9 @@ class GroupsServerWorkerHandler:
 
         return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
 
-    async def get_rooms_in_group(self, group_id, requester_user_id):
+    async def get_rooms_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the rooms in group as seen by requester_user_id
 
         This returns rooms in order of decreasing number of joined users
@@ -335,17 +359,21 @@ class GroupsServerWorkerHandler:
 
 
 class GroupsServerHandler(GroupsServerWorkerHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         # Ensure attestations get renewed
         hs.get_groups_attestation_renewer()
 
     async def update_group_summary_room(
-        self, group_id, requester_user_id, room_id, category_id, content
-    ):
-        """Add/update a room to the group summary
-        """
+        self,
+        group_id: str,
+        requester_user_id: str,
+        room_id: str,
+        category_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
+        """Add/update a room to the group summary"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
@@ -367,10 +395,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         return {}
 
     async def delete_group_summary_room(
-        self, group_id, requester_user_id, room_id, category_id
-    ):
-        """Remove a room from the summary
-        """
+        self, group_id: str, requester_user_id: str, room_id: str, category_id: str
+    ) -> JsonDict:
+        """Remove a room from the summary"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
@@ -381,7 +408,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def set_group_join_policy(self, group_id, requester_user_id, content):
+    async def set_group_join_policy(
+        self, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Sets the group join policy.
 
         Currently supported policies are:
@@ -401,10 +430,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         return {}
 
     async def update_group_category(
-        self, group_id, requester_user_id, category_id, content
-    ):
-        """Add/Update a group category
-        """
+        self, group_id: str, requester_user_id: str, category_id: str, content: JsonDict
+    ) -> JsonDict:
+        """Add/Update a group category"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
@@ -421,9 +449,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def delete_group_category(self, group_id, requester_user_id, category_id):
-        """Delete a group category
-        """
+    async def delete_group_category(
+        self, group_id: str, requester_user_id: str, category_id: str
+    ) -> JsonDict:
+        """Delete a group category"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
@@ -434,9 +463,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def update_group_role(self, group_id, requester_user_id, role_id, content):
-        """Add/update a role in a group
-        """
+    async def update_group_role(
+        self, group_id: str, requester_user_id: str, role_id: str, content: JsonDict
+    ) -> JsonDict:
+        """Add/update a role in a group"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
@@ -451,9 +481,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def delete_group_role(self, group_id, requester_user_id, role_id):
-        """Remove role from group
-        """
+    async def delete_group_role(
+        self, group_id: str, requester_user_id: str, role_id: str
+    ) -> JsonDict:
+        """Remove role from group"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
@@ -463,10 +494,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         return {}
 
     async def update_group_summary_user(
-        self, group_id, requester_user_id, user_id, role_id, content
-    ):
-        """Add/update a users entry in the group summary
-        """
+        self,
+        group_id: str,
+        requester_user_id: str,
+        user_id: str,
+        role_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
+        """Add/update a users entry in the group summary"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
@@ -486,10 +521,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         return {}
 
     async def delete_group_summary_user(
-        self, group_id, requester_user_id, user_id, role_id
-    ):
-        """Remove a user from the group summary
-        """
+        self, group_id: str, requester_user_id: str, user_id: str, role_id: str
+    ) -> JsonDict:
+        """Remove a user from the group summary"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
@@ -500,26 +534,43 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def update_group_profile(self, group_id, requester_user_id, content):
-        """Update the group profile
-        """
+    async def update_group_profile(
+        self, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> None:
+        """Update the group profile"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
 
         profile = {}
-        for keyname in ("name", "avatar_url", "short_description", "long_description"):
+        for keyname, max_length in (
+            ("name", MAX_DISPLAYNAME_LEN),
+            ("avatar_url", MAX_AVATAR_URL_LEN),
+            ("short_description", MAX_SHORT_DESC_LEN),
+            ("long_description", MAX_LONG_DESC_LEN),
+        ):
             if keyname in content:
                 value = content[keyname]
                 if not isinstance(value, str):
-                    raise SynapseError(400, "%r value is not a string" % (keyname,))
+                    raise SynapseError(
+                        400,
+                        "%r value is not a string" % (keyname,),
+                        errcode=Codes.INVALID_PARAM,
+                    )
+                if len(value) > max_length:
+                    raise SynapseError(
+                        400,
+                        "Invalid %s parameter" % (keyname,),
+                        errcode=Codes.INVALID_PARAM,
+                    )
                 profile[keyname] = value
 
         await self.store.update_group_profile(group_id, profile)
 
-    async def add_room_to_group(self, group_id, requester_user_id, room_id, content):
-        """Add room to group
-        """
+    async def add_room_to_group(
+        self, group_id: str, requester_user_id: str, room_id: str, content: JsonDict
+    ) -> JsonDict:
+        """Add room to group"""
         RoomID.from_string(room_id)  # Ensure valid room id
 
         await self.check_group_is_ours(
@@ -533,10 +584,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         return {}
 
     async def update_room_in_group(
-        self, group_id, requester_user_id, room_id, config_key, content
-    ):
-        """Update room in group
-        """
+        self,
+        group_id: str,
+        requester_user_id: str,
+        room_id: str,
+        config_key: str,
+        content: JsonDict,
+    ) -> JsonDict:
+        """Update room in group"""
         RoomID.from_string(room_id)  # Ensure valid room id
 
         await self.check_group_is_ours(
@@ -554,9 +609,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def remove_room_from_group(self, group_id, requester_user_id, room_id):
-        """Remove room from group
-        """
+    async def remove_room_from_group(
+        self, group_id: str, requester_user_id: str, room_id: str
+    ) -> JsonDict:
+        """Remove room from group"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
@@ -565,13 +621,16 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def invite_to_group(self, group_id, user_id, requester_user_id, content):
-        """Invite user to group
-        """
+    async def invite_to_group(
+        self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
+        """Invite user to group"""
 
         group = await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
+        if not group:
+            raise SynapseError(400, "Group does not exist", errcode=Codes.BAD_STATE)
 
         # TODO: Check if user knocked
 
@@ -594,6 +653,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         if self.hs.is_mine_id(user_id):
             groups_local = self.hs.get_groups_local_handler()
+            assert isinstance(
+                groups_local, GroupsLocalHandler
+            ), "Workers cannot invites users to groups."
             res = await groups_local.on_invite(group_id, user_id, content)
             local_attestation = None
         else:
@@ -629,6 +691,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
                 local_attestation=local_attestation,
                 remote_attestation=remote_attestation,
             )
+            return {"state": "join"}
         elif res["state"] == "invite":
             await self.store.add_group_invite(group_id, user_id)
             return {"state": "invite"}
@@ -637,13 +700,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         else:
             raise SynapseError(502, "Unknown state returned by HS")
 
-    async def _add_user(self, group_id, user_id, content):
+    async def _add_user(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> Optional[JsonDict]:
         """Add a user to a group based on a content dict.
 
         See accept_invite, join_group.
         """
         if not self.hs.is_mine_id(user_id):
-            local_attestation = self.attestations.create_attestation(group_id, user_id)
+            local_attestation = self.attestations.create_attestation(
+                group_id, user_id
+            )  # type: Optional[JsonDict]
 
             remote_attestation = content["attestation"]
 
@@ -667,7 +734,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return local_attestation
 
-    async def accept_invite(self, group_id, requester_user_id, content):
+    async def accept_invite(
+        self, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """User tries to accept an invite to the group.
 
         This is different from them asking to join, and so should error if no
@@ -686,7 +755,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {"state": "join", "attestation": local_attestation}
 
-    async def join_group(self, group_id, requester_user_id, content):
+    async def join_group(
+        self, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """User tries to join the group.
 
         This will error if the group requires an invite/knock to join
@@ -695,6 +766,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         group_info = await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True
         )
+        if not group_info:
+            raise SynapseError(404, "Group does not exist", errcode=Codes.NOT_FOUND)
         if group_info["join_policy"] != "open":
             raise SynapseError(403, "Group is not publicly joinable")
 
@@ -702,26 +775,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {"state": "join", "attestation": local_attestation}
 
-    async def knock(self, group_id, requester_user_id, content):
-        """A user requests becoming a member of the group
-        """
-        await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
-        raise NotImplementedError()
-
-    async def accept_knock(self, group_id, requester_user_id, content):
-        """Accept a users knock to the room.
-
-        Errors if the user hasn't knocked, rather than inviting them.
-        """
-
-        await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
-        raise NotImplementedError()
-
     async def remove_user_from_group(
-        self, group_id, user_id, requester_user_id, content
-    ):
+        self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Remove a user from the group; either a user is leaving or an admin
         kicked them.
         """
@@ -743,6 +799,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         if is_kick:
             if self.hs.is_mine_id(user_id):
                 groups_local = self.hs.get_groups_local_handler()
+                assert isinstance(
+                    groups_local, GroupsLocalHandler
+                ), "Workers cannot remove users from groups."
                 await groups_local.user_removed_from_group(group_id, user_id, {})
             else:
                 await self.transport_client.remove_user_from_group_notification(
@@ -759,14 +818,15 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def create_group(self, group_id, requester_user_id, content):
-        group = await self.check_group_is_ours(group_id, requester_user_id)
-
+    async def create_group(
+        self, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         logger.info("Attempting to create group with ID: %r", group_id)
 
         # parsing the id into a GroupID validates it.
         group_id_obj = GroupID.from_string(group_id)
 
+        group = await self.check_group_is_ours(group_id, requester_user_id)
         if group:
             raise SynapseError(400, "Group already exists")
 
@@ -811,7 +871,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
             local_attestation = self.attestations.create_attestation(
                 group_id, requester_user_id
-            )
+            )  # type: Optional[JsonDict]
         else:
             local_attestation = None
             remote_attestation = None
@@ -834,15 +894,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {"group_id": group_id}
 
-    async def delete_group(self, group_id, requester_user_id):
+    async def delete_group(self, group_id: str, requester_user_id: str) -> None:
         """Deletes a group, kicking out all current members.
 
         Only group admins or server admins can call this request
 
         Args:
-            group_id (str)
-            request_user_id (str)
-
+            group_id: The group ID to delete.
+            requester_user_id: The user requesting to delete the group.
         """
 
         await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
@@ -865,6 +924,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         async def _kick_user_from_group(user_id):
             if self.hs.is_mine_id(user_id):
                 groups_local = self.hs.get_groups_local_handler()
+                assert isinstance(
+                    groups_local, GroupsLocalHandler
+                ), "Workers cannot kick users from groups."
                 await groups_local.user_removed_from_group(group_id, user_id, {})
             else:
                 await self.transport_client.remove_user_from_group_notification(
@@ -896,9 +958,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         await self.store.delete_group(group_id)
 
 
-def _parse_join_policy_from_contents(content):
-    """Given a content for a request, return the specified join policy or None
-    """
+def _parse_join_policy_from_contents(content: JsonDict) -> Optional[str]:
+    """Given a content for a request, return the specified join policy or None"""
 
     join_policy_dict = content.get("m.join_policy")
     if join_policy_dict:
@@ -907,9 +968,8 @@ def _parse_join_policy_from_contents(content):
         return None
 
 
-def _parse_join_policy_dict(join_policy_dict):
-    """Given a dict for the "m.join_policy" config return the join policy specified
-    """
+def _parse_join_policy_dict(join_policy_dict: JsonDict) -> str:
+    """Given a dict for the "m.join_policy" config return the join policy specified"""
     join_policy_type = join_policy_dict.get("type")
     if not join_policy_type:
         return "invite"
@@ -919,7 +979,7 @@ def _parse_join_policy_dict(join_policy_dict):
     return join_policy_type
 
 
-def _parse_visibility_from_contents(content):
+def _parse_visibility_from_contents(content: JsonDict) -> bool:
     """Given a content for a request parse out whether the entity should be
     public or not
     """
@@ -933,7 +993,7 @@ def _parse_visibility_from_contents(content):
     return is_public
 
 
-def _parse_visibility_dict(visibility):
+def _parse_visibility_dict(visibility: JsonDict) -> bool:
     """Given a dict for the "m.visibility" config return if the entity should
     be public or not
     """
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 8476256a59..5ecb2da1ac 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 import twisted
 import twisted.internet.error
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
 
 from synapse.app import check_bind_error
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 ACME_REGISTER_FAIL_ERROR = """
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
 
 
 class AcmeHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.reactor = hs.get_reactor()
         self._acme_domain = hs.config.acme_domain
 
-    async def start_listening(self):
+    async def start_listening(self) -> None:
         from synapse.handlers import acme_issuing_service
 
         # Configure logging for txacme, if you need to debug
@@ -85,7 +89,7 @@ class AcmeHandler:
             logger.error(ACME_REGISTER_FAIL_ERROR)
             raise
 
-    async def provision_certificate(self):
+    async def provision_certificate(self) -> None:
 
         logger.warning("Reprovisioning %s", self._acme_domain)
 
@@ -110,5 +114,3 @@ class AcmeHandler:
         except Exception:
             logger.exception("Failed saving!")
             raise
-
-        return True
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 7294649d71..ae2a9dd9c2 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
 imported conditionally.
 """
 import logging
+from typing import Dict, Iterable, List
 
 import attr
+import pem
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
 from josepy import JWKRSA
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
 from zope.interface import implementer
 
 from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTCP
 from twisted.python.filepath import FilePath
 from twisted.python.url import URL
+from twisted.web.resource import IResource
 
 logger = logging.getLogger(__name__)
 
 
-def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+def create_issuing_service(
+    reactor: IReactorTCP,
+    acme_url: str,
+    account_key_file: str,
+    well_known_resource: IResource,
+) -> AcmeIssuingService:
     """Create an ACME issuing service, and attach it to a web Resource
 
     Args:
         reactor: twisted reactor
-        acme_url (str): URL to use to request certificates
-        account_key_file (str): where to store the account key
-        well_known_resource (twisted.web.IResource): web resource for .well-known.
+        acme_url: URL to use to request certificates
+        account_key_file: where to store the account key
+        well_known_resource: web resource for .well-known.
             we will attach a child resource for "acme-challenge".
 
     Returns:
@@ -83,18 +92,20 @@ class ErsatzStore:
     A store that only stores in memory.
     """
 
-    certs = attr.ib(default=attr.Factory(dict))
+    certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
 
-    def store(self, server_name, pem_objects):
+    def store(
+        self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
+    ) -> defer.Deferred:
         self.certs[server_name] = [o.as_bytes() for o in pem_objects]
         return defer.succeed(None)
 
 
-def load_or_create_client_key(key_file):
+def load_or_create_client_key(key_file: str) -> JWKRSA:
     """Load the ACME account key from a file, creating it if it does not exist.
 
     Args:
-        key_file (str): name of the file to use as the account key
+        key_file: name of the file to use as the account key
     """
     # this is based on txacme.endpoint.load_or_create_client_key, but doesn't
     # hardcode the 'client.key' filename
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 37e63da9b1..db68c94c50 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -203,13 +203,11 @@ class AdminHandler(BaseHandler):
 
 
 class ExfiltrationWriter(metaclass=abc.ABCMeta):
-    """Interface used to specify how to write exported data.
-    """
+    """Interface used to specify how to write exported data."""
 
     @abc.abstractmethod
     def write_events(self, room_id: str, events: List[EventBase]) -> None:
-        """Write a batch of events for a room.
-        """
+        """Write a batch of events for a room."""
         raise NotImplementedError()
 
     @abc.abstractmethod
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5c6458eb52..deab8ff2d0 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -290,7 +290,9 @@ class ApplicationServicesHandler:
             if not interested:
                 continue
             presence_events, _ = await presence_source.get_new_events(
-                user=user, service=service, from_key=from_key,
+                user=user,
+                service=service,
+                from_key=from_key,
             )
             time_now = self.clock.time_msec()
             events.extend(
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 0e98db22b3..9ba9f591d9 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -61,6 +61,7 @@ from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import ModuleApi
+from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import maybe_awaitable
@@ -119,7 +120,9 @@ def convert_client_dict_legacy_fields_to_identifier(
     # Ensure the identifier has a type
     if "type" not in identifier:
         raise SynapseError(
-            400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
+            400,
+            "'identifier' dict has no key 'type'",
+            errcode=Codes.MISSING_PARAM,
         )
 
     return identifier
@@ -350,7 +353,11 @@ class AuthHandler(BaseHandler):
 
         try:
             result, params, session_id = await self.check_ui_auth(
-                flows, request, request_body, description, get_new_session_data,
+                flows,
+                request,
+                request_body,
+                description,
+                get_new_session_data,
             )
         except LoginError:
             # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
@@ -378,8 +385,7 @@ class AuthHandler(BaseHandler):
         return params, session_id
 
     async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
-        """Get a list of the authentication types this user can use
-        """
+        """Get a list of the authentication types this user can use"""
 
         ui_auth_types = set()
 
@@ -567,16 +573,6 @@ class AuthHandler(BaseHandler):
                         session.session_id, login_type, result
                     )
             except LoginError as e:
-                if login_type == LoginType.EMAIL_IDENTITY:
-                    # riot used to have a bug where it would request a new
-                    # validation token (thus sending a new email) each time it
-                    # got a 401 with a 'flows' field.
-                    # (https://github.com/vector-im/vector-web/issues/2447).
-                    #
-                    # Grandfather in the old behaviour for now to avoid
-                    # breaking old riot deployments.
-                    raise
-
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
                 errordict = e.error_dict()
@@ -732,7 +728,9 @@ class AuthHandler(BaseHandler):
         }
 
     def _auth_dict_for_flows(
-        self, flows: List[List[str]], session_id: str,
+        self,
+        flows: List[List[str]],
+        session_id: str,
     ) -> Dict[str, Any]:
         public_flows = []
         for f in flows:
@@ -889,7 +887,9 @@ class AuthHandler(BaseHandler):
         return self._supported_login_types
 
     async def validate_login(
-        self, login_submission: Dict[str, Any], ratelimit: bool = False,
+        self,
+        login_submission: Dict[str, Any],
+        ratelimit: bool = False,
     ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
         """Authenticates the user for the /login API
 
@@ -1032,7 +1032,9 @@ class AuthHandler(BaseHandler):
             raise
 
     async def _validate_userid_login(
-        self, username: str, login_submission: Dict[str, Any],
+        self,
+        username: str,
+        login_submission: Dict[str, Any],
     ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
         """Helper for validate_login
 
@@ -1387,7 +1389,9 @@ class AuthHandler(BaseHandler):
         )
 
         return self._sso_auth_confirm_template.render(
-            description=session.description, redirect_url=redirect_url,
+            description=session.description,
+            redirect_url=redirect_url,
+            idp=sso_auth_provider,
         )
 
     async def complete_sso_login(
@@ -1396,6 +1400,7 @@ class AuthHandler(BaseHandler):
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
+        new_user: bool = False,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1406,6 +1411,8 @@ class AuthHandler(BaseHandler):
                 process.
             extra_attributes: Extra attributes which will be passed to the client
                 during successful login. Must be JSON serializable.
+            new_user: True if we should use wording appropriate to a user who has just
+                registered.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1414,8 +1421,17 @@ class AuthHandler(BaseHandler):
             respond_with_html(request, 403, self._sso_account_deactivated_template)
             return
 
+        profile = await self.store.get_profileinfo(
+            UserID.from_string(registered_user_id).localpart
+        )
+
         self._complete_sso_login(
-            registered_user_id, request, client_redirect_url, extra_attributes
+            registered_user_id,
+            request,
+            client_redirect_url,
+            extra_attributes,
+            new_user=new_user,
+            user_profile_data=profile,
         )
 
     def _complete_sso_login(
@@ -1424,18 +1440,25 @@ class AuthHandler(BaseHandler):
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
+        new_user: bool = False,
+        user_profile_data: Optional[ProfileInfo] = None,
     ):
         """
         The synchronous portion of complete_sso_login.
 
         This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
         """
+
+        if user_profile_data is None:
+            user_profile_data = ProfileInfo(None, None)
+
         # Store any extra attributes which will be passed in the login response.
         # Note that this is per-user so it may overwrite a previous value, this
         # is considered OK since the newest SSO attributes should be most valid.
         if extra_attributes:
             self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
-                self._clock.time_msec(), extra_attributes,
+                self._clock.time_msec(),
+                extra_attributes,
             )
 
         # Create a login token
@@ -1461,12 +1484,27 @@ class AuthHandler(BaseHandler):
         # Remove the query parameters from the redirect URL to get a shorter version of
         # it. This is only to display a human-readable URL in the template, but not the
         # URL we redirect users to.
-        redirect_url_no_params = client_redirect_url.split("?")[0]
+        url_parts = urllib.parse.urlsplit(client_redirect_url)
+
+        if url_parts.scheme == "https":
+            # for an https uri, just show the netloc (ie, the hostname. Specifically,
+            # the bit between "//" and "/"; this includes any potential
+            # "username:password@" prefix.)
+            display_url = url_parts.netloc
+        else:
+            # for other uris, strip the query-params (including the login token) and
+            # fragment.
+            display_url = urllib.parse.urlunsplit(
+                (url_parts.scheme, url_parts.netloc, url_parts.path, "", "")
+            )
 
         html = self._sso_redirect_confirm_template.render(
-            display_url=redirect_url_no_params,
+            display_url=display_url,
             redirect_url=redirect_url,
             server_name=self._server_name,
+            new_user=new_user,
+            user_id=registered_user_id,
+            user_profile=user_profile_data,
         )
         respond_with_html(request, 200, html)
 
@@ -1676,5 +1714,9 @@ class PasswordProvider:
         # This might return an awaitable, if it does block the log out
         # until it completes.
         await maybe_awaitable(
-            g(user_id=user_id, device_id=device_id, access_token=access_token,)
+            g(
+                user_id=user_id,
+                device_id=device_id,
+                access_token=access_token,
+            )
         )
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 0f342c607b..04972f9cf0 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import urllib.parse
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
 from xml.etree import ElementTree as ET
 
 import attr
@@ -33,8 +33,7 @@ logger = logging.getLogger(__name__)
 
 
 class CasError(Exception):
-    """Used to catch errors when validating the CAS ticket.
-    """
+    """Used to catch errors when validating the CAS ticket."""
 
     def __init__(self, error, error_description=None):
         self.error = error
@@ -49,7 +48,7 @@ class CasError(Exception):
 @attr.s(slots=True, frozen=True)
 class CasResponse:
     username = attr.ib(type=str)
-    attributes = attr.ib(type=Dict[str, Optional[str]])
+    attributes = attr.ib(type=Dict[str, List[Optional[str]]])
 
 
 class CasHandler:
@@ -80,9 +79,10 @@ class CasHandler:
         # user-facing name of this auth provider
         self.idp_name = "CAS"
 
-        # we do not currently support icons for CAS auth, but this is required by
+        # we do not currently support brands/icons for CAS auth, but this is required by
         # the SsoIdentityProvider protocol type.
         self.idp_icon = None
+        self.idp_brand = None
 
         self._sso_handler = hs.get_sso_handler()
 
@@ -99,9 +99,8 @@ class CasHandler:
         Returns:
             The URL to use as a "service" parameter.
         """
-        return "%s%s?%s" % (
+        return "%s?%s" % (
             self._cas_service_url,
-            "/_matrix/client/r0/login/cas/ticket",
             urllib.parse.urlencode(args),
         )
 
@@ -172,7 +171,7 @@ class CasHandler:
 
         # Iterate through the nodes and pull out the user and any extra attributes.
         user = None
-        attributes = {}
+        attributes = {}  # type: Dict[str, List[Optional[str]]]
         for child in root[0]:
             if child.tag.endswith("user"):
                 user = child.text
@@ -185,7 +184,7 @@ class CasHandler:
                     tag = attribute.tag
                     if "}" in tag:
                         tag = tag.split("}")[1]
-                    attributes[tag] = attribute.text
+                    attributes.setdefault(tag, []).append(attribute.text)
 
         # Ensure a user was found.
         if user is None:
@@ -299,36 +298,20 @@ class CasHandler:
         # first check if we're doing a UIA
         if session:
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self.idp_id, cas_response.username, session, request,
+                self.idp_id,
+                cas_response.username,
+                session,
+                request,
             )
 
         # otherwise, we're handling a login request.
 
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
-        for required_attribute, required_value in self._cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in cas_response.attributes:
-                self._sso_handler.render_error(
-                    request,
-                    "unauthorised",
-                    "You are not authorised to log in here.",
-                    401,
-                )
-                return
-
-            # Also need to check value
-            if required_value is not None:
-                actual_value = cas_response.attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    self._sso_handler.render_error(
-                        request,
-                        "unauthorised",
-                        "You are not authorised to log in here.",
-                        401,
-                    )
-                    return
+        if not self._sso_handler.check_required_attributes(
+            request, cas_response.attributes, self._cas_required_attributes
+        ):
+            return
 
         # Call the mapper to register/login the user
 
@@ -375,9 +358,10 @@ class CasHandler:
             if failures:
                 raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
 
+            # Arbitrarily use the first attribute found.
             display_name = cas_response.attributes.get(
-                self._cas_displayname_attribute, None
-            )
+                self._cas_displayname_attribute, [None]
+            )[0]
 
             return UserAttributes(localpart=localpart, display_name=display_name)
 
@@ -387,7 +371,8 @@ class CasHandler:
             user_id = UserID(localpart, self._hostname).to_string()
 
             logger.debug(
-                "Looking for existing account based on mapped %s", user_id,
+                "Looking for existing account based on mapped %s",
+                user_id,
             )
 
             users = await self._store.get_users_by_id_case_insensitive(user_id)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index c4a3b26a84..94f3f3163f 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -196,8 +196,7 @@ class DeactivateAccountHandler(BaseHandler):
             run_as_background_process("user_parter_loop", self._user_parter_loop)
 
     async def _user_parter_loop(self) -> None:
-        """Loop that parts deactivated users from rooms
-        """
+        """Loop that parts deactivated users from rooms"""
         self._user_parter_running = True
         logger.info("Starting user parter")
         try:
@@ -214,8 +213,7 @@ class DeactivateAccountHandler(BaseHandler):
             self._user_parter_running = False
 
     async def _part_user(self, user_id: str) -> None:
-        """Causes the given user_id to leave all the rooms they're joined to
-        """
+        """Causes the given user_id to leave all the rooms they're joined to"""
         user = UserID.from_string(user_id)
 
         rooms_for_user = await self.store.get_rooms_for_user(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index debb1b4f29..df3cdc8fba 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
         self._auth_handler = hs.get_auth_handler()
 
     @trace
-    async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
+    async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
         """
         Retrieve the given user's devices
 
@@ -85,8 +85,8 @@ class DeviceWorkerHandler(BaseHandler):
         return devices
 
     @trace
-    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
-        """ Retrieve the given device
+    async def get_device(self, user_id: str, device_id: str) -> JsonDict:
+        """Retrieve the given device
 
         Args:
             user_id: The user to get the device from
@@ -341,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
     @trace
     async def delete_device(self, user_id: str, device_id: str) -> None:
-        """ Delete the given device
+        """Delete the given device
 
         Args:
             user_id: The user to delete the device from.
@@ -386,7 +386,7 @@ class DeviceHandler(DeviceWorkerHandler):
         await self.delete_devices(user_id, device_ids)
 
     async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
-        """ Delete several devices
+        """Delete several devices
 
         Args:
             user_id: The user to delete devices from.
@@ -417,7 +417,7 @@ class DeviceHandler(DeviceWorkerHandler):
         await self.notify_device_update(user_id, device_ids)
 
     async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
-        """ Update the given device
+        """Update the given device
 
         Args:
             user_id: The user to update devices of.
@@ -534,7 +534,9 @@ class DeviceHandler(DeviceWorkerHandler):
             device id of the dehydrated device
         """
         device_id = await self.check_device_registered(
-            user_id, None, initial_device_display_name,
+            user_id,
+            None,
+            initial_device_display_name,
         )
         old_device_id = await self.store.store_dehydrated_device(
             user_id, device_id, device_data
@@ -598,7 +600,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
 
 def _update_device_from_client_ips(
-    device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+    device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -803,7 +805,8 @@ class DeviceListUpdater:
                 try:
                     # Try to resync the current user's devices list.
                     result = await self.user_device_resync(
-                        user_id=user_id, mark_failed_as_stale=False,
+                        user_id=user_id,
+                        mark_failed_as_stale=False,
                     )
 
                     # user_device_resync only returns a result if it managed to
@@ -813,14 +816,17 @@ class DeviceListUpdater:
                     # self.store.update_remote_device_list_cache).
                     if result:
                         logger.debug(
-                            "Successfully resynced the device list for %s", user_id,
+                            "Successfully resynced the device list for %s",
+                            user_id,
                         )
                 except Exception as e:
                     # If there was an issue resyncing this user, e.g. if the remote
                     # server sent a malformed result, just log the error instead of
                     # aborting all the subsequent resyncs.
                     logger.debug(
-                        "Could not resync the device list for %s: %s", user_id, e,
+                        "Could not resync the device list for %s: %s",
+                        user_id,
+                        e,
                     )
         finally:
             # Allow future calls to retry resyncinc out of sync device lists.
@@ -855,7 +861,9 @@ class DeviceListUpdater:
             return None
         except (RequestSendFailed, HttpResponseException) as e:
             logger.warning(
-                "Failed to handle device list update for %s: %s", user_id, e,
+                "Failed to handle device list update for %s: %s",
+                user_id,
+                e,
             )
 
             if mark_failed_as_stale:
@@ -931,7 +939,9 @@ class DeviceListUpdater:
 
         # Handle cross-signing keys.
         cross_signing_device_ids = await self.process_cross_signing_key_update(
-            user_id, master_key, self_signing_key,
+            user_id,
+            master_key,
+            self_signing_key,
         )
         device_ids = device_ids + cross_signing_device_ids
 
@@ -946,8 +956,8 @@ class DeviceListUpdater:
     async def process_cross_signing_key_update(
         self,
         user_id: str,
-        master_key: Optional[Dict[str, Any]],
-        self_signing_key: Optional[Dict[str, Any]],
+        master_key: Optional[JsonDict],
+        self_signing_key: Optional[JsonDict],
     ) -> List[str]:
         """Process the given new master and self-signing key for the given remote user.
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 0c7737e09d..1aa7d803b5 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -62,7 +62,8 @@ class DeviceMessageHandler:
             )
         else:
             hs.get_federation_registry().register_instances_for_edu(
-                "m.direct_to_device", hs.config.worker.writers.to_device,
+                "m.direct_to_device",
+                hs.config.worker.writers.to_device,
             )
 
         # The handler to call when we think a user's device list might be out of
@@ -73,8 +74,8 @@ class DeviceMessageHandler:
                 hs.get_device_handler().device_list_updater.user_device_resync
             )
         else:
-            self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
-                hs
+            self._user_device_resync = (
+                ReplicationUserDevicesResyncRestServlet.make_client(hs)
             )
 
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..9a946a3cfe 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import (
+    JsonDict,
     UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class E2eKeysHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
@@ -57,8 +61,8 @@ class E2eKeysHandler:
 
         self._is_master = hs.config.worker_app is None
         if not self._is_master:
-            self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
-                hs
+            self._user_device_resync_client = (
+                ReplicationUserDevicesResyncRestServlet.make_client(hs)
             )
         else:
             # Only register this edu handler on master as it requires writing
@@ -78,8 +82,10 @@ class E2eKeysHandler:
         )
 
     @trace
-    async def query_devices(self, query_body, timeout, from_user_id):
-        """ Handle a device key query from a client
+    async def query_devices(
+        self, query_body: JsonDict, timeout: int, from_user_id: str
+    ) -> JsonDict:
+        """Handle a device key query from a client
 
         {
             "device_keys": {
@@ -98,12 +104,14 @@ class E2eKeysHandler:
         }
 
         Args:
-            from_user_id (str): the user making the query.  This is used when
+            from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
         """
 
-        device_keys_query = query_body.get("device_keys", {})
+        device_keys_query = query_body.get(
+            "device_keys", {}
+        )  # type: Dict[str, Iterable[str]]
 
         # separate users by domain.
         # make a map from domain to user_id to device_ids
@@ -121,7 +129,8 @@ class E2eKeysHandler:
         set_tag("remote_key_query", remote_queries)
 
         # First get local devices.
-        failures = {}
+        # A map of destination -> failure response.
+        failures = {}  # type: Dict[str, JsonDict]
         results = {}
         if local_query:
             local_result = await self.query_local_devices(local_query)
@@ -135,9 +144,10 @@ class E2eKeysHandler:
         )
 
         # Now attempt to get any remote devices from our local cache.
-        remote_queries_not_in_cache = {}
+        # A map of destination -> user ID -> device IDs.
+        remote_queries_not_in_cache = {}  # type: Dict[str, Dict[str, Iterable[str]]]
         if remote_queries:
-            query_list = []
+            query_list = []  # type: List[Tuple[str, Optional[str]]]
             for user_id, device_ids in remote_queries.items():
                 if device_ids:
                     query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -284,15 +294,15 @@ class E2eKeysHandler:
         return ret
 
     async def get_cross_signing_keys_from_cache(
-        self, query, from_user_id
+        self, query: Iterable[str], from_user_id: Optional[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Get cross-signing keys for users from the database
 
         Args:
-            query (Iterable[string]) an iterable of user IDs.  A dict whose keys
+            query: an iterable of user IDs.  A dict whose keys
                 are user IDs satisfies this, so the query format used for
                 query_devices can be used here.
-            from_user_id (str): the user making the query.  This is used when
+            from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
 
@@ -315,14 +325,12 @@ class E2eKeysHandler:
             if "self_signing" in user_info:
                 self_signing_keys[user_id] = user_info["self_signing"]
 
-        if (
-            from_user_id in keys
-            and keys[from_user_id] is not None
-            and "user_signing" in keys[from_user_id]
-        ):
-            # users can see other users' master and self-signing keys, but can
-            # only see their own user-signing keys
-            user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+        # users can see other users' master and self-signing keys, but can
+        # only see their own user-signing keys
+        if from_user_id:
+            from_user_key = keys.get(from_user_id)
+            if from_user_key and "user_signing" in from_user_key:
+                user_signing_keys[from_user_id] = from_user_key["user_signing"]
 
         return {
             "master_keys": master_keys,
@@ -344,9 +352,9 @@ class E2eKeysHandler:
             A map from user_id -> device_id -> device details
         """
         set_tag("local_query", query)
-        local_query = []
+        local_query = []  # type: List[Tuple[str, Optional[str]]]
 
-        result_dict = {}
+        result_dict = {}  # type: Dict[str, Dict[str, dict]]
         for user_id, device_ids in query.items():
             # we use UserID.from_string to catch invalid user ids
             if not self.is_mine(UserID.from_string(user_id)):
@@ -380,10 +388,13 @@ class E2eKeysHandler:
         log_kv(results)
         return result_dict
 
-    async def on_federation_query_client_keys(self, query_body):
-        """ Handle a device key query from a federated server
-        """
-        device_keys_query = query_body.get("device_keys", {})
+    async def on_federation_query_client_keys(
+        self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
+    ) -> JsonDict:
+        """Handle a device key query from a federated server"""
+        device_keys_query = query_body.get(
+            "device_keys", {}
+        )  # type: Dict[str, Optional[List[str]]]
         res = await self.query_local_devices(device_keys_query)
         ret = {"device_keys": res}
 
@@ -397,31 +408,34 @@ class E2eKeysHandler:
         return ret
 
     @trace
-    async def claim_one_time_keys(self, query, timeout):
-        local_query = []
-        remote_queries = {}
+    async def claim_one_time_keys(
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+    ) -> JsonDict:
+        local_query = []  # type: List[Tuple[str, str, str]]
+        remote_queries = {}  # type: Dict[str, Dict[str, Dict[str, str]]]
 
-        for user_id, device_keys in query.get("one_time_keys", {}).items():
+        for user_id, one_time_keys in query.get("one_time_keys", {}).items():
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
-                for device_id, algorithm in device_keys.items():
+                for device_id, algorithm in one_time_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
                 domain = get_domain_from_id(user_id)
-                remote_queries.setdefault(domain, {})[user_id] = device_keys
+                remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
         set_tag("local_key_query", local_query)
         set_tag("remote_key_query", remote_queries)
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
-        json_result = {}
-        failures = {}
+        # A map of user ID -> device ID -> key ID -> key.
+        json_result = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+        failures = {}  # type: Dict[str, JsonDict]
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
-                for key_id, json_bytes in keys.items():
+                for key_id, json_str in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json_decoder.decode(json_bytes)
+                        key_id: json_decoder.decode(json_str)
                     }
 
         @trace
@@ -468,7 +482,9 @@ class E2eKeysHandler:
         return {"one_time_keys": json_result, "failures": failures}
 
     @tag_args
-    async def upload_keys_for_user(self, user_id, device_id, keys):
+    async def upload_keys_for_user(
+        self, user_id: str, device_id: str, keys: JsonDict
+    ) -> JsonDict:
 
         time_now = self.clock.time_msec()
 
@@ -543,8 +559,8 @@ class E2eKeysHandler:
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(
-        self, user_id, device_id, time_now, one_time_keys
-    ):
+        self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+    ) -> None:
         logger.info(
             "Adding one_time_keys %r for device %r for user %r at %d",
             one_time_keys.keys(),
@@ -585,12 +601,14 @@ class E2eKeysHandler:
         log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
         await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
 
-    async def upload_signing_keys_for_user(self, user_id, keys):
+    async def upload_signing_keys_for_user(
+        self, user_id: str, keys: JsonDict
+    ) -> JsonDict:
         """Upload signing keys for cross-signing
 
         Args:
-            user_id (string): the user uploading the keys
-            keys (dict[string, dict]): the signing keys
+            user_id: the user uploading the keys
+            keys: the signing keys
         """
 
         # if a master key is uploaded, then check it.  Otherwise, load the
@@ -667,16 +685,17 @@ class E2eKeysHandler:
 
         return {}
 
-    async def upload_signatures_for_device_keys(self, user_id, signatures):
+    async def upload_signatures_for_device_keys(
+        self, user_id: str, signatures: JsonDict
+    ) -> JsonDict:
         """Upload device signatures for cross-signing
 
         Args:
-            user_id (string): the user uploading the signatures
-            signatures (dict[string, dict[string, dict]]): map of users to
-                devices to signed keys. This is the submission from the user; an
-                exception will be raised if it is malformed.
+            user_id: the user uploading the signatures
+            signatures: map of users to devices to signed keys. This is the submission
+            from the user; an exception will be raised if it is malformed.
         Returns:
-            dict: response to be sent back to the client.  The response will have
+            The response to be sent back to the client.  The response will have
                 a "failures" key, which will be a dict mapping users to devices
                 to errors for the signatures that failed.
         Raises:
@@ -719,7 +738,9 @@ class E2eKeysHandler:
 
         return {"failures": failures}
 
-    async def _process_self_signatures(self, user_id, signatures):
+    async def _process_self_signatures(
+        self, user_id: str, signatures: JsonDict
+    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
         """Process uploaded signatures of the user's own keys.
 
         Signatures of the user's own keys from this API come in two forms:
@@ -731,15 +752,14 @@ class E2eKeysHandler:
             signatures (dict[string, dict]): map of devices to signed keys
 
         Returns:
-            (list[SignatureListItem], dict[string, dict[string, dict]]):
-            a list of signatures to store, and a map of users to devices to failure
-            reasons
+            A tuple of a list of signatures to store, and a map of users to
+            devices to failure reasons
 
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []
-        failures = {}
+        signature_list = []  # type: List[SignatureListItem]
+        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
         if not signatures:
             return signature_list, failures
 
@@ -834,19 +854,24 @@ class E2eKeysHandler:
         return signature_list, failures
 
     def _check_master_key_signature(
-        self, user_id, master_key_id, signed_master_key, stored_master_key, devices
-    ):
+        self,
+        user_id: str,
+        master_key_id: str,
+        signed_master_key: JsonDict,
+        stored_master_key: JsonDict,
+        devices: Dict[str, Dict[str, JsonDict]],
+    ) -> List["SignatureListItem"]:
         """Check signatures of a user's master key made by their devices.
 
         Args:
-            user_id (string): the user whose master key is being checked
-            master_key_id (string): the ID of the user's master key
-            signed_master_key (dict): the user's signed master key that was uploaded
-            stored_master_key (dict): our previously-stored copy of the user's master key
-            devices (iterable(dict)): the user's devices
+            user_id: the user whose master key is being checked
+            master_key_id: the ID of the user's master key
+            signed_master_key: the user's signed master key that was uploaded
+            stored_master_key: our previously-stored copy of the user's master key
+            devices: the user's devices
 
         Returns:
-            list[SignatureListItem]: a list of signatures to store
+            A list of signatures to store
 
         Raises:
             SynapseError: if a signature is invalid
@@ -877,25 +902,26 @@ class E2eKeysHandler:
 
         return master_key_signature_list
 
-    async def _process_other_signatures(self, user_id, signatures):
+    async def _process_other_signatures(
+        self, user_id: str, signatures: Dict[str, dict]
+    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
         """Process uploaded signatures of other users' keys.  These will be the
         target user's master keys, signed by the uploading user's user-signing
         key.
 
         Args:
-            user_id (string): the user uploading the keys
-            signatures (dict[string, dict]): map of users to devices to signed keys
+            user_id: the user uploading the keys
+            signatures: map of users to devices to signed keys
 
         Returns:
-            (list[SignatureListItem], dict[string, dict[string, dict]]):
-            a list of signatures to store, and a map of users to devices to failure
+            A list of signatures to store, and a map of users to devices to failure
             reasons
 
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []
-        failures = {}
+        signature_list = []  # type: List[SignatureListItem]
+        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
         if not signatures:
             return signature_list, failures
 
@@ -983,7 +1009,7 @@ class E2eKeysHandler:
 
     async def _get_e2e_cross_signing_verify_key(
         self, user_id: str, key_type: str, from_user_id: str = None
-    ):
+    ) -> Tuple[JsonDict, str, VerifyKey]:
         """Fetch locally or remotely query for a cross-signing public key.
 
         First, attempt to fetch the cross-signing public key from storage.
@@ -997,8 +1023,7 @@ class E2eKeysHandler:
                 This affects what signatures are fetched.
 
         Returns:
-            dict, str, VerifyKey: the raw key data, the key ID, and the
-                signedjson verify key
+            The raw key data, the key ID, and the signedjson verify key
 
         Raises:
             NotFoundError: if the key is not found
@@ -1039,7 +1064,9 @@ class E2eKeysHandler:
         return key, key_id, verify_key
 
     async def _retrieve_cross_signing_keys_for_remote_user(
-        self, user: UserID, desired_key_type: str,
+        self,
+        user: UserID,
+        desired_key_type: str,
     ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
         """Queries cross-signing keys for a remote user and saves them to the database
 
@@ -1135,16 +1162,18 @@ class E2eKeysHandler:
         return desired_key, desired_key_id, desired_verify_key
 
 
-def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+def _check_cross_signing_key(
+    key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
+) -> None:
     """Check a cross-signing key uploaded by a user.  Performs some basic sanity
     checking, and ensures that it is signed, if a signature is required.
 
     Args:
-        key (dict): the key data to verify
-        user_id (str): the user whose key is being checked
-        key_type (str): the type of key that the key should be
-        signing_key (VerifyKey): (optional) the signing key that the key should
-            be signed with.  If omitted, signatures will not be checked.
+        key: the key data to verify
+        user_id: the user whose key is being checked
+        key_type: the type of key that the key should be
+        signing_key: the signing key that the key should be signed with.  If
+            omitted, signatures will not be checked.
     """
     if (
         key.get("user_id") != user_id
@@ -1162,16 +1191,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
             )
 
 
-def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+def _check_device_signature(
+    user_id: str,
+    verify_key: VerifyKey,
+    signed_device: JsonDict,
+    stored_device: JsonDict,
+) -> None:
     """Check that a signature on a device or cross-signing key is correct and
     matches the copy of the device/key that we have stored.  Throws an
     exception if an error is detected.
 
     Args:
-        user_id (str): the user ID whose signature is being checked
-        verify_key (VerifyKey): the key to verify the device with
-        signed_device (dict): the uploaded signed device data
-        stored_device (dict): our previously stored copy of the device
+        user_id: the user ID whose signature is being checked
+        verify_key: the key to verify the device with
+        signed_device: the uploaded signed device data
+        stored_device: our previously stored copy of the device
 
     Raises:
         SynapseError: if the signature was invalid or the sent device is not the
@@ -1201,7 +1235,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
         raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
 
 
-def _exception_to_failure(e):
+def _exception_to_failure(e: Exception) -> JsonDict:
     if isinstance(e, SynapseError):
         return {"status": e.code, "errcode": e.errcode, "message": str(e)}
 
@@ -1218,7 +1252,7 @@ def _exception_to_failure(e):
     return {"status": 503, "message": str(e)}
 
 
-def _one_time_keys_match(old_key_json, new_key):
+def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
     old_key = json_decoder.decode(old_key_json)
 
     # if either is a string rather than an object, they must match exactly
@@ -1236,19 +1270,18 @@ def _one_time_keys_match(old_key_json, new_key):
 
 @attr.s(slots=True)
 class SignatureListItem:
-    """An item in the signature list as used by upload_signatures_for_device_keys.
-    """
+    """An item in the signature list as used by upload_signatures_for_device_keys."""
 
-    signing_key_id = attr.ib()
-    target_user_id = attr.ib()
-    target_device_id = attr.ib()
-    signature = attr.ib()
+    signing_key_id = attr.ib(type=str)
+    target_user_id = attr.ib(type=str)
+    target_device_id = attr.ib(type=str)
+    signature = attr.ib(type=JsonDict)
 
 
 class SigningKeyEduUpdater:
     """Handles incoming signing key updates from federation and updates the DB"""
 
-    def __init__(self, hs, e2e_keys_handler):
+    def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
@@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = {}
+        self._pending_updates = {}  # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
 
         # Recently seen stream ids. We don't bother keeping these in the DB,
         # but they're useful to have them about to reduce the number of spurious
@@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
             iterable=True,
         )
 
-    async def incoming_signing_key_update(self, origin, edu_content):
+    async def incoming_signing_key_update(
+        self, origin: str, edu_content: JsonDict
+    ) -> None:
         """Called on incoming signing key update from federation. Responsible for
         parsing the EDU and adding to pending updates list.
 
         Args:
-            origin (string): the server that sent the EDU
-            edu_content (dict): the contents of the EDU
+            origin: the server that sent the EDU
+            edu_content: the contents of the EDU
         """
 
         user_id = edu_content.pop("user_id")
@@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
 
         await self._handle_signing_key_updates(user_id)
 
-    async def _handle_signing_key_updates(self, user_id):
+    async def _handle_signing_key_updates(self, user_id: str) -> None:
         """Actually handle pending updates.
 
         Args:
-            user_id (string): the user whose updates we are processing
+            user_id: the user whose updates we are processing
         """
 
         device_handler = self.e2e_keys_handler.device_handler
@@ -1315,13 +1350,17 @@ class SigningKeyEduUpdater:
                 # This can happen since we batch updates
                 return
 
-            device_ids = []
+            device_ids = []  # type: List[str]
 
             logger.info("pending updates: %r", pending_updates)
 
             for master_key, self_signing_key in pending_updates:
-                new_device_ids = await device_list_updater.process_cross_signing_key_update(
-                    user_id, master_key, self_signing_key,
+                new_device_ids = (
+                    await device_list_updater.process_cross_signing_key_update(
+                        user_id,
+                        master_key,
+                        self_signing_key,
+                    )
                 )
                 device_ids = device_ids + new_device_ids
 
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f01b090772..622cae23be 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, List, Optional
 
 from synapse.api.errors import (
     Codes,
@@ -24,8 +25,12 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.opentracing import log_kv, trace
+from synapse.types import JsonDict
 from synapse.util.async_helpers import Linearizer
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
     The actual payload of the encrypted keys is completely opaque to the handler.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
         # Used to lock whenever a client is uploading key data.  This prevents collisions
@@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
         self._upload_linearizer = Linearizer("upload_room_keys_lock")
 
     @trace
-    async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def get_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> List[JsonDict]:
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
         See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
 
         Args:
-            user_id(str): the user whose keys we're getting
-            version(str): the version ID of the backup we're getting keys from
-            room_id(string): room ID to get keys for, for None to get keys for all rooms
-            session_id(string): session ID to get keys for, for None to get keys for all
+            user_id: the user whose keys we're getting
+            version: the version ID of the backup we're getting keys from
+            room_id: room ID to get keys for, for None to get keys for all rooms
+            session_id: session ID to get keys for, for None to get keys for all
                 sessions
         Raises:
             NotFoundError: if the backup version does not exist
         Returns:
-            A deferred list of dicts giving the session_data and message metadata for
+            A list of dicts giving the session_data and message metadata for
             these room keys.
         """
 
@@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
             return results
 
     @trace
-    async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def delete_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> JsonDict:
         """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
         room or a given session.
         See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
 
         Args:
-            user_id(str): the user whose backup we're deleting
-            version(str): the version ID of the backup we're deleting
-            room_id(string): room ID to delete keys for, for None to delete keys for all
+            user_id: the user whose backup we're deleting
+            version: the version ID of the backup we're deleting
+            room_id: room ID to delete keys for, for None to delete keys for all
                 rooms
-            session_id(string): session ID to delete keys for, for None to delete keys
+            session_id: session ID to delete keys for, for None to delete keys
                 for all sessions
         Raises:
             NotFoundError: if the backup version does not exist
@@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
             return {"etag": str(version_etag), "count": count}
 
     @trace
-    async def upload_room_keys(self, user_id, version, room_keys):
+    async def upload_room_keys(
+        self, user_id: str, version: str, room_keys: JsonDict
+    ) -> JsonDict:
         """Bulk upload a list of room keys into a given backup version, asserting
         that the given version is the current backup version.  room_keys are merged
         into the current backup as described in RoomKeysServlet.on_PUT().
 
         Args:
-            user_id(str): the user whose backup we're setting
-            version(str): the version ID of the backup we're updating
-            room_keys(dict): a nested dict describing the room_keys we're setting:
+            user_id: the user whose backup we're setting
+            version: the version ID of the backup we're updating
+            room_keys: a nested dict describing the room_keys we're setting:
 
         {
             "rooms": {
@@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
             return {"etag": str(version_etag), "count": count}
 
     @staticmethod
-    def _should_replace_room_key(current_room_key, room_key):
+    def _should_replace_room_key(
+        current_room_key: Optional[JsonDict], room_key: JsonDict
+    ) -> bool:
         """
         Determine whether to replace a given current_room_key (if any)
         with a newly uploaded room_key backup
 
         Args:
-            current_room_key (dict): Optional, the current room_key dict if any
-            room_key (dict): The new room_key dict which may or may not be fit to
+            current_room_key: Optional, the current room_key dict if any
+            room_key : The new room_key dict which may or may not be fit to
                 replace the current_room_key
 
         Returns:
@@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
         return True
 
     @trace
-    async def create_version(self, user_id, version_info):
+    async def create_version(self, user_id: str, version_info: JsonDict) -> str:
         """Create a new backup version.  This automatically becomes the new
         backup version for the user's keys; previous backups will no longer be
         writeable to.
 
         Args:
-            user_id(str): the user whose backup version we're creating
-            version_info(dict): metadata about the new version being created
+            user_id: the user whose backup version we're creating
+            version_info: metadata about the new version being created
 
         {
             "algorithm": "m.megolm_backup.v1",
@@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
         }
 
         Returns:
-            A deferred of a string that gives the new version number.
+            The new version number.
         """
 
         # TODO: Validate the JSON to make sure it has the right keys.
@@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
             )
             return new_version
 
-    async def get_version_info(self, user_id, version=None):
+    async def get_version_info(
+        self, user_id: str, version: Optional[str] = None
+    ) -> JsonDict:
         """Get the info about a given version of the user's backup
 
         Args:
-            user_id(str): the user whose current backup version we're querying
-            version(str): Optional; if None gives the most recent version
+            user_id: the user whose current backup version we're querying
+            version: Optional; if None gives the most recent version
                 otherwise a historical one.
         Raises:
             NotFoundError: if the requested backup version doesn't exist
         Returns:
-            A deferred of a info dict that gives the info about the new version.
+            A info dict that gives the info about the new version.
 
         {
             "version": "1234",
@@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
             return res
 
     @trace
-    async def delete_version(self, user_id, version=None):
+    async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
         """Deletes a given version of the user's e2e_room_keys backup
 
         Args:
@@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
                     raise
 
     @trace
-    async def update_version(self, user_id, version, version_info):
+    async def update_version(
+        self, user_id: str, version: str, version_info: JsonDict
+    ) -> JsonDict:
         """Update the info about a given version of the user's backup
 
         Args:
-            user_id(str): the user whose current backup version we're updating
-            version(str): the backup version we're updating
-            version_info(dict): the new information about the backup
+            user_id: the user whose current backup version we're updating
+            version: the backup version we're updating
+            version_info: the new information about the backup
         Raises:
             NotFoundError: if the requested backup version doesn't exist
         Returns:
-            A deferred of an empty dict.
+            An empty dict.
         """
         if "version" not in version_info:
             version_info["version"] = version
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 539b4fc32e..3e23f82cf7 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -57,8 +57,7 @@ class EventStreamHandler(BaseHandler):
         room_id: Optional[str] = None,
         is_guest: bool = False,
     ) -> JsonDict:
-        """Fetches the events stream for a given user.
-        """
+        """Fetches the events stream for a given user."""
 
         if room_id:
             blocked = await self.store.is_room_blocked(room_id)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fd8de8696d..2ead626a4d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -111,13 +111,13 @@ class _NewEventInfo:
 
 class FederationHandler(BaseHandler):
     """Handles events that originated from federation.
-        Responsible for:
-        a) handling received Pdus before handing them on as Events to the rest
-        of the homeserver (including auth and state conflict resolutions)
-        b) converting events that were produced by local clients that may need
-        to be sent to remote homeservers.
-        c) doing the necessary dances to invite remote users and join remote
-        rooms.
+    Responsible for:
+    a) handling received Pdus before handing them on as Events to the rest
+    of the homeserver (including auth and state conflict resolutions)
+    b) converting events that were produced by local clients that may need
+    to be sent to remote homeservers.
+    c) doing the necessary dances to invite remote users and join remote
+    rooms.
     """
 
     def __init__(self, hs: "HomeServer"):
@@ -150,11 +150,11 @@ class FederationHandler(BaseHandler):
         )
 
         if hs.config.worker_app:
-            self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
-                hs
+            self._user_device_resync = (
+                ReplicationUserDevicesResyncRestServlet.make_client(hs)
             )
-            self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
-                hs
+            self._maybe_store_room_on_outlier_membership = (
+                ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
             )
         else:
             self._device_list_updater = hs.get_device_handler().device_list_updater
@@ -172,7 +172,7 @@ class FederationHandler(BaseHandler):
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
     async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
-        """ Process a PDU received via a federation /send/ transaction, or
+        """Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
         Args:
@@ -368,7 +368,8 @@ class FederationHandler(BaseHandler):
                     # know about
                     for p in prevs - seen:
                         logger.info(
-                            "Requesting state at missing prev_event %s", event_id,
+                            "Requesting state at missing prev_event %s",
+                            event_id,
                         )
 
                         with nested_logging_context(p):
@@ -388,12 +389,14 @@ class FederationHandler(BaseHandler):
                                 event_map[x.event_id] = x
 
                     room_version = await self.store.get_room_version_id(room_id)
-                    state_map = await self._state_resolution_handler.resolve_events_with_store(
-                        room_id,
-                        room_version,
-                        state_maps,
-                        event_map,
-                        state_res_store=StateResolutionStore(self.store),
+                    state_map = (
+                        await self._state_resolution_handler.resolve_events_with_store(
+                            room_id,
+                            room_version,
+                            state_maps,
+                            event_map,
+                            state_res_store=StateResolutionStore(self.store),
+                        )
                     )
 
                     # We need to give _process_received_pdu the actual state events
@@ -687,9 +690,12 @@ class FederationHandler(BaseHandler):
         return fetched_events
 
     async def _process_received_pdu(
-        self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+        self,
+        origin: str,
+        event: EventBase,
+        state: Optional[Iterable[EventBase]],
     ):
-        """ Called when we have a new pdu. We need to do auth checks and put it
+        """Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
 
         Args:
@@ -801,7 +807,7 @@ class FederationHandler(BaseHandler):
 
     @log_function
     async def backfill(self, dest, room_id, limit, extremities):
-        """ Trigger a backfill request to `dest` for the given `room_id`
+        """Trigger a backfill request to `dest` for the given `room_id`
 
         This will attempt to get more events from the remote. If the other side
         has no new events to offer, this will return an empty list.
@@ -1204,11 +1210,16 @@ class FederationHandler(BaseHandler):
             with nested_logging_context(event_id):
                 try:
                     event = await self.federation_client.get_pdu(
-                        [destination], event_id, room_version, outlier=True,
+                        [destination],
+                        event_id,
+                        room_version,
+                        outlier=True,
                     )
                     if event is None:
                         logger.warning(
-                            "Server %s didn't return event %s", destination, event_id,
+                            "Server %s didn't return event %s",
+                            destination,
+                            event_id,
                         )
                         return
 
@@ -1235,7 +1246,8 @@ class FederationHandler(BaseHandler):
             if aid not in event_map
         ]
         persisted_events = await self.store.get_events(
-            auth_events, allow_rejected=True,
+            auth_events,
+            allow_rejected=True,
         )
 
         event_infos = []
@@ -1251,7 +1263,9 @@ class FederationHandler(BaseHandler):
             event_infos.append(_NewEventInfo(event, None, auth))
 
         await self._handle_new_events(
-            destination, room_id, event_infos,
+            destination,
+            room_id,
+            event_infos,
         )
 
     def _sanity_check_event(self, ev):
@@ -1287,7 +1301,7 @@ class FederationHandler(BaseHandler):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
 
     async def send_invite(self, target_host, event):
-        """ Sends the invite to the remote server for signing.
+        """Sends the invite to the remote server for signing.
 
         Invites must be signed by the invitee's server before distribution.
         """
@@ -1310,7 +1324,7 @@ class FederationHandler(BaseHandler):
     async def do_invite_join(
         self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
     ) -> Tuple[str, int]:
-        """ Attempts to join the `joinee` to the room `room_id` via the
+        """Attempts to join the `joinee` to the room `room_id` via the
         servers contained in `target_hosts`.
 
         This first triggers a /make_join/ request that returns a partial
@@ -1354,8 +1368,6 @@ class FederationHandler(BaseHandler):
 
         await self._clean_room_for_join(room_id)
 
-        handled_events = set()
-
         try:
             # Try the host we successfully got a response to /make_join/
             # request first.
@@ -1375,10 +1387,6 @@ class FederationHandler(BaseHandler):
             auth_chain = ret["auth_chain"]
             auth_chain.sort(key=lambda e: e.depth)
 
-            handled_events.update([s.event_id for s in state])
-            handled_events.update([a.event_id for a in auth_chain])
-            handled_events.add(event.event_id)
-
             logger.debug("do_invite_join auth_chain: %s", auth_chain)
             logger.debug("do_invite_join state: %s", state)
 
@@ -1394,7 +1402,8 @@ class FederationHandler(BaseHandler):
             # so we can rely on it now.
             #
             await self.store.upsert_room_on_join(
-                room_id=room_id, room_version=room_version_obj,
+                room_id=room_id,
+                room_version=room_version_obj,
             )
 
             max_stream_id = await self._persist_auth_tree(
@@ -1464,7 +1473,7 @@ class FederationHandler(BaseHandler):
     async def on_make_join_request(
         self, origin: str, room_id: str, user_id: str
     ) -> EventBase:
-        """ We've received a /make_join/ request, so we create a partial
+        """We've received a /make_join/ request, so we create a partial
         join event for the room and return that. We do *not* persist or
         process it until the other server has signed it and sent it back.
 
@@ -1489,7 +1498,8 @@ class FederationHandler(BaseHandler):
         is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
         if not is_in_room:
             logger.info(
-                "Got /make_join request for room %s we are no longer in", room_id,
+                "Got /make_join request for room %s we are no longer in",
+                room_id,
             )
             raise NotFoundError("Not an active room on this server")
 
@@ -1523,7 +1533,7 @@ class FederationHandler(BaseHandler):
         return event
 
     async def on_send_join_request(self, origin, pdu):
-        """ We have received a join event for a room. Fully process it and
+        """We have received a join event for a room. Fully process it and
         respond with the current state and auth chains.
         """
         event = pdu
@@ -1579,7 +1589,7 @@ class FederationHandler(BaseHandler):
     async def on_invite_request(
         self, origin: str, event: EventBase, room_version: RoomVersion
     ):
-        """ We've got an invite event. Process and persist it. Sign it.
+        """We've got an invite event. Process and persist it. Sign it.
 
         Respond with the now signed event.
         """
@@ -1617,6 +1627,12 @@ class FederationHandler(BaseHandler):
         if event.state_key == self._server_notices_mxid:
             raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
+        # We retrieve the room member handler here as to not cause a cyclic dependency
+        member_handler = self.hs.get_room_member_handler()
+        # We don't rate limit based on room ID, as that should be done by
+        # sending server.
+        member_handler.ratelimit_invite(None, event.state_key)
+
         # keep a record of the room version, if we don't yet know it.
         # (this may get overwritten if we later get a different room version in a
         # join dance).
@@ -1700,7 +1716,7 @@ class FederationHandler(BaseHandler):
     async def on_make_leave_request(
         self, origin: str, room_id: str, user_id: str
     ) -> EventBase:
-        """ We've received a /make_leave/ request, so we create a partial
+        """We've received a /make_leave/ request, so we create a partial
         leave event for the room and return that. We do *not* persist or
         process it until the other server has signed it and sent it back.
 
@@ -1776,8 +1792,7 @@ class FederationHandler(BaseHandler):
         return None
 
     async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
-        """Returns the state at the event. i.e. not including said event.
-        """
+        """Returns the state at the event. i.e. not including said event."""
 
         event = await self.store.get_event(event_id, check_room_id=room_id)
 
@@ -1803,8 +1818,7 @@ class FederationHandler(BaseHandler):
             return []
 
     async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
-        """Returns the state at the event. i.e. not including said event.
-        """
+        """Returns the state at the event. i.e. not including said event."""
         event = await self.store.get_event(event_id, check_room_id=room_id)
 
         state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@@ -2010,7 +2024,11 @@ class FederationHandler(BaseHandler):
 
         for e_id in missing_auth_events:
             m_ev = await self.federation_client.get_pdu(
-                [origin], e_id, room_version=room_version, outlier=True, timeout=10000,
+                [origin],
+                e_id,
+                room_version=room_version,
+                outlier=True,
+                timeout=10000,
             )
             if m_ev and m_ev.event_id == e_id:
                 event_map[e_id] = m_ev
@@ -2093,6 +2111,11 @@ class FederationHandler(BaseHandler):
         if event.type == EventTypes.GuestAccess and not context.rejected:
             await self.maybe_kick_guest_users(event)
 
+        # If we are going to send this event over federation we precaclculate
+        # the joined hosts.
+        if event.internal_metadata.get_send_on_behalf_of():
+            await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
         return context
 
     async def _check_for_soft_fail(
@@ -2155,7 +2178,9 @@ class FederationHandler(BaseHandler):
             )
 
         logger.debug(
-            "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
+            "Doing soft-fail check for %s: state %s",
+            event.event_id,
+            current_state_ids,
         )
 
         # Now check if event pass auth against said current state
@@ -2508,7 +2533,7 @@ class FederationHandler(BaseHandler):
     async def construct_auth_difference(
         self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
     ) -> Dict:
-        """ Given a local and remote auth chain, find the differences. This
+        """Given a local and remote auth chain, find the differences. This
         assumes that we have already processed all events in remote_auth
 
         Params:
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index df29edeb83..bfb95e3eee 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -15,9 +15,13 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set
 
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, get_domain_from_id
+from synapse.types import GroupID, JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
 
 
 class GroupsLocalWorkerHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.room_list_handler = hs.get_room_list_handler()
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
     get_group_role = _create_rerouter("get_group_role")
     get_group_roles = _create_rerouter("get_group_roles")
 
-    async def get_group_summary(self, group_id, requester_user_id):
+    async def get_group_summary(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the group summary for a group.
 
         If the group is remote we check that the users have valid attestations.
@@ -137,14 +143,14 @@ class GroupsLocalWorkerHandler:
 
         return res
 
-    async def get_users_in_group(self, group_id, requester_user_id):
-        """Get users in a group
-        """
+    async def get_users_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
+        """Get users in a group"""
         if self.is_mine_id(group_id):
-            res = await self.groups_server_handler.get_users_in_group(
+            return await self.groups_server_handler.get_users_in_group(
                 group_id, requester_user_id
             )
-            return res
 
         group_server_name = get_domain_from_id(group_id)
 
@@ -178,11 +184,11 @@ class GroupsLocalWorkerHandler:
 
         return res
 
-    async def get_joined_groups(self, user_id):
+    async def get_joined_groups(self, user_id: str) -> JsonDict:
         group_ids = await self.store.get_joined_groups(user_id)
         return {"groups": group_ids}
 
-    async def get_publicised_groups_for_user(self, user_id):
+    async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
         if self.hs.is_mine_id(user_id):
             result = await self.store.get_publicised_groups_for_user(user_id)
 
@@ -206,8 +212,10 @@ class GroupsLocalWorkerHandler:
             # TODO: Verify attestations
             return {"groups": result}
 
-    async def bulk_get_publicised_groups(self, user_ids, proxy=True):
-        destinations = {}
+    async def bulk_get_publicised_groups(
+        self, user_ids: Iterable[str], proxy: bool = True
+    ) -> JsonDict:
+        destinations = {}  # type: Dict[str, Set[str]]
         local_users = set()
 
         for user_id in user_ids:
@@ -220,7 +228,7 @@ class GroupsLocalWorkerHandler:
             raise SynapseError(400, "Some user_ids are not local")
 
         results = {}
-        failed_results = []
+        failed_results = []  # type: List[str]
         for destination, dest_user_ids in destinations.items():
             try:
                 r = await self.transport_client.bulk_get_publicised_groups(
@@ -242,7 +250,7 @@ class GroupsLocalWorkerHandler:
 
 
 class GroupsLocalHandler(GroupsLocalWorkerHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         # Ensure attestations get renewed
@@ -271,9 +279,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
     set_group_join_policy = _create_rerouter("set_group_join_policy")
 
-    async def create_group(self, group_id, user_id, content):
-        """Create a group
-        """
+    async def create_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
+        """Create a group"""
 
         logger.info("Asking to create group with ID: %r", group_id)
 
@@ -284,27 +293,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
             local_attestation = None
             remote_attestation = None
         else:
-            local_attestation = self.attestations.create_attestation(group_id, user_id)
-            content["attestation"] = local_attestation
-
-            content["user_profile"] = await self.profile_handler.get_profile(user_id)
-
-            try:
-                res = await self.transport_client.create_group(
-                    get_domain_from_id(group_id), group_id, user_id, content
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-            remote_attestation = res["attestation"]
-            await self.attestations.verify_attestation(
-                remote_attestation,
-                group_id=group_id,
-                user_id=user_id,
-                server_name=get_domain_from_id(group_id),
-            )
+            raise SynapseError(400, "Unable to create remote groups")
 
         is_publicised = content.get("publicise", False)
         token = await self.store.register_user_group_membership(
@@ -320,9 +309,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def join_group(self, group_id, user_id, content):
-        """Request to join a group
-        """
+    async def join_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
+        """Request to join a group"""
         if self.is_mine_id(group_id):
             await self.groups_server_handler.join_group(group_id, user_id, content)
             local_attestation = None
@@ -365,9 +355,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return {}
 
-    async def accept_invite(self, group_id, user_id, content):
-        """Accept an invite to a group
-        """
+    async def accept_invite(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
+        """Accept an invite to a group"""
         if self.is_mine_id(group_id):
             await self.groups_server_handler.accept_invite(group_id, user_id, content)
             local_attestation = None
@@ -410,9 +401,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return {}
 
-    async def invite(self, group_id, user_id, requester_user_id, config):
-        """Invite a user to a group
-        """
+    async def invite(
+        self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
+    ) -> JsonDict:
+        """Invite a user to a group"""
         content = {"requester_user_id": requester_user_id, "config": config}
         if self.is_mine_id(group_id):
             res = await self.groups_server_handler.invite_to_group(
@@ -434,9 +426,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def on_invite(self, group_id, user_id, content):
-        """One of our users were invited to a group
-        """
+    async def on_invite(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
+        """One of our users were invited to a group"""
         # TODO: Support auto join and rejection
 
         if not self.is_mine_id(user_id):
@@ -465,10 +458,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
         return {"state": "invite", "user_profile": user_profile}
 
     async def remove_user_from_group(
-        self, group_id, user_id, requester_user_id, content
-    ):
-        """Remove a user from a group
-        """
+        self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
+        """Remove a user from a group"""
         if user_id == requester_user_id:
             token = await self.store.register_user_group_membership(
                 group_id, user_id, membership="leave"
@@ -499,9 +491,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def user_removed_from_group(self, group_id, user_id, content):
-        """One of our users was removed/kicked from a group
-        """
+    async def user_removed_from_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> None:
+        """One of our users was removed/kicked from a group"""
         # TODO: Check if user in group
         token = await self.store.register_user_group_membership(
             group_id, user_id, membership="leave"
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index f61844d688..5f346f6d6d 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -27,9 +27,11 @@ from synapse.api.errors import (
     HttpResponseException,
     SynapseError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
+from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, Requester
 from synapse.util import json_decoder
 from synapse.util.hash import sha256_and_url_safe_base64
@@ -57,6 +59,35 @@ class IdentityHandler(BaseHandler):
 
         self._web_client_location = hs.config.invite_client_location
 
+        # Ratelimiters for `/requestToken` endpoints.
+        self._3pid_validation_ratelimiter_ip = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+        )
+        self._3pid_validation_ratelimiter_address = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+        )
+
+    def ratelimit_request_token_requests(
+        self,
+        request: SynapseRequest,
+        medium: str,
+        address: str,
+    ):
+        """Used to ratelimit requests to `/requestToken` by IP and address.
+
+        Args:
+            request: The associated request
+            medium: The type of threepid, e.g. "msisdn" or "email"
+            address: The actual threepid ID, e.g. the phone number or email address
+        """
+
+        self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
+        self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+
     async def threepid_from_creds(
         self, id_server: str, creds: Dict[str, str]
     ) -> Optional[JsonDict]:
@@ -476,6 +507,10 @@ class IdentityHandler(BaseHandler):
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
+        # It is already checked that public_baseurl is configured since this code
+        # should only be used if account_threepid_delegate_msisdn is true.
+        assert self.hs.config.public_baseurl
+
         # we need to tell the client to send the token back to us, since it doesn't
         # otherwise know where to send it, so add submit_url response parameter
         # (see also MSC2078)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index fbd8df9dcc..78c3e5a10b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -124,7 +124,8 @@ class InitialSyncHandler(BaseHandler):
 
         joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
         receipt = await self.store.get_linearized_receipts_for_rooms(
-            joined_rooms, to_key=int(now_token.receipt_key),
+            joined_rooms,
+            to_key=int(now_token.receipt_key),
         )
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -169,7 +170,10 @@ class InitialSyncHandler(BaseHandler):
                         self.state_handler.get_current_state, event.room_id
                     )
                 elif event.membership == Membership.LEAVE:
-                    room_end_token = RoomStreamToken(None, event.stream_ordering,)
+                    room_end_token = RoomStreamToken(
+                        None,
+                        event.stream_ordering,
+                    )
                     deferred_room_state = run_in_background(
                         self.state_store.get_state_for_events, [event.event_id]
                     )
@@ -284,7 +288,9 @@ class InitialSyncHandler(BaseHandler):
             membership,
             member_event_id,
         ) = await self.auth.check_user_in_room_or_world_readable(
-            room_id, user_id, allow_departed_users=True,
+            room_id,
+            user_id,
+            allow_departed_users=True,
         )
         is_peeking = member_event_id is None
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9dfeab09cd..c03f6c997b 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -65,8 +65,7 @@ logger = logging.getLogger(__name__)
 
 
 class MessageHandler:
-    """Contains some read only APIs to get state about a room
-    """
+    """Contains some read only APIs to get state about a room"""
 
     def __init__(self, hs):
         self.auth = hs.get_auth()
@@ -88,9 +87,13 @@ class MessageHandler:
             )
 
     async def get_room_data(
-        self, user_id: str, room_id: str, event_type: str, state_key: str,
+        self,
+        user_id: str,
+        room_id: str,
+        event_type: str,
+        state_key: str,
     ) -> dict:
-        """ Get data from a room.
+        """Get data from a room.
 
         Args:
             user_id
@@ -174,7 +177,10 @@ class MessageHandler:
                 raise NotFoundError("Can't find event for token %s" % (at_token,))
 
             visible_events = await filter_events_for_client(
-                self.storage, user_id, last_events, filter_send_to_client=False
+                self.storage,
+                user_id,
+                last_events,
+                filter_send_to_client=False,
             )
 
             event = last_events[0]
@@ -432,6 +438,8 @@ class EventCreationHandler:
 
         self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
 
+        self._external_cache = hs.get_external_cache()
+
     async def create_event(
         self,
         requester: Requester,
@@ -569,7 +577,7 @@ class EventCreationHandler:
     async def _is_exempt_from_privacy_policy(
         self, builder: EventBuilder, requester: Requester
     ) -> bool:
-        """"Determine if an event to be sent is exempt from having to consent
+        """ "Determine if an event to be sent is exempt from having to consent
         to the privacy policy
 
         Args:
@@ -791,9 +799,10 @@ class EventCreationHandler:
         """
 
         if prev_event_ids is not None:
-            assert len(prev_event_ids) <= 10, (
-                "Attempting to create an event with %i prev_events"
-                % (len(prev_event_ids),)
+            assert (
+                len(prev_event_ids) <= 10
+            ), "Attempting to create an event with %i prev_events" % (
+                len(prev_event_ids),
             )
         else:
             prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
@@ -819,7 +828,8 @@ class EventCreationHandler:
         )
         if not third_party_result:
             logger.info(
-                "Event %s forbidden by third-party rules", event,
+                "Event %s forbidden by third-party rules",
+                event,
             )
             raise SynapseError(
                 403, "This event is not allowed in this context", Codes.FORBIDDEN
@@ -939,6 +949,8 @@ class EventCreationHandler:
 
         await self.action_generator.handle_push_actions_for_event(event, context)
 
+        await self.cache_joined_hosts_for_event(event)
+
         try:
             # If we're a worker we need to hit out to the master.
             writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +990,44 @@ class EventCreationHandler:
             await self.store.remove_push_actions_from_staging(event.event_id)
             raise
 
+    async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+        """Precalculate the joined hosts at the event, when using Redis, so that
+        external federation senders don't have to recalculate it themselves.
+        """
+
+        if not self._external_cache.is_enabled():
+            return
+
+        # We actually store two mappings, event ID -> prev state group,
+        # state group -> joined hosts, which is much more space efficient
+        # than event ID -> joined hosts.
+        #
+        # Note: We have to cache event ID -> prev state group, as we don't
+        # store that in the DB.
+        #
+        # Note: We always set the state group -> joined hosts cache, even if
+        # we already set it, so that the expiry time is reset.
+
+        state_entry = await self.state.resolve_state_groups_for_events(
+            event.room_id, event_ids=event.prev_event_ids()
+        )
+
+        if state_entry.state_group:
+            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+            await self._external_cache.set(
+                "event_to_prev_state_group",
+                event.event_id,
+                state_entry.state_group,
+                expiry_ms=60 * 60 * 1000,
+            )
+            await self._external_cache.set(
+                "get_joined_hosts",
+                str(state_entry.state_group),
+                list(joined_hosts),
+                expiry_ms=60 * 60 * 1000,
+            )
+
     async def _validate_canonical_alias(
         self, directory_handler, room_alias_str: str, expected_room_id: str
     ) -> None:
@@ -1274,7 +1324,11 @@ class EventCreationHandler:
                 # Since this is a dummy-event it is OK if it is sent by a
                 # shadow-banned user.
                 await self.handle_new_client_event(
-                    requester, event, context, ratelimit=False, ignore_shadow_ban=True,
+                    requester,
+                    event,
+                    context,
+                    ratelimit=False,
+                    ignore_shadow_ban=True,
                 )
                 return True
             except AuthError:
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 81cb2ffc6b..f73cbe2af3 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -41,13 +41,33 @@ from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import json_decoder
+from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
-SESSION_COOKIE_NAME = b"oidc_session"
+# we want the cookie to be returned to us even when the request is the POSTed
+# result of a form on another domain, as is used with `response_mode=form_post`.
+#
+# Modern browsers will not do so unless we set SameSite=None; however *older*
+# browsers (including all versions of Safari on iOS 12?) don't support
+# SameSite=None, and interpret it as SameSite=Strict:
+# https://bugs.webkit.org/show_bug.cgi?id=198181
+#
+# As a rather painful workaround, we set *two* cookies, one with SameSite=None
+# and one with no SameSite, in the hope that at least one of them will get
+# back to us.
+#
+# Secure is necessary for SameSite=None (and, empirically, also breaks things
+# on iOS 12.)
+#
+# Here we have the names of the cookies, and the options we use to set them.
+_SESSION_COOKIES = [
+    (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"),
+    (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"),
+]
 
 #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
 #: OpenID.Core sec 3.1.3.3.
@@ -72,8 +92,7 @@ JWKS = TypedDict("JWKS", {"keys": List[JWK]})
 
 
 class OidcHandler:
-    """Handles requests related to the OpenID Connect login flow.
-    """
+    """Handles requests related to the OpenID Connect login flow."""
 
     def __init__(self, hs: "HomeServer"):
         self._sso_handler = hs.get_sso_handler()
@@ -102,7 +121,7 @@ class OidcHandler:
                 ) from e
 
     async def handle_oidc_callback(self, request: SynapseRequest) -> None:
-        """Handle an incoming request to /_synapse/oidc/callback
+        """Handle an incoming request to /_synapse/client/oidc/callback
 
         Since we might want to display OIDC-related errors in a user-friendly
         way, we don't raise SynapseError from here. Instead, we call
@@ -123,7 +142,6 @@ class OidcHandler:
         Args:
             request: the incoming request from the browser.
         """
-
         # The provider might redirect with an error.
         # In that case, just display it as-is.
         if b"error" in request.args:
@@ -137,8 +155,12 @@ class OidcHandler:
             # either the provider misbehaving or Synapse being misconfigured.
             # The only exception of that is "access_denied", where the user
             # probably cancelled the login flow. In other cases, log those errors.
-            if error != "access_denied":
-                logger.error("Error from the OIDC provider: %s %s", error, description)
+            logger.log(
+                logging.INFO if error == "access_denied" else logging.ERROR,
+                "Received OIDC callback with error: %s %s",
+                error,
+                description,
+            )
 
             self._sso_handler.render_error(request, error, description)
             return
@@ -146,30 +168,37 @@ class OidcHandler:
         # otherwise, it is presumably a successful response. see:
         #   https://tools.ietf.org/html/rfc6749#section-4.1.2
 
-        # Fetch the session cookie
-        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
-        if session is None:
-            logger.info("No session cookie found")
+        # Fetch the session cookie. See the comments on SESSION_COOKIES for why there
+        # are two.
+
+        for cookie_name, _ in _SESSION_COOKIES:
+            session = request.getCookie(cookie_name)  # type: Optional[bytes]
+            if session is not None:
+                break
+        else:
+            logger.info("Received OIDC callback, with no session cookie")
             self._sso_handler.render_error(
                 request, "missing_session", "No session cookie found"
             )
             return
 
-        # Remove the cookie. There is a good chance that if the callback failed
+        # Remove the cookies. There is a good chance that if the callback failed
         # once, it will fail next time and the code will already be exchanged.
-        # Removing it early avoids spamming the provider with token requests.
-        request.addCookie(
-            SESSION_COOKIE_NAME,
-            b"",
-            path="/_synapse/oidc",
-            expires="Thu, Jan 01 1970 00:00:00 UTC",
-            httpOnly=True,
-            sameSite="lax",
-        )
+        # Removing the cookies early avoids spamming the provider with token requests.
+        #
+        # we have to build the header by hand rather than calling request.addCookie
+        # because the latter does not support SameSite=None
+        # (https://twistedmatrix.com/trac/ticket/10088)
+
+        for cookie_name, options in _SESSION_COOKIES:
+            request.cookies.append(
+                b"%s=; Expires=Thu, Jan 01 1970 00:00:00 UTC; %s"
+                % (cookie_name, options)
+            )
 
         # Check for the state query parameter
         if b"state" not in request.args:
-            logger.info("State parameter is missing")
+            logger.info("Received OIDC callback, with no state parameter")
             self._sso_handler.render_error(
                 request, "invalid_request", "State parameter is missing"
             )
@@ -183,14 +212,16 @@ class OidcHandler:
                 session, state
             )
         except (MacaroonDeserializationException, ValueError) as e:
-            logger.exception("Invalid session")
+            logger.exception("Invalid session for OIDC callback")
             self._sso_handler.render_error(request, "invalid_session", str(e))
             return
         except MacaroonInvalidSignatureException as e:
-            logger.exception("Could not verify session")
+            logger.exception("Could not verify session for OIDC callback")
             self._sso_handler.render_error(request, "mismatching_session", str(e))
             return
 
+        logger.info("Received OIDC callback for IdP %s", session_data.idp_id)
+
         oidc_provider = self._providers.get(session_data.idp_id)
         if not oidc_provider:
             logger.error("OIDC session uses unknown IdP %r", oidc_provider)
@@ -210,8 +241,7 @@ class OidcHandler:
 
 
 class OidcError(Exception):
-    """Used to catch errors when calling the token_endpoint
-    """
+    """Used to catch errors when calling the token_endpoint"""
 
     def __init__(self, error, error_description=None):
         self.error = error
@@ -240,22 +270,27 @@ class OidcProvider:
 
         self._token_generator = token_generator
 
+        self._config = provider
         self._callback_url = hs.config.oidc_callback_url  # type: str
 
         self._scopes = provider.scopes
         self._user_profile_method = provider.user_profile_method
         self._client_auth = ClientAuth(
-            provider.client_id, provider.client_secret, provider.client_auth_method,
+            provider.client_id,
+            provider.client_secret,
+            provider.client_auth_method,
         )  # type: ClientAuth
         self._client_auth_method = provider.client_auth_method
-        self._provider_metadata = OpenIDProviderMetadata(
-            issuer=provider.issuer,
-            authorization_endpoint=provider.authorization_endpoint,
-            token_endpoint=provider.token_endpoint,
-            userinfo_endpoint=provider.userinfo_endpoint,
-            jwks_uri=provider.jwks_uri,
-        )  # type: OpenIDProviderMetadata
-        self._provider_needs_discovery = provider.discover
+
+        # cache of metadata for the identity provider (endpoint uris, mostly). This is
+        # loaded on-demand from the discovery endpoint (if discovery is enabled), with
+        # possible overrides from the config.  Access via `load_metadata`.
+        self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
+
+        # cache of JWKs used by the identity provider to sign tokens. Loaded on demand
+        # from the IdP's jwks_uri, if required.
+        self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
+
         self._user_mapping_provider = provider.user_mapping_provider_class(
             provider.user_mapping_provider_config
         )
@@ -274,11 +309,14 @@ class OidcProvider:
         # MXC URI for icon for this auth provider
         self.idp_icon = provider.idp_icon
 
+        # optional brand identifier for this auth provider
+        self.idp_brand = provider.idp_brand
+
         self._sso_handler = hs.get_sso_handler()
 
         self._sso_handler.register_identity_provider(self)
 
-    def _validate_metadata(self):
+    def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
         """Verifies the provider metadata.
 
         This checks the validity of the currently loaded provider. Not
@@ -297,7 +335,6 @@ class OidcProvider:
         if self._skip_verification is True:
             return
 
-        m = self._provider_metadata
         m.validate_issuer()
         m.validate_authorization_endpoint()
         m.validate_token_endpoint()
@@ -332,11 +369,7 @@ class OidcProvider:
                 )
         else:
             # If we're not using userinfo, we need a valid jwks to validate the ID token
-            if m.get("jwks") is None:
-                if m.get("jwks_uri") is not None:
-                    m.validate_jwks_uri()
-                else:
-                    raise ValueError('"jwks_uri" must be set')
+            m.validate_jwks_uri()
 
     @property
     def _uses_userinfo(self) -> bool:
@@ -353,11 +386,15 @@ class OidcProvider:
             or self._user_profile_method == "userinfo_endpoint"
         )
 
-    async def load_metadata(self) -> OpenIDProviderMetadata:
-        """Load and validate the provider metadata.
+    async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
+        """Return the provider metadata.
 
-        The values metadatas are discovered if ``oidc_config.discovery`` is
-        ``True`` and then cached.
+        If this is the first call, the metadata is built from the config and from the
+        metadata discovery endpoint (if enabled), and then validated. If the metadata
+        is successfully validated, it is then cached for future use.
+
+        Args:
+            force: If true, any cached metadata is discarded to force a reload.
 
         Raises:
             ValueError: if something in the provider is not valid
@@ -365,18 +402,41 @@ class OidcProvider:
         Returns:
             The provider's metadata.
         """
-        # If we are using the OpenID Discovery documents, it needs to be loaded once
-        # FIXME: should there be a lock here?
-        if self._provider_needs_discovery:
-            url = get_well_known_url(self._provider_metadata["issuer"], external=True)
+        if force:
+            # reset the cached call to ensure we get a new result
+            self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
+
+        return await self._provider_metadata.get()
+
+    async def _load_metadata(self) -> OpenIDProviderMetadata:
+        # start out with just the issuer (unlike the other settings, discovered issuer
+        # takes precedence over configured issuer, because configured issuer is
+        # required for discovery to take place.)
+        #
+        metadata = OpenIDProviderMetadata(issuer=self._config.issuer)
+
+        # load any data from the discovery endpoint, if enabled
+        if self._config.discover:
+            url = get_well_known_url(self._config.issuer, external=True)
             metadata_response = await self._http_client.get_json(url)
-            # TODO: maybe update the other way around to let user override some values?
-            self._provider_metadata.update(metadata_response)
-            self._provider_needs_discovery = False
+            metadata.update(metadata_response)
 
-        self._validate_metadata()
+        # override any discovered data with any settings in our config
+        if self._config.authorization_endpoint:
+            metadata["authorization_endpoint"] = self._config.authorization_endpoint
 
-        return self._provider_metadata
+        if self._config.token_endpoint:
+            metadata["token_endpoint"] = self._config.token_endpoint
+
+        if self._config.userinfo_endpoint:
+            metadata["userinfo_endpoint"] = self._config.userinfo_endpoint
+
+        if self._config.jwks_uri:
+            metadata["jwks_uri"] = self._config.jwks_uri
+
+        self._validate_metadata(metadata)
+
+        return metadata
 
     async def load_jwks(self, force: bool = False) -> JWKS:
         """Load the JSON Web Key Set used to sign ID tokens.
@@ -406,27 +466,27 @@ class OidcProvider:
                     ]
                 }
         """
+        if force:
+            # reset the cached call to ensure we get a new result
+            self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
+        return await self._jwks.get()
+
+    async def _load_jwks(self) -> JWKS:
         if self._uses_userinfo:
             # We're not using jwt signing, return an empty jwk set
             return {"keys": []}
 
-        # First check if the JWKS are loaded in the provider metadata.
-        # It can happen either if the provider gives its JWKS in the discovery
-        # document directly or if it was already loaded once.
         metadata = await self.load_metadata()
-        jwk_set = metadata.get("jwks")
-        if jwk_set is not None and not force:
-            return jwk_set
 
-        # Loading the JWKS using the `jwks_uri` metadata
+        # Load the JWKS using the `jwks_uri` metadata.
         uri = metadata.get("jwks_uri")
         if not uri:
+            # this should be unreachable: load_metadata validates that
+            # there is a jwks_uri in the metadata if _uses_userinfo is unset
             raise RuntimeError('Missing "jwks_uri" in metadata')
 
         jwk_set = await self._http_client.get_json(uri)
 
-        # Caching the JWKS in the provider's metadata
-        self._provider_metadata["jwks"] = jwk_set
         return jwk_set
 
     async def _exchange_code(self, code: str) -> Token:
@@ -484,7 +544,10 @@ class OidcProvider:
         # We're not using the SimpleHttpClient util methods as we don't want to
         # check the HTTP status code and we do the body encoding ourself.
         response = await self._http_client.request(
-            method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
+            method="POST",
+            uri=uri,
+            data=body.encode("utf-8"),
+            headers=headers,
         )
 
         # This is used in multiple error messages below
@@ -562,6 +625,7 @@ class OidcProvider:
         Returns:
             UserInfo: an object representing the user.
         """
+        logger.debug("Using the OAuth2 access_token to request userinfo")
         metadata = await self.load_metadata()
 
         resp = await self._http_client.get_json(
@@ -569,6 +633,8 @@ class OidcProvider:
             headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
         )
 
+        logger.debug("Retrieved user info from userinfo endpoint: %r", resp)
+
         return UserInfo(resp)
 
     async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
@@ -597,17 +663,19 @@ class OidcProvider:
             claims_cls = ImplicitIDToken
 
         alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
-
         jwt = JsonWebToken(alg_values)
 
         claim_options = {"iss": {"values": [metadata["issuer"]]}}
 
+        id_token = token["id_token"]
+        logger.debug("Attempting to decode JWT id_token %r", id_token)
+
         # Try to decode the keys in cache first, then retry by forcing the keys
         # to be reloaded
         jwk_set = await self.load_jwks()
         try:
             claims = jwt.decode(
-                token["id_token"],
+                id_token,
                 key=jwk_set,
                 claims_cls=claims_cls,
                 claims_options=claim_options,
@@ -617,13 +685,15 @@ class OidcProvider:
             logger.info("Reloading JWKS after decode error")
             jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
             claims = jwt.decode(
-                token["id_token"],
+                id_token,
                 key=jwk_set,
                 claims_cls=claims_cls,
                 claims_options=claim_options,
                 claims_params=claims_params,
             )
 
+        logger.debug("Decoded id_token JWT %r; validating", claims)
+
         claims.validate(leeway=120)  # allows 2 min of clock skew
         return UserInfo(claims)
 
@@ -640,7 +710,7 @@ class OidcProvider:
 
           - ``client_id``: the client ID set in ``oidc_config.client_id``
           - ``response_type``: ``code``
-          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
           - ``scope``: the list of scopes set in ``oidc_config.scopes``
           - ``state``: a random string
           - ``nonce``: a random string
@@ -678,14 +748,18 @@ class OidcProvider:
                 ui_auth_session_id=ui_auth_session_id,
             ),
         )
-        request.addCookie(
-            SESSION_COOKIE_NAME,
-            cookie,
-            path="/_synapse/oidc",
-            max_age="3600",
-            httpOnly=True,
-            sameSite="lax",
-        )
+
+        # Set the cookies. See the comments on _SESSION_COOKIES for why there are two.
+        #
+        # we have to build the header by hand rather than calling request.addCookie
+        # because the latter does not support SameSite=None
+        # (https://twistedmatrix.com/trac/ticket/10088)
+
+        for cookie_name, options in _SESSION_COOKIES:
+            request.cookies.append(
+                b"%s=%s; Max-Age=3600; %s"
+                % (cookie_name, cookie.encode("utf-8"), options)
+            )
 
         metadata = await self.load_metadata()
         authorization_endpoint = metadata.get("authorization_endpoint")
@@ -720,7 +794,7 @@ class OidcProvider:
     async def handle_oidc_callback(
         self, request: SynapseRequest, session_data: "OidcSessionData", code: str
     ) -> None:
-        """Handle an incoming request to /_synapse/oidc/callback
+        """Handle an incoming request to /_synapse/client/oidc/callback
 
         By this time we have already validated the session on the synapse side, and
         now need to do the provider-specific operations. This includes:
@@ -741,19 +815,18 @@ class OidcProvider:
         """
         # Exchange the code with the provider
         try:
-            logger.debug("Exchanging code")
+            logger.debug("Exchanging OAuth2 code for a token")
             token = await self._exchange_code(code)
         except OidcError as e:
-            logger.exception("Could not exchange code")
+            logger.exception("Could not exchange OAuth2 code")
             self._sso_handler.render_error(request, e.error, e.error_description)
             return
 
-        logger.debug("Successfully obtained OAuth2 access token")
+        logger.debug("Successfully obtained OAuth2 token data: %r", token)
 
         # Now that we have a token, get the userinfo, either by decoding the
         # `id_token` or by fetching the `userinfo_endpoint`.
         if self._uses_userinfo:
-            logger.debug("Fetching userinfo")
             try:
                 userinfo = await self._fetch_userinfo(token)
             except Exception as e:
@@ -761,7 +834,6 @@ class OidcProvider:
                 self._sso_handler.render_error(request, "fetch_error", str(e))
                 return
         else:
-            logger.debug("Extracting userinfo from id_token")
             try:
                 userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
             except Exception as e:
@@ -954,7 +1026,9 @@ class OidcSessionTokenGenerator:
             A signed macaroon token with the session information.
         """
         macaroon = pymacaroons.Macaroon(
-            location=self._server_name, identifier="key", key=self._macaroon_secret_key,
+            location=self._server_name,
+            identifier="key",
+            key=self._macaroon_secret_key,
         )
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = session")
@@ -1074,7 +1148,8 @@ class OidcSessionData:
 
 
 UserAttributeDict = TypedDict(
-    "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
+    "UserAttributeDict",
+    {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
 )
 C = TypeVar("C")
 
@@ -1153,11 +1228,12 @@ def jinja_finalize(thing):
 env = Environment(finalize=jinja_finalize)
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class JinjaOidcMappingConfig:
     subject_claim = attr.ib(type=str)
     localpart_template = attr.ib(type=Optional[Template])
     display_name_template = attr.ib(type=Optional[Template])
+    email_template = attr.ib(type=Optional[Template])
     extra_attributes = attr.ib(type=Dict[str, Template])
 
 
@@ -1174,23 +1250,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     def parse_config(config: dict) -> JinjaOidcMappingConfig:
         subject_claim = config.get("subject_claim", "sub")
 
-        localpart_template = None  # type: Optional[Template]
-        if "localpart_template" in config:
+        def parse_template_config(option_name: str) -> Optional[Template]:
+            if option_name not in config:
+                return None
             try:
-                localpart_template = env.from_string(config["localpart_template"])
+                return env.from_string(config[option_name])
             except Exception as e:
-                raise ConfigError(
-                    "invalid jinja template", path=["localpart_template"]
-                ) from e
+                raise ConfigError("invalid jinja template", path=[option_name]) from e
 
-        display_name_template = None  # type: Optional[Template]
-        if "display_name_template" in config:
-            try:
-                display_name_template = env.from_string(config["display_name_template"])
-            except Exception as e:
-                raise ConfigError(
-                    "invalid jinja template", path=["display_name_template"]
-                ) from e
+        localpart_template = parse_template_config("localpart_template")
+        display_name_template = parse_template_config("display_name_template")
+        email_template = parse_template_config("email_template")
 
         extra_attributes = {}  # type Dict[str, Template]
         if "extra_attributes" in config:
@@ -1210,6 +1280,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
             subject_claim=subject_claim,
             localpart_template=localpart_template,
             display_name_template=display_name_template,
+            email_template=email_template,
             extra_attributes=extra_attributes,
         )
 
@@ -1231,16 +1302,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
             # a usable mxid.
             localpart += str(failures) if failures else ""
 
-        display_name = None  # type: Optional[str]
-        if self._config.display_name_template is not None:
-            display_name = self._config.display_name_template.render(
-                user=userinfo
-            ).strip()
+        def render_template_field(template: Optional[Template]) -> Optional[str]:
+            if template is None:
+                return None
+            return template.render(user=userinfo).strip()
 
-            if display_name == "":
-                display_name = None
+        display_name = render_template_field(self._config.display_name_template)
+        if display_name == "":
+            display_name = None
 
-        return UserAttributeDict(localpart=localpart, display_name=display_name)
+        emails = []  # type: List[str]
+        email = render_template_field(self._config.email_template)
+        if email:
+            emails.append(email)
+
+        return UserAttributeDict(
+            localpart=localpart, display_name=display_name, emails=emails
+        )
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
         extras = {}  # type: Dict[str, str]
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5372753707..059064a4eb 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -197,7 +197,8 @@ class PaginationHandler:
             stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
 
             r = await self.store.get_room_event_before_stream_ordering(
-                room_id, stream_ordering,
+                room_id,
+                stream_ordering,
             )
             if not r:
                 logger.warning(
@@ -223,7 +224,12 @@ class PaginationHandler:
             # the background so that it's not blocking any other operation apart from
             # other purges in the same room.
             run_as_background_process(
-                "_purge_history", self._purge_history, purge_id, room_id, token, True,
+                "_purge_history",
+                self._purge_history,
+                purge_id,
+                room_id,
+                token,
+                True,
             )
 
     def start_purge_history(
@@ -389,7 +395,9 @@ class PaginationHandler:
                         )
 
                 await self.hs.get_federation_handler().maybe_backfill(
-                    room_id, curr_topo, limit=pagin_config.limit,
+                    room_id,
+                    curr_topo,
+                    limit=pagin_config.limit,
                 )
 
             to_room_key = None
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 22d1e9d35c..fb85b19770 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -349,10 +349,13 @@ class PresenceHandler(BasePresenceHandler):
                 [self.user_to_current_state[user_id] for user_id in unpersisted]
             )
 
-    async def _update_states(self, new_states):
+    async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None:
         """Updates presence of users. Sets the appropriate timeouts. Pokes
         the notifier and federation if and only if the changed presence state
         should be sent to clients/servers.
+
+        Args:
+            new_states: The new user presence state updates to process.
         """
         now = self.clock.time_msec()
 
@@ -368,7 +371,7 @@ class PresenceHandler(BasePresenceHandler):
             new_states_dict = {}
             for new_state in new_states:
                 new_states_dict[new_state.user_id] = new_state
-            new_state = new_states_dict.values()
+            new_states = new_states_dict.values()
 
             for new_state in new_states:
                 user_id = new_state.user_id
@@ -635,8 +638,7 @@ class PresenceHandler(BasePresenceHandler):
             self.external_process_last_updated_ms.pop(process_id, None)
 
     async def current_state_for_user(self, user_id):
-        """Get the current presence state for a user.
-        """
+        """Get the current presence state for a user."""
         res = await self.current_state_for_users([user_id])
         return res[user_id]
 
@@ -658,17 +660,6 @@ class PresenceHandler(BasePresenceHandler):
 
         self._push_to_remotes(states)
 
-    async def notify_for_states(self, state, stream_id):
-        parties = await get_interested_parties(self.store, [state])
-        room_ids_to_states, users_to_states = parties
-
-        self.notifier.on_new_event(
-            "presence_key",
-            stream_id,
-            rooms=room_ids_to_states.keys(),
-            users=[UserID.from_string(u) for u in users_to_states],
-        )
-
     def _push_to_remotes(self, states):
         """Sends state updates to remote servers.
 
@@ -678,8 +669,7 @@ class PresenceHandler(BasePresenceHandler):
         self.federation.send_presence(states)
 
     async def incoming_presence(self, origin, content):
-        """Called when we receive a `m.presence` EDU from a remote server.
-        """
+        """Called when we receive a `m.presence` EDU from a remote server."""
         if not self._presence_enabled:
             return
 
@@ -729,8 +719,7 @@ class PresenceHandler(BasePresenceHandler):
             await self._update_states(updates)
 
     async def set_state(self, target_user, state, ignore_status_msg=False):
-        """Set the presence state of the user.
-        """
+        """Set the presence state of the user."""
         status_msg = state.get("status_msg", None)
         presence = state["presence"]
 
@@ -758,8 +747,7 @@ class PresenceHandler(BasePresenceHandler):
         await self._update_states([prev_state.copy_and_replace(**new_fields)])
 
     async def is_visible(self, observed_user, observer_user):
-        """Returns whether a user can see another user's presence.
-        """
+        """Returns whether a user can see another user's presence."""
         observer_room_ids = await self.store.get_rooms_for_user(
             observer_user.to_string()
         )
@@ -953,8 +941,7 @@ class PresenceHandler(BasePresenceHandler):
 
 
 def should_notify(old_state, new_state):
-    """Decides if a presence state change should be sent to interested parties.
-    """
+    """Decides if a presence state change should be sent to interested parties."""
     if old_state == new_state:
         return False
 
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c02b951031..2f62d84fb5 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -207,7 +207,8 @@ class ProfileHandler(BaseHandler):
         # This must be done by the target user himself.
         if by_admin:
             requester = create_requester(
-                target_user, authenticated_entity=requester.authenticated_entity,
+                target_user,
+                authenticated_entity=requester.authenticated_entity,
             )
 
         await self.store.set_profile_displayname(
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index cc21fc2284..6a6c528849 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -49,15 +49,15 @@ class ReceiptsHandler(BaseHandler):
             )
         else:
             hs.get_federation_registry().register_instances_for_edu(
-                "m.receipt", hs.config.worker.writers.receipts,
+                "m.receipt",
+                hs.config.worker.writers.receipts,
             )
 
         self.clock = self.hs.get_clock()
         self.state = hs.get_state_handler()
 
     async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
-        """Called when we receive an EDU of type m.receipt from a remote HS.
-        """
+        """Called when we receive an EDU of type m.receipt from a remote HS."""
         receipts = []
         for room_id, room_values in content.items():
             for receipt_type, users in room_values.items():
@@ -83,8 +83,7 @@ class ReceiptsHandler(BaseHandler):
         await self._handle_new_receipts(receipts)
 
     async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
-        """Takes a list of receipts, stores them and informs the notifier.
-        """
+        """Takes a list of receipts, stores them and informs the notifier."""
         min_batch_id = None  # type: Optional[int]
         max_batch_id = None  # type: Optional[int]
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a2cf0f6f3e..3cda89657e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -14,8 +14,9 @@
 # limitations under the License.
 
 """Contains functions for registering clients."""
+
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -61,8 +62,8 @@ class RegistrationHandler(BaseHandler):
             self._register_device_client = RegisterDeviceReplicationServlet.make_client(
                 hs
             )
-            self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client(
-                hs
+            self._post_registration_client = (
+                ReplicationPostRegisterActionsServlet.make_client(hs)
             )
         else:
             self.device_handler = hs.get_device_handler()
@@ -152,7 +153,7 @@ class RegistrationHandler(BaseHandler):
         user_type: Optional[str] = None,
         default_display_name: Optional[str] = None,
         address: Optional[str] = None,
-        bind_emails: List[str] = [],
+        bind_emails: Iterable[str] = [],
         by_admin: bool = False,
         user_agent_ips: Optional[List[Tuple[str, str]]] = None,
     ) -> str:
@@ -188,12 +189,15 @@ class RegistrationHandler(BaseHandler):
         self.check_registration_ratelimit(address)
 
         result = await self.spam_checker.check_registration_for_spam(
-            threepid, localpart, user_agent_ips or [],
+            threepid,
+            localpart,
+            user_agent_ips or [],
         )
 
         if result == RegistrationBehaviour.DENY:
             logger.info(
-                "Blocked registration of %r", localpart,
+                "Blocked registration of %r",
+                localpart,
             )
             # We return a 429 to make it not obvious that they've been
             # denied.
@@ -202,7 +206,8 @@ class RegistrationHandler(BaseHandler):
         shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
         if shadow_banned:
             logger.info(
-                "Shadow banning registration of %r", localpart,
+                "Shadow banning registration of %r",
+                localpart,
             )
 
         # do not check_auth_blocking if the call is coming through the Admin API
@@ -368,7 +373,9 @@ class RegistrationHandler(BaseHandler):
                     config["room_alias_name"] = room_alias.localpart
 
                     info, _ = await room_creation_handler.create_room(
-                        fake_requester, config=config, ratelimit=False,
+                        fake_requester,
+                        config=config,
+                        ratelimit=False,
                     )
 
                     # If the room does not require an invite, but another user
@@ -693,6 +700,8 @@ class RegistrationHandler(BaseHandler):
             access_token: The access token of the newly logged in device, or
                 None if `inhibit_login` enabled.
         """
+        # TODO: 3pid registration can actually happen on the workers. Consider
+        # refactoring it.
         if self.hs.config.worker_app:
             await self._post_registration_client(
                 user_id=user_id, auth_result=auth_result, access_token=access_token
@@ -750,7 +759,10 @@ class RegistrationHandler(BaseHandler):
             return
 
         await self._auth_handler.add_threepid(
-            user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
+            user_id,
+            threepid["medium"],
+            threepid["address"],
+            threepid["validated_at"],
         )
 
         # And we add an email pusher for them by default, but only
@@ -802,5 +814,8 @@ class RegistrationHandler(BaseHandler):
             raise
 
         await self._auth_handler.add_threepid(
-            user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
+            user_id,
+            threepid["medium"],
+            threepid["address"],
+            threepid["validated_at"],
         )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ee27d99135..a488df10d6 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -38,6 +38,7 @@ from synapse.api.filtering import Filter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
+from synapse.rest.admin._base import assert_user_is_admin
 from synapse.storage.state import StateFilter
 from synapse.types import (
     JsonDict,
@@ -126,6 +127,10 @@ class RoomCreationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
+        self._invite_burst_count = (
+            hs.config.ratelimiting.rc_invites_per_room.burst_count
+        )
+
     async def upgrade_room(
         self, requester: Requester, old_room_id: str, new_version: RoomVersion
     ) -> str:
@@ -193,7 +198,9 @@ class RoomCreationHandler(BaseHandler):
         if r is None:
             raise NotFoundError("Unknown room id %s" % (old_room_id,))
         new_room_id = await self._generate_room_id(
-            creator_id=user_id, is_public=r["is_public"], room_version=new_version,
+            creator_id=user_id,
+            is_public=r["is_public"],
+            room_version=new_version,
         )
 
         logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
@@ -231,7 +238,9 @@ class RoomCreationHandler(BaseHandler):
 
         # now send the tombstone
         await self.event_creation_handler.handle_new_client_event(
-            requester=requester, event=tombstone_event, context=tombstone_context,
+            requester=requester,
+            event=tombstone_event,
+            context=tombstone_context,
         )
 
         old_room_state = await tombstone_context.get_current_state_ids()
@@ -252,7 +261,10 @@ class RoomCreationHandler(BaseHandler):
         # finally, shut down the PLs in the old room, and update them in the new
         # room.
         await self._update_upgraded_room_pls(
-            requester, old_room_id, new_room_id, old_room_state,
+            requester,
+            old_room_id,
+            new_room_id,
+            old_room_state,
         )
 
         return new_room_id
@@ -420,17 +432,20 @@ class RoomCreationHandler(BaseHandler):
 
         # Copy over user power levels now as this will not be possible with >100PL users once
         # the room has been created
-
         # Calculate the minimum power level needed to clone the room
         event_power_levels = power_levels.get("events", {})
-        state_default = power_levels.get("state_default", 0)
-        ban = power_levels.get("ban")
+        state_default = power_levels.get("state_default", 50)
+        ban = power_levels.get("ban", 50)
         needed_power_level = max(state_default, ban, max(event_power_levels.values()))
 
+        # Get the user's current power level, this matches the logic in get_user_power_level,
+        # but without the entire state map.
+        user_power_levels = power_levels.setdefault("users", {})
+        users_default = power_levels.get("users_default", 0)
+        current_power_level = user_power_levels.get(user_id, users_default)
         # Raise the requester's power level in the new room if necessary
-        current_power_level = power_levels["users"][user_id]
         if current_power_level < needed_power_level:
-            power_levels["users"][user_id] = needed_power_level
+            user_power_levels[user_id] = needed_power_level
 
         await self._send_events_for_new_room(
             requester,
@@ -562,7 +577,7 @@ class RoomCreationHandler(BaseHandler):
         ratelimit: bool = True,
         creator_join_profile: Optional[JsonDict] = None,
     ) -> Tuple[dict, int]:
-        """ Creates a new room.
+        """Creates a new room.
 
         Args:
             requester:
@@ -662,6 +677,9 @@ class RoomCreationHandler(BaseHandler):
             invite_3pid_list = []
             invite_list = []
 
+        if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
+            raise SynapseError(400, "Cannot invite so many users at once")
+
         await self.event_creation_handler.assert_accepted_privacy_policy(requester)
 
         power_level_content_override = config.get("power_level_content_override")
@@ -680,7 +698,9 @@ class RoomCreationHandler(BaseHandler):
         is_public = visibility == "public"
 
         room_id = await self._generate_room_id(
-            creator_id=user_id, is_public=is_public, room_version=room_version,
+            creator_id=user_id,
+            is_public=is_public,
+            room_version=room_version,
         )
 
         # Check whether this visibility value is blocked by a third party module
@@ -821,7 +841,7 @@ class RoomCreationHandler(BaseHandler):
         if room_alias:
             result["room_alias"] = room_alias.to_string()
 
-        # Always wait for room creation to progate before returning
+        # Always wait for room creation to propagate before returning
         await self._replication.wait_for_stream_position(
             self.hs.config.worker.events_shard_config.get_instance(room_id),
             "events",
@@ -873,7 +893,10 @@ class RoomCreationHandler(BaseHandler):
                 _,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
-                creator, event, ratelimit=False, ignore_shadow_ban=True,
+                creator,
+                event,
+                ratelimit=False,
+                ignore_shadow_ban=True,
             )
             return last_stream_id
 
@@ -973,7 +996,10 @@ class RoomCreationHandler(BaseHandler):
         return last_sent_stream_id
 
     async def _generate_room_id(
-        self, creator_id: str, is_public: bool, room_version: RoomVersion,
+        self,
+        creator_id: str,
+        is_public: bool,
+        room_version: RoomVersion,
     ):
         # autogen room IDs and try to create it. We may clash, so just
         # try a few times till one goes through, giving up eventually.
@@ -997,41 +1023,51 @@ class RoomCreationHandler(BaseHandler):
 class RoomContextHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
+        self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
     async def get_event_context(
         self,
-        user: UserID,
+        requester: Requester,
         room_id: str,
         event_id: str,
         limit: int,
         event_filter: Optional[Filter],
+        use_admin_priviledge: bool = False,
     ) -> Optional[JsonDict]:
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
         Args:
-            user
+            requester
             room_id
             event_id
             limit: The maximum number of events to return in total
                 (excluding state).
             event_filter: the filter to apply to the events returned
                 (excluding the target event_id)
-
+            use_admin_priviledge: if `True`, return all events, regardless
+                of whether `user` has access to them. To be used **ONLY**
+                from the admin API.
         Returns:
             dict, or None if the event isn't found
         """
+        user = requester.user
+        if use_admin_priviledge:
+            await assert_user_is_admin(self.auth, requester.user)
+
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
         users = await self.store.get_users_in_room(room_id)
         is_peeking = user.to_string() not in users
 
-        def filter_evts(events):
-            return filter_events_for_client(
+        async def filter_evts(events):
+            if use_admin_priviledge:
+                return events
+            return await filter_events_for_client(
                 self.storage, user.to_string(), events, is_peeking=is_peeking
             )
 
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e001e418f9..1660921306 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
 
+        self._invites_per_room_limiter = Ratelimiter(
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
+            burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
+        )
+        self._invites_per_user_limiter = Ratelimiter(
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
+            burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
+        )
+
         # This is only used to get at ratelimit function, and
         # maybe_kick_guest_users. It's fine there are multiple of these as
         # it doesn't store state.
@@ -144,6 +155,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         """
         raise NotImplementedError()
 
+    def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
+        """Ratelimit invites by room and by target user.
+
+        If room ID is missing then we just rate limit by target user.
+        """
+        if room_id:
+            self._invites_per_room_limiter.ratelimit(room_id)
+
+        self._invites_per_user_limiter.ratelimit(invitee_user_id)
+
     async def _local_membership_update(
         self,
         requester: Requester,
@@ -170,7 +191,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # do it up front for efficiency.)
         if txn_id and requester.access_token_id:
             existing_event_id = await self.store.get_event_id_from_transaction_id(
-                room_id, requester.user.to_string(), requester.access_token_id, txn_id,
+                room_id,
+                requester.user.to_string(),
+                requester.access_token_id,
+                txn_id,
             )
             if existing_event_id:
                 event_pos = await self.store.get_position_for_event(existing_event_id)
@@ -217,7 +241,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     )
 
         result_event = await self.event_creation_handler.handle_new_client_event(
-            requester, event, context, extra_users=[target], ratelimit=ratelimit,
+            requester,
+            event,
+            context,
+            extra_users=[target],
+            ratelimit=ratelimit,
         )
 
         if event.membership == Membership.LEAVE:
@@ -387,8 +415,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 raise SynapseError(403, "This room has been blocked on this server")
 
         if effective_membership_state == Membership.INVITE:
+            target_id = target.to_string()
+            if ratelimit:
+                # Don't ratelimit application services.
+                if not requester.app_service or requester.app_service.is_rate_limited():
+                    self.ratelimit_invite(room_id, target_id)
+
             # block any attempts to invite the server notices mxid
-            if target.to_string() == self._server_notices_mxid:
+            if target_id == self._server_notices_mxid:
                 raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
             block_invite = False
@@ -412,7 +446,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     block_invite = True
 
                 if not await self.spam_checker.user_may_invite(
-                    requester.user.to_string(), target.to_string(), room_id
+                    requester.user.to_string(), target_id, room_id
                 ):
                     logger.info("Blocking invite due to spam checker")
                     block_invite = True
@@ -556,7 +590,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     # send the rejection to the inviter's HS (with fallback to
                     # local event)
                     return await self.remote_reject_invite(
-                        invite.event_id, txn_id, requester, content,
+                        invite.event_id,
+                        txn_id,
+                        requester,
+                        content,
                     )
 
                 # the inviter was on our server, but has now left. Carry on
@@ -1029,8 +1066,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         user: UserID,
         content: dict,
     ) -> Tuple[str, int]:
-        """Implements RoomMemberHandler._remote_join
-        """
+        """Implements RoomMemberHandler._remote_join"""
         # filter ourselves out of remote_room_hosts: do_invite_join ignores it
         # and if it is the only entry we'd like to return a 404 rather than a
         # 500.
@@ -1184,7 +1220,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         event.internal_metadata.out_of_band_membership = True
 
         result_event = await self.event_creation_handler.handle_new_client_event(
-            requester, event, context, extra_users=[UserID.from_string(target_user)],
+            requester,
+            event,
+            context,
+            extra_users=[UserID.from_string(target_user)],
         )
         # we know it was persisted, so must have a stream ordering
         assert result_event.internal_metadata.stream_ordering
@@ -1192,8 +1231,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         return result_event.event_id, result_event.internal_metadata.stream_ordering
 
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
-        """Implements RoomMemberHandler._user_left_room
-        """
+        """Implements RoomMemberHandler._user_left_room"""
         user_left_room(self.distributor, target, room_id)
 
     async def forget(self, user: UserID, room_id: str) -> None:
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index f2e88f6a5b..108730a7a1 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -44,8 +44,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         user: UserID,
         content: dict,
     ) -> Tuple[str, int]:
-        """Implements RoomMemberHandler._remote_join
-        """
+        """Implements RoomMemberHandler._remote_join"""
         if len(remote_room_hosts) == 0:
             raise SynapseError(404, "No known servers")
 
@@ -80,8 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         return ret["event_id"], ret["stream_id"]
 
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
-        """Implements RoomMemberHandler._user_left_room
-        """
+        """Implements RoomMemberHandler._user_left_room"""
         await self._notify_change_client(
             user_id=target.to_string(), room_id=room_id, change="left"
         )
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 38461cf79d..a9645b77d8 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -23,7 +23,6 @@ from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
-from synapse.config.saml2_config import SamlAttributeRequirement
 from synapse.handlers._base import BaseHandler
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
@@ -78,9 +77,10 @@ class SamlHandler(BaseHandler):
         # user-facing name of this auth provider
         self.idp_name = "SAML"
 
-        # we do not currently support icons for SAML auth, but this is required by
+        # we do not currently support icons/brands for SAML auth, but this is required by
         # the SsoIdentityProvider protocol type.
         self.idp_icon = None
+        self.idp_brand = None
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
@@ -121,7 +121,8 @@ class SamlHandler(BaseHandler):
 
         now = self.clock.time_msec()
         self._outstanding_requests_dict[reqid] = Saml2SessionData(
-            creation_time=now, ui_auth_session_id=ui_auth_session_id,
+            creation_time=now,
+            ui_auth_session_id=ui_auth_session_id,
         )
 
         for key, value in info["headers"]:
@@ -132,7 +133,7 @@ class SamlHandler(BaseHandler):
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
     async def handle_saml_response(self, request: SynapseRequest) -> None:
-        """Handle an incoming request to /_matrix/saml2/authn_response
+        """Handle an incoming request to /_synapse/client/saml2/authn_response
 
         Args:
             request: the incoming request from the browser. We'll
@@ -238,12 +239,10 @@ class SamlHandler(BaseHandler):
 
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
-        for requirement in self._saml2_attribute_requirements:
-            if not _check_attribute_requirement(saml2_auth.ava, requirement):
-                self._sso_handler.render_error(
-                    request, "unauthorised", "You are not authorised to log in here."
-                )
-                return
+        if not self._sso_handler.check_required_attributes(
+            request, saml2_auth.ava, self._saml2_attribute_requirements
+        ):
+            return
 
         # Call the mapper to register/login the user
         try:
@@ -372,21 +371,6 @@ class SamlHandler(BaseHandler):
             del self._outstanding_requests_dict[reqid]
 
 
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
-    values = ava.get(req.attribute, [])
-    for v in values:
-        if v == req.value:
-            return True
-
-    logger.info(
-        "SAML2 attribute %s did not match required value '%s' (was '%s')",
-        req.attribute,
-        req.value,
-        values,
-    )
-    return False
-
-
 DOT_REPLACE_PATTERN = re.compile(
     ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
 )
@@ -467,7 +451,8 @@ class DefaultSamlMappingProvider:
             mxid_source = saml_response.ava[self._mxid_source_attribute][0]
         except KeyError:
             logger.warning(
-                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+                "SAML2 response lacks a '%s' attestation",
+                self._mxid_source_attribute,
             )
             raise SynapseError(
                 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 66f1bbcfc4..94062e79cb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,23 +15,28 @@
 
 import itertools
 import logging
-from typing import Iterable
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
 
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
+from synapse.events import EventBase
 from synapse.storage.state import StateFilter
+from synapse.types import JsonDict, UserID
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SearchHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
 
         return historical_room_ids
 
-    async def search(self, user, content, batch=None):
+    async def search(
+        self, user: UserID, content: JsonDict, batch: Optional[str] = None
+    ) -> JsonDict:
         """Performs a full text search for a user.
 
         Args:
-            user (UserID)
-            content (dict): Search parameters
-            batch (str): The next_batch parameter. Used for pagination.
+            user
+            content: Search parameters
+            batch: The next_batch parameter. Used for pagination.
 
         Returns:
             dict to be returned to the client with results of search
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
         # If doing a subset of all rooms seearch, check if any of the rooms
         # are from an upgraded room, and search their contents as well
         if search_filter.rooms:
-            historical_room_ids = []
+            historical_room_ids = []  # type: List[str]
             for room_id in search_filter.rooms:
                 # Add any previous rooms to the search if they exist
                 ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
 
         rank_map = {}  # event_id -> rank of event
         allowed_events = []
-        room_groups = {}  # Holds result of grouping by room, if applicable
-        sender_group = {}  # Holds result of grouping by sender, if applicable
+        # Holds result of grouping by room, if applicable
+        room_groups = {}  # type: Dict[str, JsonDict]
+        # Holds result of grouping by sender, if applicable
+        sender_group = {}  # type: Dict[str, JsonDict]
 
         # Holds the next_batch for the entire result set if one of those exists
         global_next_batch = None
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
                 s["results"].append(e.event_id)
 
         elif order_by == "recent":
-            room_events = []
+            room_events = []  # type: List[EventBase]
             i = 0
 
             pagination_token = batch_token
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
 
         state_results = {}
         if include_state:
-            rooms = {e.room_id for e in allowed_events}
-            for room_id in rooms:
+            for room_id in {e.room_id for e in allowed_events}:
                 state = await self.state_handler.get_current_state(room_id)
                 state_results[room_id] = list(state.values())
 
-            state_results.values()
-
         # We're now about to serialize the events. We should not make any
         # blocking calls after this. Otherwise the 'age' will be wrong
 
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
 
         if state_results:
             s = {}
-            for room_id, state in state_results.items():
+            for room_id, state_events in state_results.items():
                 s[room_id] = await self._event_serializer.serialize_events(
-                    state, time_now
+                    state_events, time_now
                 )
 
             rooms_cat_res["state"] = s
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a5d67f828f..84af2dde7e 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -13,24 +13,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.types import Requester
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SetPasswordHandler(BaseHandler):
     """Handler which deals with changing user account passwords"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
-        self._password_policy_handler = hs.get_password_policy_handler()
 
     async def set_password(
         self,
@@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler):
         password_hash: str,
         logout_devices: bool,
         requester: Optional[Requester] = None,
-    ):
+    ) -> None:
         if not self.hs.config.password_localdb_enabled:
             raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d493327a10..514b1f69d8 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -14,21 +14,34 @@
 # limitations under the License.
 import abc
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Set,
+)
 from urllib.parse import urlencode
 
 import attr
 from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
+from twisted.web.iweb import IRequest
 
 from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.config.sso import SsoAttributeRequirement
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
-from synapse.http.server import respond_with_html
+from synapse.http.server import respond_with_html, respond_with_redirect
 from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
+from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters
 from synapse.util.async_helpers import Linearizer
 from synapse.util.stringutils import random_string
 
@@ -80,6 +93,11 @@ class SsoIdentityProvider(Protocol):
         """Optional MXC URI for user-facing icon"""
         return None
 
+    @property
+    def idp_brand(self) -> Optional[str]:
+        """Optional branding identifier"""
+        return None
+
     @abc.abstractmethod
     async def handle_redirect_request(
         self,
@@ -109,7 +127,7 @@ class UserAttributes:
     # enter one.
     localpart = attr.ib(type=Optional[str])
     display_name = attr.ib(type=Optional[str], default=None)
-    emails = attr.ib(type=List[str], default=attr.Factory(list))
+    emails = attr.ib(type=Collection[str], default=attr.Factory(list))
 
 
 @attr.s(slots=True)
@@ -124,7 +142,7 @@ class UsernameMappingSession:
 
     # attributes returned by the ID mapper
     display_name = attr.ib(type=Optional[str])
-    emails = attr.ib(type=List[str])
+    emails = attr.ib(type=Collection[str])
 
     # An optional dictionary of extra attributes to be provided to the client in the
     # login response.
@@ -136,6 +154,12 @@ class UsernameMappingSession:
     # expiry time for the session, in milliseconds
     expiry_time_ms = attr.ib(type=int)
 
+    # choices made by the user
+    chosen_localpart = attr.ib(type=Optional[str], default=None)
+    use_display_name = attr.ib(type=bool, default=True)
+    emails_to_use = attr.ib(type=Collection[str], default=())
+    terms_accepted_version = attr.ib(type=Optional[str], default=None)
+
 
 # the HTTP cookie used to track the mapping session id
 USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
@@ -170,6 +194,8 @@ class SsoHandler:
         # map from idp_id to SsoIdentityProvider
         self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
 
+        self._consent_at_registration = hs.config.consent.user_consent_at_registration
+
     def register_identity_provider(self, p: SsoIdentityProvider):
         p_id = p.idp_id
         assert p_id not in self._identity_providers
@@ -235,7 +261,10 @@ class SsoHandler:
         respond_with_html(request, code, html)
 
     async def handle_redirect_request(
-        self, request: SynapseRequest, client_redirect_url: bytes,
+        self,
+        request: SynapseRequest,
+        client_redirect_url: bytes,
+        idp_id: Optional[str],
     ) -> str:
         """Handle a request to /login/sso/redirect
 
@@ -243,6 +272,7 @@ class SsoHandler:
             request: incoming HTTP request
             client_redirect_url: the URL that we should redirect the
                 client to after login.
+            idp_id: optional identity provider chosen by the client
 
         Returns:
              the URI to redirect to
@@ -252,10 +282,19 @@ class SsoHandler:
                 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
             )
 
+        # if the client chose an IdP, use that
+        idp = None  # type: Optional[SsoIdentityProvider]
+        if idp_id:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                raise NotFoundError("Unknown identity provider")
+
         # if we only have one auth provider, redirect to it directly
-        if len(self._identity_providers) == 1:
-            ap = next(iter(self._identity_providers.values()))
-            return await ap.handle_redirect_request(request, client_redirect_url)
+        elif len(self._identity_providers) == 1:
+            idp = next(iter(self._identity_providers.values()))
+
+        if idp:
+            return await idp.handle_redirect_request(request, client_redirect_url)
 
         # otherwise, redirect to the IDP picker
         return "/_synapse/client/pick_idp?" + urlencode(
@@ -288,7 +327,8 @@ class SsoHandler:
 
         # Check if we already have a mapping for this user.
         previously_registered_user_id = await self._store.get_user_by_external_id(
-            auth_provider_id, remote_user_id,
+            auth_provider_id,
+            remote_user_id,
         )
 
         # A match was found, return the user ID.
@@ -369,13 +409,16 @@ class SsoHandler:
                 to an additional page. (e.g. to prompt for more information)
 
         """
+        new_user = False
+
         # grab a lock while we try to find a mapping for this user. This seems...
         # optimistic, especially for implementations that end up redirecting to
         # interstitial pages.
         with await self._mapping_lock.queue(auth_provider_id):
             # first of all, check if we already have a mapping for this user
             user_id = await self.get_sso_user_by_remote_user_id(
-                auth_provider_id, remote_user_id,
+                auth_provider_id,
+                remote_user_id,
             )
 
             # Check for grandfathering of users.
@@ -409,13 +452,19 @@ class SsoHandler:
                     get_request_user_agent(request),
                     request.getClientIP(),
                 )
+                new_user = True
 
         await self._auth_handler.complete_sso_login(
-            user_id, request, client_redirect_url, extra_login_attributes
+            user_id,
+            request,
+            client_redirect_url,
+            extra_login_attributes,
+            new_user=new_user,
         )
 
     async def _call_attribute_mapper(
-        self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+        self,
+        sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
     ) -> UserAttributes:
         """Call the attribute mapper function in a loop, until we get a unique userid"""
         for i in range(self._MAP_USERNAME_RETRIES):
@@ -501,7 +550,7 @@ class SsoHandler:
         logger.info("Recorded registration session id %s", session_id)
 
         # Set the cookie and redirect to the username picker
-        e = RedirectException(b"/_synapse/client/pick_username")
+        e = RedirectException(b"/_synapse/client/pick_username/account_details")
         e.cookies.append(
             b"%s=%s; path=/"
             % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
@@ -586,7 +635,8 @@ class SsoHandler:
         """
 
         user_id = await self.get_sso_user_by_remote_user_id(
-            auth_provider_id, remote_user_id,
+            auth_provider_id,
+            remote_user_id,
         )
 
         user_id_to_verify = await self._auth_handler.get_session_data(
@@ -625,12 +675,34 @@ class SsoHandler:
 
         # render an error page.
         html = self._bad_user_template.render(
-            server_name=self._server_name, user_id_to_verify=user_id_to_verify,
+            server_name=self._server_name,
+            user_id_to_verify=user_id_to_verify,
         )
         respond_with_html(request, 200, html)
 
+    def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
+        """Look up the given username mapping session
+
+        If it is not found, raises a SynapseError with an http code of 400
+
+        Args:
+            session_id: session to look up
+        Returns:
+            active mapping session
+        Raises:
+            SynapseError if the session is not found/has expired
+        """
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if session:
+            return session
+        logger.info("Couldn't find session id %s", session_id)
+        raise SynapseError(400, "unknown session")
+
     async def check_username_availability(
-        self, localpart: str, session_id: str,
+        self,
+        localpart: str,
+        session_id: str,
     ) -> bool:
         """Handle an "is username available" callback check
 
@@ -645,12 +717,7 @@ class SsoHandler:
 
         # make sure that there is a valid mapping session, to stop people dictionary-
         # scanning for accounts
-
-        self._expire_old_sessions()
-        session = self._username_mapping_sessions.get(session_id)
-        if not session:
-            logger.info("Couldn't find session id %s", session_id)
-            raise SynapseError(400, "unknown session")
+        self.get_mapping_session(session_id)
 
         logger.info(
             "[session %s] Checking for availability of username %s",
@@ -667,7 +734,12 @@ class SsoHandler:
         return not user_infos
 
     async def handle_submit_username_request(
-        self, request: SynapseRequest, localpart: str, session_id: str
+        self,
+        request: SynapseRequest,
+        session_id: str,
+        localpart: str,
+        use_display_name: bool,
+        emails_to_use: Iterable[str],
     ) -> None:
         """Handle a request to the username-picker 'submit' endpoint
 
@@ -677,21 +749,104 @@ class SsoHandler:
             request: HTTP request
             localpart: localpart requested by the user
             session_id: ID of the username mapping session, extracted from a cookie
+            use_display_name: whether the user wants to use the suggested display name
+            emails_to_use: emails that the user would like to use
         """
-        self._expire_old_sessions()
-        session = self._username_mapping_sessions.get(session_id)
-        if not session:
-            logger.info("Couldn't find session id %s", session_id)
-            raise SynapseError(400, "unknown session")
+        try:
+            session = self.get_mapping_session(session_id)
+        except SynapseError as e:
+            self.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        # update the session with the user's choices
+        session.chosen_localpart = localpart
+        session.use_display_name = use_display_name
+
+        emails_from_idp = set(session.emails)
+        filtered_emails = set()  # type: Set[str]
+
+        # we iterate through the list rather than just building a set conjunction, so
+        # that we can log attempts to use unknown addresses
+        for email in emails_to_use:
+            if email in emails_from_idp:
+                filtered_emails.add(email)
+            else:
+                logger.warning(
+                    "[session %s] ignoring user request to use unknown email address %r",
+                    session_id,
+                    email,
+                )
+        session.emails_to_use = filtered_emails
+
+        # we may now need to collect consent from the user, in which case, redirect
+        # to the consent-extraction-unit
+        if self._consent_at_registration:
+            redirect_url = b"/_synapse/client/new_user_consent"
+
+        # otherwise, redirect to the completion page
+        else:
+            redirect_url = b"/_synapse/client/sso_register"
+
+        respond_with_redirect(request, redirect_url)
+
+    async def handle_terms_accepted(
+        self, request: Request, session_id: str, terms_version: str
+    ):
+        """Handle a request to the new-user 'consent' endpoint
+
+        Will serve an HTTP response to the request.
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+            terms_version: the version of the terms which the user viewed and consented
+                to
+        """
+        logger.info(
+            "[session %s] User consented to terms version %s",
+            session_id,
+            terms_version,
+        )
+        try:
+            session = self.get_mapping_session(session_id)
+        except SynapseError as e:
+            self.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        session.terms_accepted_version = terms_version
+
+        # we're done; now we can register the user
+        respond_with_redirect(request, b"/_synapse/client/sso_register")
+
+    async def register_sso_user(self, request: Request, session_id: str) -> None:
+        """Called once we have all the info we need to register a new user.
 
-        logger.info("[session %s] Registering localpart %s", session_id, localpart)
+        Does so and serves an HTTP response
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+        """
+        try:
+            session = self.get_mapping_session(session_id)
+        except SynapseError as e:
+            self.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        logger.info(
+            "[session %s] Registering localpart %s",
+            session_id,
+            session.chosen_localpart,
+        )
 
         attributes = UserAttributes(
-            localpart=localpart,
-            display_name=session.display_name,
-            emails=session.emails,
+            localpart=session.chosen_localpart,
+            emails=session.emails_to_use,
         )
 
+        if session.use_display_name:
+            attributes.display_name = session.display_name
+
         # the following will raise a 400 error if the username has been taken in the
         # meantime.
         user_id = await self._register_mapped_user(
@@ -702,7 +857,12 @@ class SsoHandler:
             request.getClientIP(),
         )
 
-        logger.info("[session %s] Registered userid %s", session_id, user_id)
+        logger.info(
+            "[session %s] Registered userid %s with attributes %s",
+            session_id,
+            user_id,
+            attributes,
+        )
 
         # delete the mapping session and the cookie
         del self._username_mapping_sessions[session_id]
@@ -715,11 +875,21 @@ class SsoHandler:
             path=b"/",
         )
 
+        auth_result = {}
+        if session.terms_accepted_version:
+            # TODO: make this less awful.
+            auth_result[LoginType.TERMS] = True
+
+        await self._registration_handler.post_registration_actions(
+            user_id, auth_result, access_token=None
+        )
+
         await self._auth_handler.complete_sso_login(
             user_id,
             request,
             session.client_redirect_url,
             session.extra_login_attributes,
+            new_user=True,
         )
 
     def _expire_old_sessions(self):
@@ -733,3 +903,82 @@ class SsoHandler:
         for session_id in to_expire:
             logger.info("Expiring mapping session %s", session_id)
             del self._username_mapping_sessions[session_id]
+
+    def check_required_attributes(
+        self,
+        request: SynapseRequest,
+        attributes: Mapping[str, List[Any]],
+        attribute_requirements: Iterable[SsoAttributeRequirement],
+    ) -> bool:
+        """
+        Confirm that the required attributes were present in the SSO response.
+
+        If all requirements are met, this will return True.
+
+        If any requirement is not met, then the request will be finalized by
+        showing an error page to the user and False will be returned.
+
+        Args:
+            request: The request to (potentially) respond to.
+            attributes: The attributes from the SSO IdP.
+            attribute_requirements: The requirements that attributes must meet.
+
+        Returns:
+            True if all requirements are met, False if any attribute fails to
+            meet the requirement.
+
+        """
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
+        for requirement in attribute_requirements:
+            if not _check_attribute_requirement(attributes, requirement):
+                self.render_error(
+                    request, "unauthorised", "You are not authorised to log in here."
+                )
+                return False
+
+        return True
+
+
+def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
+    """Extract the session ID from the cookie
+
+    Raises a SynapseError if the cookie isn't found
+    """
+    session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+    if not session_id:
+        raise SynapseError(code=400, msg="missing session_id")
+    return session_id.decode("ascii", errors="replace")
+
+
+def _check_attribute_requirement(
+    attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
+) -> bool:
+    """Check if SSO attributes meet the proper requirements.
+
+    Args:
+        attributes: A mapping of attributes to an iterable of one or more values.
+        requirement: The configured requirement to check.
+
+    Returns:
+        True if the required attribute was found and had a proper value.
+    """
+    if req.attribute not in attributes:
+        logger.info("SSO attribute missing: %s", req.attribute)
+        return False
+
+    # If the requirement is None, the attribute existing is enough.
+    if req.value is None:
+        return True
+
+    values = attributes[req.attribute]
+    if req.value in values:
+        return True
+
+    logger.info(
+        "SSO attribute %s did not match required value '%s' (was '%s')",
+        req.attribute,
+        req.value,
+        values,
+    )
+    return False
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index fb4f70e8e2..b3f9875358 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -14,15 +14,25 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
 class StateDeltasHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
-    async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+    async def _get_key_change(
+        self,
+        prev_event_id: Optional[str],
+        event_id: Optional[str],
+        key_name: str,
+        public_value: str,
+    ) -> Optional[bool]:
         """Given two events check if the `key_name` field in content changed
         from not matching `public_value` to doing so.
 
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index dc62b21c06..924281144c 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -12,13 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
+
+from typing_extensions import Counter as CounterType
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -31,7 +37,7 @@ class StatsHandler:
     Heavily derived from UserDirectoryHandler
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
@@ -44,7 +50,7 @@ class StatsHandler:
         self.stats_enabled = hs.config.stats_enabled
 
         # The current position in the current_state_delta stream
-        self.pos = None
+        self.pos = None  # type: Optional[int]
 
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
@@ -56,9 +62,8 @@ class StatsHandler:
             # we start populating stats
             self.clock.call_later(0, self.notify_new_event)
 
-    def notify_new_event(self):
-        """Called when there may be more deltas to process
-        """
+    def notify_new_event(self) -> None:
+        """Called when there may be more deltas to process"""
         if not self.stats_enabled or self._is_processing:
             return
 
@@ -72,7 +77,7 @@ class StatsHandler:
 
         run_as_background_process("stats.notify_new_event", process)
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         # If self.pos is None then means we haven't fetched it from DB
         if self.pos is None:
             self.pos = await self.store.get_stats_positions()
@@ -110,10 +115,10 @@ class StatsHandler:
             )
 
             for room_id, fields in room_count.items():
-                room_deltas.setdefault(room_id, {}).update(fields)
+                room_deltas.setdefault(room_id, Counter()).update(fields)
 
             for user_id, fields in user_count.items():
-                user_deltas.setdefault(user_id, {}).update(fields)
+                user_deltas.setdefault(user_id, Counter()).update(fields)
 
             logger.debug("room_deltas: %s", room_deltas)
             logger.debug("user_deltas: %s", user_deltas)
@@ -131,19 +136,20 @@ class StatsHandler:
 
             self.pos = max_pos
 
-    async def _handle_deltas(self, deltas):
+    async def _handle_deltas(
+        self, deltas: Iterable[JsonDict]
+    ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
         """Called with the state deltas to process
 
         Returns:
-            tuple[dict[str, Counter], dict[str, counter]]
             Two dicts: the room deltas and the user deltas,
             mapping from room/user ID to changes in the various fields.
         """
 
-        room_to_stats_deltas = {}
-        user_to_stats_deltas = {}
+        room_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
+        user_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
 
-        room_to_state_updates = {}
+        room_to_state_updates = {}  # type: Dict[str, Dict[str, Any]]
 
         for delta in deltas:
             typ = delta["type"]
@@ -173,7 +179,7 @@ class StatsHandler:
                 )
                 continue
 
-            event_content = {}
+            event_content = {}  # type: JsonDict
 
             sender = None
             if event_id is not None:
@@ -257,13 +263,13 @@ class StatsHandler:
                     )
 
                     if has_changed_joinedness:
-                        delta = +1 if membership == Membership.JOIN else -1
+                        membership_delta = +1 if membership == Membership.JOIN else -1
 
                         user_to_stats_deltas.setdefault(user_id, Counter())[
                             "joined_rooms"
-                        ] += delta
+                        ] += membership_delta
 
-                        room_stats_delta["local_users_in_room"] += delta
+                        room_stats_delta["local_users_in_room"] += membership_delta
 
             elif typ == EventTypes.Create:
                 room_state["is_federatable"] = (
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5c7590f38e..4e8ed7b33f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -339,8 +339,7 @@ class SyncHandler:
         since_token: Optional[StreamToken] = None,
         full_state: bool = False,
     ) -> SyncResult:
-        """Get the sync for client needed to match what the server has now.
-        """
+        """Get the sync for client needed to match what the server has now."""
         return await self.generate_sync_result(sync_config, since_token, full_state)
 
     async def push_rules_for_user(self, user: UserID) -> JsonDict:
@@ -564,7 +563,7 @@ class SyncHandler:
         stream_position: StreamToken,
         state_filter: StateFilter = StateFilter.all(),
     ) -> StateMap[str]:
-        """ Get the room state at a particular stream position
+        """Get the room state at a particular stream position
 
         Args:
             room_id: room for which to get state
@@ -598,7 +597,7 @@ class SyncHandler:
         state: MutableStateMap[EventBase],
         now_token: StreamToken,
     ) -> Optional[JsonDict]:
-        """ Works out a room summary block for this room, summarising the number
+        """Works out a room summary block for this room, summarising the number
         of joined members in the room, and providing the 'hero' members if the
         room has no name so clients can consistently name rooms.  Also adds
         state events to 'state' if needed to describe the heroes.
@@ -743,7 +742,7 @@ class SyncHandler:
         now_token: StreamToken,
         full_state: bool,
     ) -> MutableStateMap[EventBase]:
-        """ Works out the difference in state between the start of the timeline
+        """Works out the difference in state between the start of the timeline
         and the previous sync.
 
         Args:
@@ -820,8 +819,10 @@ class SyncHandler:
                 )
             elif batch.limited:
                 if batch:
-                    state_at_timeline_start = await self.state_store.get_state_ids_for_event(
-                        batch.events[0].event_id, state_filter=state_filter
+                    state_at_timeline_start = (
+                        await self.state_store.get_state_ids_for_event(
+                            batch.events[0].event_id, state_filter=state_filter
+                        )
                     )
                 else:
                     # We can get here if the user has ignored the senders of all
@@ -955,8 +956,7 @@ class SyncHandler:
         since_token: Optional[StreamToken] = None,
         full_state: bool = False,
     ) -> SyncResult:
-        """Generates a sync result.
-        """
+        """Generates a sync result."""
         # NB: The now_token gets changed by some of the generate_sync_* methods,
         # this is due to some of the underlying streams not supporting the ability
         # to query up to a given point.
@@ -1030,8 +1030,8 @@ class SyncHandler:
             one_time_key_counts = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
-            unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
-                user_id, device_id
+            unused_fallback_key_types = (
+                await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
             )
 
         logger.debug("Fetching group data")
@@ -1176,8 +1176,10 @@ class SyncHandler:
             # weren't in the previous sync *or* they left and rejoined.
             users_that_have_changed.update(newly_joined_or_invited_users)
 
-            user_signatures_changed = await self.store.get_users_whose_signatures_changed(
-                user_id, since_token.device_list_key
+            user_signatures_changed = (
+                await self.store.get_users_whose_signatures_changed(
+                    user_id, since_token.device_list_key
+                )
             )
             users_that_have_changed.update(user_signatures_changed)
 
@@ -1393,8 +1395,10 @@ class SyncHandler:
                         logger.debug("no-oping sync")
                         return set(), set(), set(), set()
 
-        ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
-            AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+        ignored_account_data = (
+            await self.store.get_global_account_data_by_type_for_user(
+                AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+            )
         )
 
         # If there is ignored users account data and it matches the proper type,
@@ -1499,8 +1503,7 @@ class SyncHandler:
     async def _get_rooms_changed(
         self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
     ) -> _RoomChanges:
-        """Gets the the changes that have happened since the last sync.
-        """
+        """Gets the the changes that have happened since the last sync."""
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
         now_token = sync_result_builder.now_token
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..096d199f4c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@
 import logging
 import random
 from collections import namedtuple
-from typing import TYPE_CHECKING, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.streams import TypingStream
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -61,23 +61,23 @@ class FollowerTypingHandler:
 
         if hs.config.worker.writers.typing != hs.get_instance_name():
             hs.get_federation_registry().register_instance_for_edu(
-                "m.typing", hs.config.worker.writers.typing,
+                "m.typing",
+                hs.config.worker.writers.typing,
             )
 
         # map room IDs to serial numbers
-        self._room_serials = {}
+        self._room_serials = {}  # type: Dict[str, int]
         # map room IDs to sets of users currently typing
-        self._room_typing = {}
+        self._room_typing = {}  # type: Dict[str, Set[str]]
 
-        self._member_last_federation_poke = {}
+        self._member_last_federation_poke = {}  # type: Dict[RoomMember, int]
         self.wheel_timer = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
 
         self.clock.looping_call(self._handle_timeouts, 5000)
 
-    def _reset(self):
-        """Reset the typing handler's data caches.
-        """
+    def _reset(self) -> None:
+        """Reset the typing handler's data caches."""
         # map room IDs to serial numbers
         self._room_serials = {}
         # map room IDs to sets of users currently typing
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
         self._member_last_federation_poke = {}
         self.wheel_timer = WheelTimer(bucket_size=5000)
 
-    def _handle_timeouts(self):
+    def _handle_timeouts(self) -> None:
         logger.debug("Checking for typing timeouts")
 
         now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
         for member in members:
             self._handle_timeout_for_member(now, member)
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         if not self.is_typing(member):
             # Nothing to do if they're no longer typing
             return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
         # each person typing.
         self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
 
-    def is_typing(self, member):
+    def is_typing(self, member: RoomMember) -> bool:
         return member.user_id in self._room_typing.get(member.room_id, [])
 
-    async def _push_remote(self, member, typing):
+    async def _push_remote(self, member: RoomMember, typing: bool) -> None:
         if not self.federation:
             return
 
@@ -148,9 +148,8 @@ class FollowerTypingHandler:
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
-        """Should be called whenever we receive updates for typing stream.
-        """
+    ) -> None:
+        """Should be called whenever we receive updates for typing stream."""
 
         if self._latest_room_serial > token:
             # The master has gone backwards. To prevent inconsistent data, just
@@ -178,7 +177,7 @@ class FollowerTypingHandler:
 
     async def _send_changes_in_typing_to_remotes(
         self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
-    ):
+    ) -> None:
         """Process a change in typing of a room from replication, sending EDUs
         for any local users.
         """
@@ -194,12 +193,12 @@ class FollowerTypingHandler:
             if self.is_mine_id(user_id):
                 await self._push_remote(RoomMember(room_id, user_id), False)
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         return self._latest_room_serial
 
 
 class TypingWriterHandler(FollowerTypingHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +212,15 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
-        self._member_typing_until = {}  # clock time we expect to stop
+        # clock time we expect to stop
+        self._member_typing_until = {}  # type: Dict[RoomMember, int]
 
         # caches which room_ids changed at which serials
         self._typing_stream_change_cache = StreamChangeCache(
             "TypingStreamChangeCache", self._latest_room_serial
         )
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         super()._handle_timeout_for_member(now, member)
 
         if not self.is_typing(member):
@@ -233,7 +233,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             self._stopped_typing(member)
             return
 
-    async def started_typing(self, target_user, requester, room_id, timeout):
+    async def started_typing(
+        self, target_user: UserID, requester: Requester, room_id: str, timeout: int
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -263,11 +265,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         if was_present:
             # No point sending another notification
-            return None
+            return
 
         self._push_update(member=member, typing=True)
 
-    async def stopped_typing(self, target_user, requester, room_id):
+    async def stopped_typing(
+        self, target_user: UserID, requester: Requester, room_id: str
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -290,23 +294,23 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._stopped_typing(member)
 
-    def user_left_room(self, user, room_id):
+    def user_left_room(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
         if self.is_mine_id(user_id):
             member = RoomMember(room_id=room_id, user_id=user_id)
             self._stopped_typing(member)
 
-    def _stopped_typing(self, member):
+    def _stopped_typing(self, member: RoomMember) -> None:
         if member.user_id not in self._room_typing.get(member.room_id, set()):
             # No point
-            return None
+            return
 
         self._member_typing_until.pop(member, None)
         self._member_last_federation_poke.pop(member, None)
 
         self._push_update(member=member, typing=False)
 
-    def _push_update(self, member, typing):
+    def _push_update(self, member: RoomMember, typing: bool) -> None:
         if self.hs.is_mine_id(member.user_id):
             # Only send updates for changes to our own users.
             run_as_background_process(
@@ -315,7 +319,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._push_update_local(member=member, typing=typing)
 
-    async def _recv_edu(self, origin, content):
+    async def _recv_edu(self, origin: str, content: JsonDict) -> None:
         room_id = content["room_id"]
         user_id = content["user_id"]
 
@@ -340,7 +344,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
             self._push_update_local(member=member, typing=content["typing"])
 
-    def _push_update_local(self, member, typing):
+    def _push_update_local(self, member: RoomMember, typing: bool) -> None:
         room_set = self._room_typing.setdefault(member.room_id, set())
         if typing:
             room_set.add(member.user_id)
@@ -386,7 +390,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
             last_id
-        )
+        )  # type: Optional[Iterable[str]]
 
         if changed_rooms is None:
             changed_rooms = self._room_serials
@@ -412,13 +416,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
+    ) -> None:
         # The writing process should never get updates from replication.
         raise Exception("Typing writer instance got typing info over replication")
 
 
 class TypingNotificationEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.clock = hs.get_clock()
         # We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +431,7 @@ class TypingNotificationEventSource:
         #
         self.get_typing_handler = hs.get_typing_handler
 
-    def _make_event_for(self, room_id):
+    def _make_event_for(self, room_id: str) -> JsonDict:
         typing = self.get_typing_handler()._room_typing[room_id]
         return {
             "type": "m.typing",
@@ -462,7 +466,9 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    async def get_new_events(self, from_key, room_ids, **kwargs):
+    async def get_new_events(
+        self, from_key: int, room_ids: Iterable[str], **kwargs
+    ) -> Tuple[List[JsonDict], int]:
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
             handler = self.get_typing_handler()
@@ -478,5 +484,5 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    def get_current_key(self):
+    def get_current_key(self) -> int:
         return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d4651c8348..1a8340000a 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -97,8 +97,7 @@ class UserDirectoryHandler(StateDeltasHandler):
         return results
 
     def notify_new_event(self) -> None:
-        """Called when there may be more deltas to process
-        """
+        """Called when there may be more deltas to process"""
         if not self.update_user_directory:
             return
 
@@ -134,8 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler):
             )
 
     async def handle_user_deactivated(self, user_id: str) -> None:
-        """Called when a user ID is deactivated
-        """
+        """Called when a user ID is deactivated"""
         # FIXME(#3714): We should probably do this in the same worker as all
         # the other changes.
         await self.store.remove_from_user_dir(user_id)
@@ -145,7 +143,7 @@ class UserDirectoryHandler(StateDeltasHandler):
         if self.pos is None:
             self.pos = await self.store.get_user_directory_stream_pos()
 
-        # If still None then the initial background update hasn't happened yet
+        # If still None then the initial background update hasn't happened yet.
         if self.pos is None:
             return None
 
@@ -176,8 +174,7 @@ class UserDirectoryHandler(StateDeltasHandler):
                 await self.store.update_user_directory_stream_pos(max_pos)
 
     async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
-        """Called with the state deltas to process
-        """
+        """Called with the state deltas to process"""
         for delta in deltas:
             typ = delta["type"]
             state_key = delta["state_key"]
@@ -233,6 +230,11 @@ class UserDirectoryHandler(StateDeltasHandler):
 
                     if change:  # The user joined
                         event = await self.store.get_event(event_id, allow_none=True)
+                        # It isn't expected for this event to not exist, but we
+                        # don't want the entire background process to break.
+                        if event is None:
+                            continue
+
                         profile = ProfileInfo(
                             avatar_url=event.content.get("avatar_url"),
                             display_name=event.content.get("displayname"),
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 4bc3cb53f0..c658862fe6 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -54,8 +54,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
 
 
 def get_request_user_agent(request: IRequest, default: str = "") -> str:
-    """Return the last User-Agent header, or the given default.
-    """
+    """Return the last User-Agent header, or the given default."""
     # There could be raw utf-8 bytes in the User-Agent header.
 
     # N.B. if you don't do this, the logger explodes cryptically
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 37ccf5ab98..e54d9bd213 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -56,7 +56,7 @@ from twisted.web.client import (
 )
 from twisted.web.http import PotentialDataLoss
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
 
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -398,7 +398,8 @@ class SimpleHttpClient:
                 body_producer = None
                 if data is not None:
                     body_producer = QuieterFileBodyProducer(
-                        BytesIO(data), cooperator=self._cooperator,
+                        BytesIO(data),
+                        cooperator=self._cooperator,
                     )
 
                 request_deferred = treq.request(
@@ -407,13 +408,18 @@ class SimpleHttpClient:
                     agent=self.agent,
                     data=body_producer,
                     headers=headers,
+                    # Avoid buffering the body in treq since we do not reuse
+                    # response bodies.
+                    unbuffered=True,
                     **self._extra_treq_args,
                 )  # type: defer.Deferred
 
                 # we use our own timeout mechanism rather than treq's as a workaround
                 # for https://twistedmatrix.com/trac/ticket/9534.
                 request_deferred = timeout_deferred(
-                    request_deferred, 60, self.hs.get_reactor(),
+                    request_deferred,
+                    60,
+                    self.hs.get_reactor(),
                 )
 
                 # turn timeouts into RequestTimedOutErrors
@@ -699,18 +705,6 @@ class SimpleHttpClient:
 
         resp_headers = dict(response.headers.getAllRawHeaders())
 
-        if (
-            b"Content-Length" in resp_headers
-            and max_size
-            and int(resp_headers[b"Content-Length"][0]) > max_size
-        ):
-            logger.warning("Requested URL is too large > %r bytes" % (max_size,))
-            raise SynapseError(
-                502,
-                "Requested file is too large > %r bytes" % (max_size,),
-                Codes.TOO_LARGE,
-            )
-
         if response.code > 299:
             logger.warning("Got %d when downloading %s" % (response.code, url))
             raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
@@ -777,7 +771,9 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
         # in the meantime.
         if self.max_size is not None and self.length >= self.max_size:
             self.deferred.errback(BodyExceededMaxSize())
-            self.transport.loseConnection()
+            # Close the connection (forcefully) since all the data will get
+            # discarded anyway.
+            self.transport.abortConnection()
 
     def connectionLost(self, reason: Failure) -> None:
         # If the maximum size was already exceeded, there's nothing to do.
@@ -811,6 +807,11 @@ def read_body_with_max_size(
     Returns:
         A Deferred which resolves to the length of the read body.
     """
+    # If the Content-Length header gives a size larger than the maximum allowed
+    # size, do not bother downloading the body.
+    if max_size is not None and response.length != UNKNOWN_LENGTH:
+        if response.length > max_size:
+            return defer.fail(BodyExceededMaxSize())
 
     d = defer.Deferred()
     response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 4c06a117d3..2e83fa6773 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -195,8 +195,7 @@ class MatrixFederationAgent:
 
 @implementer(IAgentEndpointFactory)
 class MatrixHostnameEndpointFactory:
-    """Factory for MatrixHostnameEndpoint for parsing to an Agent.
-    """
+    """Factory for MatrixHostnameEndpoint for parsing to an Agent."""
 
     def __init__(
         self,
@@ -261,8 +260,7 @@ class MatrixHostnameEndpoint:
         self._srv_resolver = srv_resolver
 
     def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
-        """Implements IStreamClientEndpoint interface
-        """
+        """Implements IStreamClientEndpoint interface"""
 
         return run_in_background(self._do_connect, protocol_factory)
 
@@ -323,12 +321,19 @@ class MatrixHostnameEndpoint:
         if port or _is_ip_literal(host):
             return [Server(host, port or 8448)]
 
+        logger.debug("Looking up SRV record for %s", host.decode(errors="replace"))
         server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
 
         if server_list:
+            logger.debug(
+                "Got %s from SRV lookup for %s",
+                ", ".join(map(str, server_list)),
+                host.decode(errors="replace"),
+            )
             return server_list
 
         # No SRV records, so we fallback to host and 8448
+        logger.debug("No SRV records for %s", host.decode(errors="replace"))
         return [Server(host, 8448)]
 
 
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index b3b6dbcab0..4def7d7633 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -81,8 +81,7 @@ class WellKnownLookupResult:
 
 
 class WellKnownResolver:
-    """Handles well-known lookups for matrix servers.
-    """
+    """Handles well-known lookups for matrix servers."""
 
     def __init__(
         self,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 19293bf673..cde42e9f5e 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -254,7 +254,8 @@ class MatrixFederationHttpClient:
         # Use a BlacklistingAgentWrapper to prevent circumventing the IP
         # blacklist via IP literals in server names
         self.agent = BlacklistingAgentWrapper(
-            self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
+            self.agent,
+            ip_blacklist=hs.config.federation_ip_range_blacklist,
         )
 
         self.clock = hs.get_clock()
@@ -652,7 +653,7 @@ class MatrixFederationHttpClient:
         backoff_on_404: bool = False,
         try_trailing_slash_on_400: bool = False,
     ) -> Union[JsonDict, list]:
-        """ Sends the specified json data using PUT
+        """Sends the specified json data using PUT
 
         Args:
             destination: The remote server to send the HTTP request to.
@@ -740,7 +741,7 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         args: Optional[QueryArgs] = None,
     ) -> Union[JsonDict, list]:
-        """ Sends the specified json data using POST
+        """Sends the specified json data using POST
 
         Args:
             destination: The remote server to send the HTTP request to.
@@ -799,7 +800,11 @@ class MatrixFederationHttpClient:
             _sec_timeout = self.default_timeout
 
         body = await _handle_json_response(
-            self.reactor, _sec_timeout, request, response, start_ms,
+            self.reactor,
+            _sec_timeout,
+            request,
+            response,
+            start_ms,
         )
         return body
 
@@ -813,7 +818,7 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         try_trailing_slash_on_400: bool = False,
     ) -> Union[JsonDict, list]:
-        """ GETs some json from the given host homeserver and path
+        """GETs some json from the given host homeserver and path
 
         Args:
             destination: The remote server to send the HTTP request to.
@@ -994,7 +999,10 @@ class MatrixFederationHttpClient:
         except BodyExceededMaxSize:
             msg = "Requested file is too large > %r bytes" % (max_size,)
             logger.warning(
-                "{%s} [%s] %s", request.txn_id, request.destination, msg,
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
             )
             raise SynapseError(502, msg, Codes.TOO_LARGE)
         except Exception as e:
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 7c5defec82..0ec5d941b8 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -213,8 +213,7 @@ class RequestMetrics:
         self.update_metrics()
 
     def update_metrics(self):
-        """Updates the in flight metrics with values from this request.
-        """
+        """Updates the in flight metrics with values from this request."""
         new_stats = self.start_context.get_resource_usage()
 
         diff = new_stats - self._request_stats
diff --git a/synapse/http/server.py b/synapse/http/server.py
index e464bfe6c7..845db9b78d 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,10 +22,22 @@ import types
 import urllib
 from http import HTTPStatus
 from io import BytesIO
-from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
+from typing import (
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Pattern,
+    Tuple,
+    Union,
+)
 
 import jinja2
 from canonicaljson import iterencode_canonical_json
+from typing_extensions import Protocol
 from zope.interface import implementer
 
 from twisted.internet import defer, interfaces
@@ -64,8 +76,7 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 
 
 def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
-    """Sends a JSON error response to clients.
-    """
+    """Sends a JSON error response to clients."""
 
     if f.check(SynapseError):
         error_code = f.value.code
@@ -94,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
                 pass
     else:
         respond_with_json(
-            request, error_code, error_dict, send_cors=True,
+            request,
+            error_code,
+            error_dict,
+            send_cors=True,
         )
 
 
 def return_html_error(
-    f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
+    f: failure.Failure,
+    request: Request,
+    error_template: Union[str, jinja2.Template],
 ) -> None:
     """Sends an HTML error page corresponding to the given failure.
 
@@ -168,24 +184,39 @@ def wrap_async_request_handler(h):
     return preserve_fn(wrapped_async_request_handler)
 
 
-class HttpServer:
-    """ Interface for registering callbacks on a HTTP server
-    """
+# Type of a callback method for processing requests
+# it is actually called with a SynapseRequest and a kwargs dict for the params,
+# but I can't figure out how to represent that.
+ServletCallback = Callable[
+    ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
+]
 
-    def register_paths(self, method, path_patterns, callback):
-        """ Register a callback that gets fired if we receive a http request
+
+class HttpServer(Protocol):
+    """Interface for registering callbacks on a HTTP server"""
+
+    def register_paths(
+        self,
+        method: str,
+        path_patterns: Iterable[Pattern],
+        callback: ServletCallback,
+        servlet_classname: str,
+    ) -> None:
+        """Register a callback that gets fired if we receive a http request
         with the given method for a path that matches the given regex.
 
         If the regex contains groups these gets passed to the callback via
         an unpacked tuple.
 
         Args:
-            method (str): The method to listen to.
-            path_patterns (list<SRE_Pattern>): The regex used to match requests.
-            callback (function): The function to fire if we receive a matched
+            method: The HTTP method to listen to.
+            path_patterns: The regex used to match requests.
+            callback: The function to fire if we receive a matched
                 request. The first argument will be the request object and
                 subsequent arguments will be any matched groups from the regex.
-                This should return a tuple of (code, response).
+                This should return either tuple of (code, response), or None.
+            servlet_classname (str): The name of the handler to be used in prometheus
+                and opentracing logs.
         """
         pass
 
@@ -207,8 +238,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
         self._extract_context = extract_context
 
     def render(self, request):
-        """ This gets called by twisted every time someone sends us a request.
-        """
+        """This gets called by twisted every time someone sends us a request."""
         defer.ensureDeferred(self._async_render_wrapper(request))
         return NOT_DONE_YET
 
@@ -259,13 +289,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     def _send_response(
-        self, request: SynapseRequest, code: int, response_object: Any,
+        self,
+        request: SynapseRequest,
+        code: int,
+        response_object: Any,
     ) -> None:
         raise NotImplementedError()
 
     @abc.abstractmethod
     def _send_error_response(
-        self, f: failure.Failure, request: SynapseRequest,
+        self,
+        f: failure.Failure,
+        request: SynapseRequest,
     ) -> None:
         raise NotImplementedError()
 
@@ -280,10 +315,12 @@ class DirectServeJsonResource(_AsyncResource):
         self.canonical_json = canonical_json
 
     def _send_response(
-        self, request: Request, code: int, response_object: Any,
+        self,
+        request: Request,
+        code: int,
+        response_object: Any,
     ):
-        """Implements _AsyncResource._send_response
-        """
+        """Implements _AsyncResource._send_response"""
         # TODO: Only enable CORS for the requests that need it.
         respond_with_json(
             request,
@@ -294,15 +331,16 @@ class DirectServeJsonResource(_AsyncResource):
         )
 
     def _send_error_response(
-        self, f: failure.Failure, request: SynapseRequest,
+        self,
+        f: failure.Failure,
+        request: SynapseRequest,
     ) -> None:
-        """Implements _AsyncResource._send_error_response
-        """
+        """Implements _AsyncResource._send_error_response"""
         return_json_error(f, request)
 
 
 class JsonResource(DirectServeJsonResource):
-    """ This implements the HttpServer interface and provides JSON support for
+    """This implements the HttpServer interface and provides JSON support for
     Resources.
 
     Register callbacks via register_paths()
@@ -354,7 +392,7 @@ class JsonResource(DirectServeJsonResource):
 
     def _get_handler_for_request(
         self, request: SynapseRequest
-    ) -> Tuple[Callable, str, Dict[str, str]]:
+    ) -> Tuple[ServletCallback, str, Dict[str, str]]:
         """Finds a callback method to handle the given request.
 
         Returns:
@@ -415,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource):
     ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
 
     def _send_response(
-        self, request: SynapseRequest, code: int, response_object: Any,
+        self,
+        request: SynapseRequest,
+        code: int,
+        response_object: Any,
     ):
-        """Implements _AsyncResource._send_response
-        """
+        """Implements _AsyncResource._send_response"""
         # We expect to get bytes for us to write
         assert isinstance(response_object, bytes)
         html_bytes = response_object
@@ -426,10 +466,11 @@ class DirectServeHtmlResource(_AsyncResource):
         respond_with_html_bytes(request, 200, html_bytes)
 
     def _send_error_response(
-        self, f: failure.Failure, request: SynapseRequest,
+        self,
+        f: failure.Failure,
+        request: SynapseRequest,
     ) -> None:
-        """Implements _AsyncResource._send_error_response
-        """
+        """Implements _AsyncResource._send_error_response"""
         return_html_error(f, request, self.ERROR_TEMPLATE)
 
 
@@ -506,7 +547,9 @@ class _ByteProducer:
     min_chunk_size = 1024
 
     def __init__(
-        self, request: Request, iterator: Iterator[bytes],
+        self,
+        request: Request,
+        iterator: Iterator[bytes],
     ):
         self._request = request
         self._iterator = iterator
@@ -626,7 +669,10 @@ def respond_with_json(
 
 
 def respond_with_json_bytes(
-    request: Request, code: int, json_bytes: bytes, send_cors: bool = False,
+    request: Request,
+    code: int,
+    json_bytes: bytes,
+    send_cors: bool = False,
 ):
     """Sends encoded JSON in response to the given request.
 
@@ -733,8 +779,15 @@ def set_clickjacking_protection_headers(request: Request):
     request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
 
 
+def respond_with_redirect(request: Request, url: bytes) -> None:
+    """Write a 302 response to the request, if it is still alive."""
+    logger.debug("Redirect to %s", url.decode("utf-8"))
+    request.redirect(url)
+    finish_request(request)
+
+
 def finish_request(request: Request):
-    """ Finish writing the response to the request.
+    """Finish writing the response to the request.
 
     Twisted throws a RuntimeException if the connection closed before the
     response was written but doesn't provide a convenient or reliable way to
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index b361b7cbaf..0e637f4701 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -258,7 +258,7 @@ def assert_params_in_dict(body, required):
 
 class RestServlet:
 
-    """ A Synapse REST Servlet.
+    """A Synapse REST Servlet.
 
     An implementing class can either provide its own custom 'register' method,
     or use the automatic pattern handling provided by the base class.
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 12ec3f851f..4a4fb5ef26 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -249,8 +249,7 @@ class SynapseRequest(Request):
         )
 
     def _finished_processing(self):
-        """Log the completion of this request and update the metrics
-        """
+        """Log the completion of this request and update the metrics"""
         assert self.logcontext is not None
         usage = self.logcontext.get_resource_usage()
 
@@ -276,7 +275,8 @@ class SynapseRequest(Request):
             # authenticated (e.g. and admin is puppetting a user) then we log both.
             if self.requester.user.to_string() != authenticated_entity:
                 authenticated_entity = "{},{}".format(
-                    authenticated_entity, self.requester.user.to_string(),
+                    authenticated_entity,
+                    self.requester.user.to_string(),
                 )
         elif self.requester is not None:
             # This shouldn't happen, but we log it so we don't lose information
@@ -322,8 +322,7 @@ class SynapseRequest(Request):
             logger.warning("Failed to stop metrics: %r", e)
 
     def _should_log_request(self) -> bool:
-        """Whether we should log at INFO that we processed the request.
-        """
+        """Whether we should log at INFO that we processed the request."""
         if self.path == b"/health":
             return False
 
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index fb937b3f28..f8e9112b56 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -174,7 +174,9 @@ class RemoteHandler(logging.Handler):
 
             # Make a new producer and start it.
             self._producer = LogProducer(
-                buffer=self._buffer, transport=result.transport, format=self.format,
+                buffer=self._buffer,
+                transport=result.transport,
+                format=self.format,
             )
             result.transport.registerProducer(self._producer, True)
             self._producer.resumeProducing()
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 14d9c104c2..3e054f615c 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -60,7 +60,10 @@ def parse_drain_configs(
             )
 
         # Either use the default formatter or the tersejson one.
-        if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
+        if logging_type in (
+            DrainType.CONSOLE_JSON,
+            DrainType.FILE_JSON,
+        ):
             formatter = "json"  # type: Optional[str]
         elif logging_type in (
             DrainType.CONSOLE_JSON_TERSE,
@@ -131,7 +134,9 @@ def parse_drain_configs(
             )
 
 
-def setup_structured_logging(log_config: dict,) -> dict:
+def setup_structured_logging(
+    log_config: dict,
+) -> dict:
     """
     Convert a legacy structured logging configuration (from Synapse < v1.23.0)
     to one compatible with the new standard library handlers.
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index c2db8b45f3..78e27bfb00 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -338,7 +338,10 @@ class LoggingContext:
         if self.previous_context != old_context:
             logcontext_error(
                 "Expected previous context %r, found %r"
-                % (self.previous_context, old_context,)
+                % (
+                    self.previous_context,
+                    old_context,
+                )
             )
         return self
 
@@ -562,7 +565,7 @@ class LoggingContextFilter(logging.Filter):
 class PreserveLoggingContext:
     """Context manager which replaces the logging context
 
-     The previous logging context is restored on exit."""
+    The previous logging context is restored on exit."""
 
     __slots__ = ["_old_context", "_new_context"]
 
@@ -585,7 +588,10 @@ class PreserveLoggingContext:
             else:
                 logcontext_error(
                     "Expected logging context %s but found %s"
-                    % (self._new_context, context,)
+                    % (
+                        self._new_context,
+                        context,
+                    )
                 )
 
 
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index ab586c318c..10bd4a1461 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -238,8 +238,7 @@ try:
 
     @attr.s(slots=True, frozen=True)
     class _WrappedRustReporter:
-        """Wrap the reporter to ensure `report_span` never throws.
-        """
+        """Wrap the reporter to ensure `report_span` never throws."""
 
         _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
 
@@ -326,8 +325,7 @@ def noop_context_manager(*args, **kwargs):
 
 
 def init_tracer(hs: "HomeServer"):
-    """Set the whitelists and initialise the JaegerClient tracer
-    """
+    """Set the whitelists and initialise the JaegerClient tracer"""
     global opentracing
     if not hs.config.opentracer_enabled:
         # We don't have a tracer
@@ -384,7 +382,7 @@ def whitelisted_homeserver(destination):
 
     Args:
         destination (str)
-        """
+    """
 
     if _homeserver_whitelist:
         return _homeserver_whitelist.match(destination)
@@ -791,7 +789,7 @@ def tag_args(func):
 
     @wraps(func)
     def _tag_args_inner(*args, **kwargs):
-        argspec = inspect.getargspec(func)
+        argspec = inspect.getfullargspec(func)
         for i, arg in enumerate(argspec.args[1:]):
             set_tag("ARG_" + arg, args[i])
         set_tag("args", args[len(argspec.args) :])
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index becf66dd86..fd3543ab04 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -43,8 +43,7 @@ def _log_debug_as_f(f, msg, msg_args):
 
 
 def log_function(f):
-    """ Function decorator that logs every call to that function.
-    """
+    """Function decorator that logs every call to that function."""
     func_name = f.__name__
 
     @wraps(f)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index cbf0dbb871..a8cb49d5b4 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -155,8 +155,7 @@ class InFlightGauge:
             self._registrations.setdefault(key, set()).add(callback)
 
     def unregister(self, key, callback):
-        """Registers that we've exited a block with labels `key`.
-        """
+        """Registers that we've exited a block with labels `key`."""
 
         with self._lock:
             self._registrations.setdefault(key, set()).discard(callback)
@@ -402,7 +401,9 @@ class PyPyGCStats:
         #     Total time spent in GC:  0.073                  # s.total_gc_time
 
         pypy_gc_time = CounterMetricFamily(
-            "pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[],
+            "pypy_gc_time_seconds_total",
+            "Total time spent in PyPy GC",
+            labels=[],
         )
         pypy_gc_time.add_metric([], s.total_gc_time / 1000)
         yield pypy_gc_time
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index 734271e765..71320a1402 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -216,7 +216,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
     @classmethod
     def factory(cls, registry):
         """Returns a dynamic MetricsHandler class tied
-           to the passed registry.
+        to the passed registry.
         """
         # This implementation relies on MetricsHandler.registry
         #  (defined above and defaulted to REGISTRY).
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 70e0fa45d9..b56986d8e7 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -208,7 +208,8 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
                     return await maybe_awaitable(func(*args, **kwargs))
             except Exception:
                 logger.exception(
-                    "Background process '%s' threw an exception", desc,
+                    "Background process '%s' threw an exception",
+                    desc,
                 )
             finally:
                 _background_process_in_flight_count.labels(desc).dec()
@@ -249,8 +250,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
         self._proc = _BackgroundProcess(name, self)
 
     def start(self, rusage: "Optional[resource._RUsage]"):
-        """Log context has started running (again).
-        """
+        """Log context has started running (again)."""
 
         super().start(rusage)
 
@@ -261,8 +261,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
             _background_processes_active_since_last_scrape.add(self._proc)
 
     def __exit__(self, type, value, traceback) -> None:
-        """Log context has finished.
-        """
+        """Log context has finished."""
 
         super().__exit__(type, value, traceback)
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 72ab5750cc..2e3b311c4a 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -275,11 +275,17 @@ class ModuleApi:
                 redirect them directly if whitelisted).
         """
         self._auth_handler._complete_sso_login(
-            registered_user_id, request, client_redirect_url,
+            registered_user_id,
+            request,
+            client_redirect_url,
         )
 
     async def complete_sso_login_async(
-        self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
+        self,
+        registered_user_id: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
+        new_user: bool = False,
     ):
         """Complete a SSO login by redirecting the user to a page to confirm whether they
         want their access token sent to `client_redirect_url`, or redirect them to that
@@ -291,9 +297,11 @@ class ModuleApi:
             request: The request to respond to.
             client_redirect_url: The URL to which to offer to redirect the user (or to
                 redirect them directly if whitelisted).
+            new_user: set to true to use wording for the consent appropriate to a user
+                who has just registered.
         """
         await self._auth_handler.complete_sso_login(
-            registered_user_id, request, client_redirect_url,
+            registered_user_id, request, client_redirect_url, new_user=new_user
         )
 
     @defer.inlineCallbacks
@@ -346,7 +354,10 @@ class ModuleApi:
             event,
             _,
         ) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event(
-            requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
+            requester,
+            event_dict,
+            ratelimit=False,
+            ignore_shadow_ban=True,
         )
 
         return event
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 0745899b48..1374aae490 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -75,7 +75,7 @@ def count(func: Callable[[T], bool], it: Iterable[T]) -> int:
 
 
 class _NotificationListener:
-    """ This represents a single client connection to the events stream.
+    """This represents a single client connection to the events stream.
     The events stream handler will have yielded to the deferred, so to
     notify the handler it is sufficient to resolve the deferred.
     """
@@ -119,7 +119,10 @@ class _NotifierUserStream:
             self.notify_deferred = ObservableDeferred(defer.Deferred())
 
     def notify(
-        self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
+        self,
+        stream_key: str,
+        stream_id: Union[int, RoomStreamToken],
+        time_now_ms: int,
     ):
         """Notify any listeners for this user of a new event from an
         event source.
@@ -140,7 +143,7 @@ class _NotifierUserStream:
             noify_deferred.callback(self.current_token)
 
     def remove(self, notifier: "Notifier"):
-        """ Remove this listener from all the indexes in the Notifier
+        """Remove this listener from all the indexes in the Notifier
         it knows about.
         """
 
@@ -186,7 +189,7 @@ class _PendingRoomEventEntry:
 
 
 class Notifier:
-    """ This class is responsible for notifying any listeners when there are
+    """This class is responsible for notifying any listeners when there are
     new events available for it.
 
     Primarily used from the /events stream.
@@ -265,8 +268,7 @@ class Notifier:
         max_room_stream_token: RoomStreamToken,
         extra_users: Collection[UserID] = [],
     ):
-        """Unwraps event and calls `on_new_room_event_args`.
-        """
+        """Unwraps event and calls `on_new_room_event_args`."""
         self.on_new_room_event_args(
             event_pos=event_pos,
             room_id=event.room_id,
@@ -341,7 +343,10 @@ class Notifier:
 
         if users or rooms:
             self.on_new_event(
-                "room_key", max_room_stream_token, users=users, rooms=rooms,
+                "room_key",
+                max_room_stream_token,
+                users=users,
+                rooms=rooms,
             )
             self._on_updated_room_token(max_room_stream_token)
 
@@ -392,7 +397,7 @@ class Notifier:
         users: Collection[Union[str, UserID]] = [],
         rooms: Collection[str] = [],
     ):
-        """ Used to inform listeners that something has happened event wise.
+        """Used to inform listeners that something has happened event wise.
 
         Will wake up all listeners for the given users and rooms.
         """
@@ -418,7 +423,9 @@ class Notifier:
 
             # Notify appservices
             self._notify_app_services_ephemeral(
-                stream_key, new_token, users,
+                stream_key,
+                new_token,
+                users,
             )
 
     def on_new_replication_data(self) -> None:
@@ -502,7 +509,7 @@ class Notifier:
         is_guest: bool = False,
         explicit_room_id: str = None,
     ) -> EventStreamResult:
-        """ For the given user and rooms, return any new events for them. If
+        """For the given user and rooms, return any new events for them. If
         there are no new events wait for up to `timeout` milliseconds for any
         new events to happen before returning.
 
@@ -651,8 +658,7 @@ class Notifier:
             cb()
 
     def notify_remote_server_up(self, server: str):
-        """Notify any replication that a remote server has come back up
-        """
+        """Notify any replication that a remote server has come back up"""
         # We call federation_sender directly rather than registering as a
         # callback as a) we already have a reference to it and b) it introduces
         # circular dependencies.
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 9018f9e20b..c016a83909 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -144,8 +144,7 @@ class BulkPushRuleEvaluator:
 
     @lru_cache()
     def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
-        """Get the current RulesForRoom object for the given room id
-        """
+        """Get the current RulesForRoom object for the given room id"""
         # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
         # before any lookup methods get called on it as otherwise there may be
         # a race if invalidate_all gets called (which assumes its in the cache)
@@ -252,7 +251,9 @@ class BulkPushRuleEvaluator:
         # notified for this event. (This will then get handled when we persist
         # the event)
         await self.store.add_push_actions_to_staging(
-            event.event_id, actions_by_user, count_as_unread,
+            event.event_id,
+            actions_by_user,
+            count_as_unread,
         )
 
 
@@ -524,7 +525,7 @@ class RulesForRoom:
 class _Invalidation:
     # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
     # which means that it it is stored on the bulk_get_push_rules cache entry. In order
-    # to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
+    # to ensure that we don't accumulate lots of redundant callbacks on the cache entry,
     # we need to ensure that two _Invalidation objects are "equal" if they refer to the
     # same `cache` and `room_id`.
     #
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 4ac1b31748..5fec2aaf5d 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -116,8 +116,7 @@ class EmailPusher(Pusher):
         self._is_processing = True
 
     def _resume_processing(self) -> None:
-        """Used by tests to resume processing of events after pausing.
-        """
+        """Used by tests to resume processing of events after pausing."""
         assert self._is_processing
         self._is_processing = False
         self._start_processing()
@@ -157,8 +156,10 @@ class EmailPusher(Pusher):
         being run.
         """
         start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
-        unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
-            self.user_id, start, self.max_stream_ordering
+        unprocessed = (
+            await self.store.get_unread_push_actions_for_user_in_range_for_email(
+                self.user_id, start, self.max_stream_ordering
+            )
         )
 
         soonest_due_at = None  # type: Optional[int]
@@ -222,12 +223,14 @@ class EmailPusher(Pusher):
         self, last_stream_ordering: int
     ) -> None:
         self.last_stream_ordering = last_stream_ordering
-        pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
-            self.app_id,
-            self.email,
-            self.user_id,
-            last_stream_ordering,
-            self.clock.time_msec(),
+        pusher_still_exists = (
+            await self.store.update_pusher_last_stream_ordering_and_success(
+                self.app_id,
+                self.email,
+                self.user_id,
+                last_stream_ordering,
+                self.clock.time_msec(),
+            )
         )
         if not pusher_still_exists:
             # The pusher has been deleted while we were processing, so
@@ -298,7 +301,8 @@ class EmailPusher(Pusher):
                     current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
                 )
         self.throttle_params[room_id] = ThrottleParams(
-            self.clock.time_msec(), new_throttle_ms,
+            self.clock.time_msec(),
+            new_throttle_ms,
         )
         assert self.pusher_id is not None
         await self.store.set_throttle_params(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index e048b0d59e..b9d3da2e0a 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -176,8 +176,10 @@ class HttpPusher(Pusher):
         Never call this directly: use _process which will only allow this to
         run once per pusher.
         """
-        unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
-            self.user_id, self.last_stream_ordering, self.max_stream_ordering
+        unprocessed = (
+            await self.store.get_unread_push_actions_for_user_in_range_for_http(
+                self.user_id, self.last_stream_ordering, self.max_stream_ordering
+            )
         )
 
         logger.info(
@@ -204,12 +206,14 @@ class HttpPusher(Pusher):
                 http_push_processed_counter.inc()
                 self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
                 self.last_stream_ordering = push_action["stream_ordering"]
-                pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
-                    self.app_id,
-                    self.pushkey,
-                    self.user_id,
-                    self.last_stream_ordering,
-                    self.clock.time_msec(),
+                pusher_still_exists = (
+                    await self.store.update_pusher_last_stream_ordering_and_success(
+                        self.app_id,
+                        self.pushkey,
+                        self.user_id,
+                        self.last_stream_ordering,
+                        self.clock.time_msec(),
+                    )
                 )
                 if not pusher_still_exists:
                     # The pusher has been deleted while we were processing, so
@@ -290,7 +294,8 @@ class HttpPusher(Pusher):
                     # for sanity, we only remove the pushkey if it
                     # was the one we actually sent...
                     logger.warning(
-                        ("Ignoring rejected pushkey %s because we didn't send it"), pk,
+                        ("Ignoring rejected pushkey %s because we didn't send it"),
+                        pk,
                     )
                 else:
                     logger.info("Pushkey %s was rejected: removing", pk)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 4d875dcb91..d10201b6b3 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -34,6 +34,7 @@ from synapse.push.presentable_names import (
     descriptor_from_member_events,
     name_from_member_event,
 )
+from synapse.storage.state import StateFilter
 from synapse.types import StateMap, UserID
 from synapse.util.async_helpers import concurrently_execute
 from synapse.visibility import filter_events_for_client
@@ -110,6 +111,7 @@ class Mailer:
 
         self.sendmail = self.hs.get_sendmail()
         self.store = self.hs.get_datastore()
+        self.state_store = self.hs.get_storage().state
         self.macaroon_gen = self.hs.get_macaroon_generator()
         self.state_handler = self.hs.get_state_handler()
         self.storage = hs.get_storage()
@@ -217,7 +219,17 @@ class Mailer:
         push_actions: Iterable[Dict[str, Any]],
         reason: Dict[str, Any],
     ) -> None:
-        """Send email regarding a user's room notifications"""
+        """
+        Send email regarding a user's room notifications
+
+        Params:
+            app_id: The application receiving the notification.
+            user_id: The user receiving the notification.
+            email_address: The email address receiving the notification.
+            push_actions: All outstanding notifications.
+            reason: The notification that was ready and is the cause of an email
+                being sent.
+        """
         rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
 
         notif_events = await self.store.get_events(
@@ -241,7 +253,7 @@ class Mailer:
         except StoreError:
             user_display_name = user_id
 
-        async def _fetch_room_state(room_id):
+        async def _fetch_room_state(room_id: str) -> None:
             room_state = await self.store.get_current_state_ids(room_id)
             state_by_room[room_id] = room_state
 
@@ -255,7 +267,7 @@ class Mailer:
         rooms = []
 
         for r in rooms_in_order:
-            roomvars = await self.get_room_vars(
+            roomvars = await self._get_room_vars(
                 r, user_id, notifs_by_room[r], notif_events, state_by_room[r]
             )
             rooms.append(roomvars)
@@ -267,13 +279,25 @@ class Mailer:
             fallback_to_members=True,
         )
 
-        summary_text = await self.make_summary_text(
-            notifs_by_room, state_by_room, notif_events, user_id, reason
-        )
+        if len(notifs_by_room) == 1:
+            # Only one room has new stuff
+            room_id = list(notifs_by_room.keys())[0]
+
+            summary_text = await self._make_summary_text_single_room(
+                room_id,
+                notifs_by_room[room_id],
+                state_by_room[room_id],
+                notif_events,
+                user_id,
+            )
+        else:
+            summary_text = await self._make_summary_text(
+                notifs_by_room, state_by_room, notif_events, reason
+            )
 
         template_vars = {
             "user_display_name": user_display_name,
-            "unsubscribe_link": self.make_unsubscribe_link(
+            "unsubscribe_link": self._make_unsubscribe_link(
                 user_id, app_id, email_address
             ),
             "summary_text": summary_text,
@@ -337,7 +361,7 @@ class Mailer:
             )
         )
 
-    async def get_room_vars(
+    async def _get_room_vars(
         self,
         room_id: str,
         user_id: str,
@@ -345,6 +369,20 @@ class Mailer:
         notif_events: Dict[str, EventBase],
         room_state_ids: StateMap[str],
     ) -> Dict[str, Any]:
+        """
+        Generate the variables for notifications on a per-room basis.
+
+        Args:
+            room_id: The room ID
+            user_id: The user receiving the notification.
+            notifs: The outstanding push actions for this room.
+            notif_events: The events related to the above notifications.
+            room_state_ids: The event IDs of the current room state.
+
+        Returns:
+             A dictionary to be added to the template context.
+        """
+
         # Check if one of the notifs is an invite event for the user.
         is_invite = False
         for n in notifs:
@@ -361,12 +399,12 @@ class Mailer:
             "hash": string_ordinal_total(room_id),  # See sender avatar hash
             "notifs": [],
             "invite": is_invite,
-            "link": self.make_room_link(room_id),
+            "link": self._make_room_link(room_id),
         }  # type: Dict[str, Any]
 
         if not is_invite:
             for n in notifs:
-                notifvars = await self.get_notif_vars(
+                notifvars = await self._get_notif_vars(
                     n, user_id, notif_events[n["event_id"]], room_state_ids
                 )
 
@@ -393,13 +431,26 @@ class Mailer:
 
         return room_vars
 
-    async def get_notif_vars(
+    async def _get_notif_vars(
         self,
         notif: Dict[str, Any],
         user_id: str,
         notif_event: EventBase,
         room_state_ids: StateMap[str],
     ) -> Dict[str, Any]:
+        """
+        Generate the variables for a single notification.
+
+        Args:
+            notif: The outstanding notification for this room.
+            user_id: The user receiving the notification.
+            notif_event: The event related to the above notification.
+            room_state_ids: The event IDs of the current room state.
+
+        Returns:
+             A dictionary to be added to the template context.
+        """
+
         results = await self.store.get_events_around(
             notif["room_id"],
             notif["event_id"],
@@ -408,7 +459,7 @@ class Mailer:
         )
 
         ret = {
-            "link": self.make_notif_link(notif),
+            "link": self._make_notif_link(notif),
             "ts": notif["received_ts"],
             "messages": [],
         }
@@ -419,22 +470,51 @@ class Mailer:
         the_events.append(notif_event)
 
         for event in the_events:
-            messagevars = await self.get_message_vars(notif, event, room_state_ids)
+            messagevars = await self._get_message_vars(notif, event, room_state_ids)
             if messagevars is not None:
                 ret["messages"].append(messagevars)
 
         return ret
 
-    async def get_message_vars(
+    async def _get_message_vars(
         self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
     ) -> Optional[Dict[str, Any]]:
+        """
+        Generate the variables for a single event, if possible.
+
+        Args:
+            notif: The outstanding notification for this room.
+            event: The event under consideration.
+            room_state_ids: The event IDs of the current room state.
+
+        Returns:
+             A dictionary to be added to the template context, or None if the
+             event cannot be processed.
+        """
         if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
             return None
 
-        sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
-        sender_state_event = await self.store.get_event(sender_state_event_id)
-        sender_name = name_from_member_event(sender_state_event)
-        sender_avatar_url = sender_state_event.content.get("avatar_url")
+        # Get the sender's name and avatar from the room state.
+        type_state_key = ("m.room.member", event.sender)
+        sender_state_event_id = room_state_ids.get(type_state_key)
+        if sender_state_event_id:
+            sender_state_event = await self.store.get_event(
+                sender_state_event_id
+            )  # type: Optional[EventBase]
+        else:
+            # Attempt to check the historical state for the room.
+            historical_state = await self.state_store.get_state_for_event(
+                event.event_id, StateFilter.from_types((type_state_key,))
+            )
+            sender_state_event = historical_state.get(type_state_key)
+
+        if sender_state_event:
+            sender_name = name_from_member_event(sender_state_event)
+            sender_avatar_url = sender_state_event.content.get("avatar_url")
+        else:
+            # No state could be found, fallback to the MXID.
+            sender_name = event.sender
+            sender_avatar_url = None
 
         # 'hash' for deterministically picking default images: use
         # sender_hash % the number of default images to choose from
@@ -459,18 +539,25 @@ class Mailer:
         ret["msgtype"] = msgtype
 
         if msgtype == "m.text":
-            self.add_text_message_vars(ret, event)
+            self._add_text_message_vars(ret, event)
         elif msgtype == "m.image":
-            self.add_image_message_vars(ret, event)
+            self._add_image_message_vars(ret, event)
 
         if "body" in event.content:
             ret["body_text_plain"] = event.content["body"]
 
         return ret
 
-    def add_text_message_vars(
+    def _add_text_message_vars(
         self, messagevars: Dict[str, Any], event: EventBase
     ) -> None:
+        """
+        Potentially add a sanitised message body to the message variables.
+
+        Args:
+            messagevars: The template context to be modified.
+            event: The event under consideration.
+        """
         msgformat = event.content.get("format")
 
         messagevars["format"] = msgformat
@@ -483,149 +570,232 @@ class Mailer:
         elif body:
             messagevars["body_text_html"] = safe_text(body)
 
-    def add_image_message_vars(
+    def _add_image_message_vars(
         self, messagevars: Dict[str, Any], event: EventBase
     ) -> None:
         """
         Potentially add an image URL to the message variables.
+
+        Args:
+            messagevars: The template context to be modified.
+            event: The event under consideration.
         """
         if "url" in event.content:
             messagevars["image_url"] = event.content["url"]
 
-    async def make_summary_text(
+    async def _make_summary_text_single_room(
         self,
-        notifs_by_room: Dict[str, List[Dict[str, Any]]],
-        room_state_ids: Dict[str, StateMap[str]],
+        room_id: str,
+        notifs: List[Dict[str, Any]],
+        room_state_ids: StateMap[str],
         notif_events: Dict[str, EventBase],
         user_id: str,
-        reason: Dict[str, Any],
-    ):
-        if len(notifs_by_room) == 1:
-            # Only one room has new stuff
-            room_id = list(notifs_by_room.keys())[0]
+    ) -> str:
+        """
+        Make a summary text for the email when only a single room has notifications.
 
-            # If the room has some kind of name, use it, but we don't
-            # want the generated-from-names one here otherwise we'll
-            # end up with, "new message from Bob in the Bob room"
-            room_name = await calculate_room_name(
-                self.store, room_state_ids[room_id], user_id, fallback_to_members=False
-            )
+        Args:
+            room_id: The ID of the room.
+            notifs: The push actions for this room.
+            room_state_ids: The state map for the room.
+            notif_events: A map of event ID -> notification event.
+            user_id: The user receiving the notification.
+
+        Returns:
+            The summary text.
+        """
+        # If the room has some kind of name, use it, but we don't
+        # want the generated-from-names one here otherwise we'll
+        # end up with, "new message from Bob in the Bob room"
+        room_name = await calculate_room_name(
+            self.store, room_state_ids, user_id, fallback_to_members=False
+        )
 
-            # See if one of the notifs is an invite event for the user
-            invite_event = None
-            for n in notifs_by_room[room_id]:
-                ev = notif_events[n["event_id"]]
-                if ev.type == EventTypes.Member and ev.state_key == user_id:
-                    if ev.content.get("membership") == Membership.INVITE:
-                        invite_event = ev
-                        break
-
-            if invite_event:
-                inviter_member_event_id = room_state_ids[room_id].get(
-                    ("m.room.member", invite_event.sender)
+        # See if one of the notifs is an invite event for the user
+        invite_event = None
+        for n in notifs:
+            ev = notif_events[n["event_id"]]
+            if ev.type == EventTypes.Member and ev.state_key == user_id:
+                if ev.content.get("membership") == Membership.INVITE:
+                    invite_event = ev
+                    break
+
+        if invite_event:
+            inviter_member_event_id = room_state_ids.get(
+                ("m.room.member", invite_event.sender)
+            )
+            inviter_name = invite_event.sender
+            if inviter_member_event_id:
+                inviter_member_event = await self.store.get_event(
+                    inviter_member_event_id, allow_none=True
                 )
-                inviter_name = invite_event.sender
-                if inviter_member_event_id:
-                    inviter_member_event = await self.store.get_event(
-                        inviter_member_event_id, allow_none=True
-                    )
-                    if inviter_member_event:
-                        inviter_name = name_from_member_event(inviter_member_event)
-
-                if room_name is None:
-                    return self.email_subjects.invite_from_person % {
-                        "person": inviter_name,
-                        "app": self.app_name,
-                    }
-                else:
-                    return self.email_subjects.invite_from_person_to_room % {
-                        "person": inviter_name,
-                        "room": room_name,
-                        "app": self.app_name,
-                    }
+                if inviter_member_event:
+                    inviter_name = name_from_member_event(inviter_member_event)
+
+            if room_name is None:
+                return self.email_subjects.invite_from_person % {
+                    "person": inviter_name,
+                    "app": self.app_name,
+                }
+
+            return self.email_subjects.invite_from_person_to_room % {
+                "person": inviter_name,
+                "room": room_name,
+                "app": self.app_name,
+            }
 
+        if len(notifs) == 1:
+            # There is just the one notification, so give some detail
             sender_name = None
-            if len(notifs_by_room[room_id]) == 1:
-                # There is just the one notification, so give some detail
-                event = notif_events[notifs_by_room[room_id][0]["event_id"]]
-                if ("m.room.member", event.sender) in room_state_ids[room_id]:
-                    state_event_id = room_state_ids[room_id][
-                        ("m.room.member", event.sender)
-                    ]
-                    state_event = await self.store.get_event(state_event_id)
-                    sender_name = name_from_member_event(state_event)
-
-                if sender_name is not None and room_name is not None:
-                    return self.email_subjects.message_from_person_in_room % {
-                        "person": sender_name,
-                        "room": room_name,
-                        "app": self.app_name,
-                    }
-                elif sender_name is not None:
-                    return self.email_subjects.message_from_person % {
-                        "person": sender_name,
-                        "app": self.app_name,
-                    }
-            else:
-                # There's more than one notification for this room, so just
-                # say there are several
-                if room_name is not None:
-                    return self.email_subjects.messages_in_room % {
-                        "room": room_name,
-                        "app": self.app_name,
-                    }
-                else:
-                    # If the room doesn't have a name, say who the messages
-                    # are from explicitly to avoid, "messages in the Bob room"
-                    sender_ids = list(
-                        {
-                            notif_events[n["event_id"]].sender
-                            for n in notifs_by_room[room_id]
-                        }
-                    )
-
-                    member_events = await self.store.get_events(
-                        [
-                            room_state_ids[room_id][("m.room.member", s)]
-                            for s in sender_ids
-                        ]
-                    )
-
-                    return self.email_subjects.messages_from_person % {
-                        "person": descriptor_from_member_events(member_events.values()),
-                        "app": self.app_name,
-                    }
-        else:
-            # Stuff's happened in multiple different rooms
+            event = notif_events[notifs[0]["event_id"]]
+            if ("m.room.member", event.sender) in room_state_ids:
+                state_event_id = room_state_ids[("m.room.member", event.sender)]
+                state_event = await self.store.get_event(state_event_id)
+                sender_name = name_from_member_event(state_event)
+
+            if sender_name is not None and room_name is not None:
+                return self.email_subjects.message_from_person_in_room % {
+                    "person": sender_name,
+                    "room": room_name,
+                    "app": self.app_name,
+                }
+            elif sender_name is not None:
+                return self.email_subjects.message_from_person % {
+                    "person": sender_name,
+                    "app": self.app_name,
+                }
 
-            # ...but we still refer to the 'reason' room which triggered the mail
-            if reason["room_name"] is not None:
-                return self.email_subjects.messages_in_room_and_others % {
-                    "room": reason["room_name"],
+            # The sender is unknown, just use the room name (or ID).
+            return self.email_subjects.messages_in_room % {
+                "room": room_name or room_id,
+                "app": self.app_name,
+            }
+        else:
+            # There's more than one notification for this room, so just
+            # say there are several
+            if room_name is not None:
+                return self.email_subjects.messages_in_room % {
+                    "room": room_name,
                     "app": self.app_name,
                 }
+
+            return await self._make_summary_text_from_member_events(
+                room_id, notifs, room_state_ids, notif_events
+            )
+
+    async def _make_summary_text(
+        self,
+        notifs_by_room: Dict[str, List[Dict[str, Any]]],
+        room_state_ids: Dict[str, StateMap[str]],
+        notif_events: Dict[str, EventBase],
+        reason: Dict[str, Any],
+    ) -> str:
+        """
+        Make a summary text for the email when multiple rooms have notifications.
+
+        Args:
+            notifs_by_room: A map of room ID to the push actions for that room.
+            room_state_ids: A map of room ID to the state map for that room.
+            notif_events: A map of event ID -> notification event.
+            reason: The reason this notification is being sent.
+
+        Returns:
+            The summary text.
+        """
+        # Stuff's happened in multiple different rooms
+        # ...but we still refer to the 'reason' room which triggered the mail
+        if reason["room_name"] is not None:
+            return self.email_subjects.messages_in_room_and_others % {
+                "room": reason["room_name"],
+                "app": self.app_name,
+            }
+
+        room_id = reason["room_id"]
+        return await self._make_summary_text_from_member_events(
+            room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events
+        )
+
+    async def _make_summary_text_from_member_events(
+        self,
+        room_id: str,
+        notifs: List[Dict[str, Any]],
+        room_state_ids: StateMap[str],
+        notif_events: Dict[str, EventBase],
+    ) -> str:
+        """
+        Make a summary text for the email when only a single room has notifications.
+
+        Args:
+            room_id: The ID of the room.
+            notifs: The push actions for this room.
+            room_state_ids: The state map for the room.
+            notif_events: A map of event ID -> notification event.
+
+        Returns:
+            The summary text.
+        """
+        # If the room doesn't have a name, say who the messages
+        # are from explicitly to avoid, "messages in the Bob room"
+
+        # Find the latest event ID for each sender, note that the notifications
+        # are already in descending received_ts.
+        sender_ids = {}
+        for n in notifs:
+            sender = notif_events[n["event_id"]].sender
+            if sender not in sender_ids:
+                sender_ids[sender] = n["event_id"]
+
+        # Get the actual member events (in order to calculate a pretty name for
+        # the room).
+        member_event_ids = []
+        member_events = {}
+        for sender_id, event_id in sender_ids.items():
+            type_state_key = ("m.room.member", sender_id)
+            sender_state_event_id = room_state_ids.get(type_state_key)
+            if sender_state_event_id:
+                member_event_ids.append(sender_state_event_id)
             else:
-                # If the reason room doesn't have a name, say who the messages
-                # are from explicitly to avoid, "messages in the Bob room"
-                room_id = reason["room_id"]
-
-                sender_ids = list(
-                    {
-                        notif_events[n["event_id"]].sender
-                        for n in notifs_by_room[room_id]
-                    }
+                # Attempt to check the historical state for the room.
+                historical_state = await self.state_store.get_state_for_event(
+                    event_id, StateFilter.from_types((type_state_key,))
                 )
+                sender_state_event = historical_state.get(type_state_key)
+                if sender_state_event:
+                    member_events[event_id] = sender_state_event
+        member_events.update(await self.store.get_events(member_event_ids))
+
+        if not member_events:
+            # No member events were found! Maybe the room is empty?
+            # Fallback to the room ID (note that if there was a room name this
+            # would already have been used previously).
+            return self.email_subjects.messages_in_room % {
+                "room": room_id,
+                "app": self.app_name,
+            }
+
+        # There was a single sender.
+        if len(member_events) == 1:
+            return self.email_subjects.messages_from_person % {
+                "person": descriptor_from_member_events(member_events.values()),
+                "app": self.app_name,
+            }
+
+        # There was more than one sender, use the first one and a tweaked template.
+        return self.email_subjects.messages_from_person_and_others % {
+            "person": descriptor_from_member_events(list(member_events.values())[:1]),
+            "app": self.app_name,
+        }
 
-                member_events = await self.store.get_events(
-                    [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
-                )
+    def _make_room_link(self, room_id: str) -> str:
+        """
+        Generate a link to open a room in the web client.
 
-                return self.email_subjects.messages_from_person_and_others % {
-                    "person": descriptor_from_member_events(member_events.values()),
-                    "app": self.app_name,
-                }
+        Args:
+            room_id: The room ID to generate a link to.
 
-    def make_room_link(self, room_id: str) -> str:
+        Returns:
+             A link to open a room in the web client.
+        """
         if self.hs.config.email_riot_base_url:
             base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
         elif self.app_name == "Vector":
@@ -635,7 +805,16 @@ class Mailer:
             base_url = "https://matrix.to/#"
         return "%s/%s" % (base_url, room_id)
 
-    def make_notif_link(self, notif: Dict[str, str]) -> str:
+    def _make_notif_link(self, notif: Dict[str, str]) -> str:
+        """
+        Generate a link to open an event in the web client.
+
+        Args:
+            notif: The notification to generate a link for.
+
+        Returns:
+             A link to open the notification in the web client.
+        """
         if self.hs.config.email_riot_base_url:
             return "%s/#/room/%s/%s" % (
                 self.hs.config.email_riot_base_url,
@@ -651,9 +830,20 @@ class Mailer:
         else:
             return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
 
-    def make_unsubscribe_link(
+    def _make_unsubscribe_link(
         self, user_id: str, app_id: str, email_address: str
     ) -> str:
+        """
+        Generate a link to unsubscribe from email notifications.
+
+        Args:
+            user_id: The user receiving the notification.
+            app_id: The application receiving the notification.
+            email_address: The email address receiving the notification.
+
+        Returns:
+             A link to unsubscribe from email notifications.
+        """
         params = {
             "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
             "app_id": app_id,
@@ -668,6 +858,15 @@ class Mailer:
 
 
 def safe_markup(raw_html: str) -> jinja2.Markup:
+    """
+    Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
+
+    Args
+        raw_html: Unsafe HTML.
+
+    Returns:
+        A Markup object ready to safely use in a Jinja template.
+    """
     return jinja2.Markup(
         bleach.linkify(
             bleach.clean(
@@ -684,8 +883,13 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
 
 def safe_text(raw_text: str) -> jinja2.Markup:
     """
-    Process text: treat it as HTML but escape any tags (ie. just escape the
-    HTML) then linkify it.
+    Sanitise text (escape any HTML tags), and then linkify any bare URLs.
+
+    Args
+        raw_text: Unsafe text which might include HTML markup.
+
+    Returns:
+        A Markup object ready to safely use in a Jinja template.
     """
     return jinja2.Markup(
         bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 7e50341d74..04c2c1482c 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -17,7 +17,7 @@ import logging
 import re
 from typing import TYPE_CHECKING, Dict, Iterable, Optional
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.types import StateMap
 
@@ -63,7 +63,7 @@ async def calculate_room_name(
         m_room_name = await store.get_event(
             room_state_ids[(EventTypes.Name, "")], allow_none=True
         )
-        if m_room_name and m_room_name.content and m_room_name.content["name"]:
+        if m_room_name and m_room_name.content and m_room_name.content.get("name"):
             return m_room_name.content["name"]
 
     # does it have a canonical alias?
@@ -74,15 +74,11 @@ async def calculate_room_name(
         if (
             canon_alias
             and canon_alias.content
-            and canon_alias.content["alias"]
+            and canon_alias.content.get("alias")
             and _looks_like_an_alias(canon_alias.content["alias"])
         ):
             return canon_alias.content["alias"]
 
-    # at this point we're going to need to search the state by all state keys
-    # for an event type, so rearrange the data structure
-    room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
-
     if not fallback_to_members:
         return None
 
@@ -94,7 +90,7 @@ async def calculate_room_name(
 
     if (
         my_member_event is not None
-        and my_member_event.content["membership"] == "invite"
+        and my_member_event.content.get("membership") == Membership.INVITE
     ):
         if (EventTypes.Member, my_member_event.sender) in room_state_ids:
             inviter_member_event = await store.get_event(
@@ -111,6 +107,10 @@ async def calculate_room_name(
         else:
             return "Room Invite"
 
+    # at this point we're going to need to search the state by all state keys
+    # for an event type, so rearrange the data structure
+    room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
+
     # we're going to have to generate a name based on who's in the room,
     # so find out who is in the room that isn't the user.
     if EventTypes.Member in room_state_bytype_ids:
@@ -120,8 +120,8 @@ async def calculate_room_name(
         all_members = [
             ev
             for ev in member_events.values()
-            if ev.content["membership"] == "join"
-            or ev.content["membership"] == "invite"
+            if ev.content.get("membership") == Membership.JOIN
+            or ev.content.get("membership") == Membership.INVITE
         ]
         # Sort the member events oldest-first so the we name people in the
         # order the joined (it should at least be deterministic rather than
@@ -194,11 +194,7 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
 
 
 def name_from_member_event(member_event: EventBase) -> str:
-    if (
-        member_event.content
-        and "displayname" in member_event.content
-        and member_event.content["displayname"]
-    ):
+    if member_event.content and member_event.content.get("displayname"):
         return member_event.content["displayname"]
     return member_event.state_key
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index eed16dbfb5..ae1145be0e 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -78,8 +78,7 @@ class PusherPool:
         self.pushers = {}  # type: Dict[str, Dict[str, Pusher]]
 
     def start(self) -> None:
-        """Starts the pushers off in a background process.
-        """
+        """Starts the pushers off in a background process."""
         if not self._should_start_pushers:
             logger.info("Not starting pushers because they are disabled in the config")
             return
@@ -297,8 +296,7 @@ class PusherPool:
         return pusher
 
     async def _start_pushers(self) -> None:
-        """Start all the pushers
-        """
+        """Start all the pushers"""
         pushers = await self.store.get_all_pushers()
 
         # Stagger starting up the pushers so we don't completely drown the
@@ -335,7 +333,8 @@ class PusherPool:
             return None
         except Exception:
             logger.exception(
-                "Couldn't start pusher id %i: caught Exception", pusher_config.id,
+                "Couldn't start pusher id %i: caught Exception",
+                pusher_config.id,
             )
             return None
 
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index bfd46a3730..8a2b73b75e 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -86,8 +86,12 @@ REQUIREMENTS = [
 
 CONDITIONAL_REQUIREMENTS = {
     "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
-    # we use execute_values with the fetch param, which arrived in psycopg 2.8.
-    "postgres": ["psycopg2>=2.8"],
+    "postgres": [
+        # we use execute_values with the fetch param, which arrived in psycopg 2.8.
+        "psycopg2>=2.8 ; platform_python_implementation != 'PyPy'",
+        "psycopg2cffi>=2.8 ; platform_python_implementation == 'PyPy'",
+        "psycopg2cffi-compat==1.1 ; platform_python_implementation == 'PyPy'",
+    ],
     # ACME support is required to provision TLS certificates from authorities
     # that use the protocol, such as Let's Encrypt.
     "acme": [
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 288727a566..8a3f113e76 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -273,7 +273,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
 
         http_server.register_paths(
-            method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
+            method,
+            [pattern],
+            self._check_auth_and_handle,
+            self.__class__.__name__,
         )
 
     def _check_auth_and_handle(self, request, **kwargs):
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
index 52d32528ee..60899b6ad6 100644
--- a/synapse/replication/http/account_data.py
+++ b/synapse/replication/http/account_data.py
@@ -175,7 +175,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
         return {}
 
     async def _handle_request(self, request, user_id, room_id, tag):
-        max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
+        max_stream_id = await self.handler.remove_tag_from_room(
+            user_id,
+            room_id,
+            tag,
+        )
 
         return 200, {"max_stream_id": max_stream_id}
 
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 84e002f934..439881be67 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -160,7 +160,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
 
         # hopefully we're now on the master, so this won't recurse!
         event_id, stream_id = await self.member_handler.remote_reject_invite(
-            invite_event_id, txn_id, requester, event_content,
+            invite_event_id,
+            txn_id,
+            requester,
+            event_content,
         )
 
         return 200, {"event_id": event_id, "stream_id": stream_id}
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 7b12ec9060..d005f38767 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -22,8 +22,7 @@ logger = logging.getLogger(__name__)
 
 
 class ReplicationRegisterServlet(ReplicationEndpoint):
-    """Register a new user
-    """
+    """Register a new user"""
 
     NAME = "register_user"
     PATH_ARGS = ("user_id",)
@@ -97,8 +96,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
 
 
 class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
-    """Run any post registration actions
-    """
+    """Run any post registration actions"""
 
     NAME = "post_register"
     PATH_ARGS = ("user_id",)
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index ac532ed588..0a9da79c32 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -196,8 +196,7 @@ class ErrorCommand(_SimpleCommand):
 
 
 class PingCommand(_SimpleCommand):
-    """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
-    """
+    """Sent by either side as a keep alive. The data is arbitrary (often timestamp)"""
 
     NAME = "PING"
 
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
new file mode 100644
index 0000000000..d89a36f25a
--- /dev/null
+++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from prometheus_client import Counter
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import json_decoder, json_encoder
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+set_counter = Counter(
+    "synapse_external_cache_set",
+    "Number of times we set a cache",
+    labelnames=["cache_name"],
+)
+
+get_counter = Counter(
+    "synapse_external_cache_get",
+    "Number of times we get a cache",
+    labelnames=["cache_name", "hit"],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalCache:
+    """A cache backed by an external Redis. Does nothing if no Redis is
+    configured.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._redis_connection = hs.get_outbound_redis_connection()
+
+    def _get_redis_key(self, cache_name: str, key: str) -> str:
+        return "cache_v1:%s:%s" % (cache_name, key)
+
+    def is_enabled(self) -> bool:
+        """Whether the external cache is used or not.
+
+        It's safe to use the cache when this returns false, the methods will
+        just no-op, but the function is useful to avoid doing unnecessary work.
+        """
+        return self._redis_connection is not None
+
+    async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+        """Add the key/value to the named cache, with the expiry time given."""
+
+        if self._redis_connection is None:
+            return
+
+        set_counter.labels(cache_name).inc()
+
+        # txredisapi requires the value to be string, bytes or numbers, so we
+        # encode stuff in JSON.
+        encoded_value = json_encoder.encode(value)
+
+        logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
+
+        return await make_deferred_yieldable(
+            self._redis_connection.set(
+                self._get_redis_key(cache_name, key),
+                encoded_value,
+                pexpire=expiry_ms,
+            )
+        )
+
+    async def get(self, cache_name: str, key: str) -> Optional[Any]:
+        """Look up a key/value in the named cache."""
+
+        if self._redis_connection is None:
+            return None
+
+        result = await make_deferred_yieldable(
+            self._redis_connection.get(self._get_redis_key(cache_name, key))
+        )
+
+        logger.debug("Got cache result %s %s: %r", cache_name, key, result)
+
+        get_counter.labels(cache_name, result is not None).inc()
+
+        if not result:
+            return None
+
+        # For some reason the integers get magically converted back to integers
+        if isinstance(result, int):
+            return result
+
+        return json_decoder.decode(result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 317796d5e0..d1d00c3717 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 import logging
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Dict,
@@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
     TypingStream,
 )
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -88,7 +92,7 @@ class ReplicationCommandHandler:
     back out to connections.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._replication_data_handler = hs.get_replication_data_handler()
         self._presence_handler = hs.get_presence_handler()
         self._store = hs.get_datastore()
@@ -282,13 +286,6 @@ class ReplicationCommandHandler:
         if hs.config.redis.redis_enabled:
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
-                lazyConnection,
-            )
-
-            logger.info(
-                "Connecting to redis (host=%r port=%r)",
-                hs.config.redis_host,
-                hs.config.redis_port,
             )
 
             # First let's ensure that we have a ReplicationStreamer started.
@@ -299,20 +296,16 @@ class ReplicationCommandHandler:
             # connection after SUBSCRIBE is called).
 
             # First create the connection for sending commands.
-            outbound_redis_connection = lazyConnection(
-                reactor=hs.get_reactor(),
-                host=hs.config.redis_host,
-                port=hs.config.redis_port,
-                password=hs.config.redis.redis_password,
-                reconnect=True,
-            )
+            outbound_redis_connection = hs.get_outbound_redis_connection()
 
             # Now create the factory/connection for the subscription stream.
             self._factory = RedisDirectTcpReplicationClientFactory(
                 hs, outbound_redis_connection
             )
             hs.get_reactor().connectTCP(
-                hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
+                hs.config.redis.redis_host,
+                hs.config.redis.redis_port,
+                self._factory,
             )
         else:
             client_name = hs.get_instance_name()
@@ -322,13 +315,11 @@ class ReplicationCommandHandler:
             hs.get_reactor().connectTCP(host, port, self._factory)
 
     def get_streams(self) -> Dict[str, Stream]:
-        """Get a map from stream name to all streams.
-        """
+        """Get a map from stream name to all streams."""
         return self._streams
 
     def get_streams_to_replicate(self) -> List[Stream]:
-        """Get a list of streams that this instances replicates.
-        """
+        """Get a list of streams that this instances replicates."""
         return self._streams_to_replicate
 
     def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
@@ -349,7 +340,10 @@ class ReplicationCommandHandler:
             current_token = stream.current_token(self._instance_name)
             self.send_command(
                 PositionCommand(
-                    stream.NAME, self._instance_name, current_token, current_token,
+                    stream.NAME,
+                    self._instance_name,
+                    current_token,
+                    current_token,
                 )
             )
 
@@ -601,8 +595,7 @@ class ReplicationCommandHandler:
         self.send_command(cmd, ignore_conn=conn)
 
     def new_connection(self, connection: AbstractConnection):
-        """Called when we have a new connection.
-        """
+        """Called when we have a new connection."""
         self._connections.append(connection)
 
         # If we are connected to replication as a client (rather than a server)
@@ -629,8 +622,7 @@ class ReplicationCommandHandler:
             )
 
     def lost_connection(self, connection: AbstractConnection):
-        """Called when a connection is closed/lost.
-        """
+        """Called when a connection is closed/lost."""
         # we no longer need _streams_by_connection for this connection.
         streams = self._streams_by_connection.pop(connection, None)
         if streams:
@@ -687,15 +679,13 @@ class ReplicationCommandHandler:
     def send_user_sync(
         self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
     ):
-        """Poke the master that a user has started/stopped syncing.
-        """
+        """Poke the master that a user has started/stopped syncing."""
         self.send_command(
             UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
         )
 
     def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
-        """Poke the master to remove a pusher for a user
-        """
+        """Poke the master to remove a pusher for a user"""
         cmd = RemovePusherCommand(app_id, push_key, user_id)
         self.send_command(cmd)
 
@@ -708,8 +698,7 @@ class ReplicationCommandHandler:
         device_id: str,
         last_seen: int,
     ):
-        """Tell the master that the user made a request.
-        """
+        """Tell the master that the user made a request."""
         cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
         self.send_command(cmd)
 
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 804da994ea..e0b4ad314d 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -222,8 +222,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 self.send_error("ping timeout")
 
     def lineReceived(self, line: bytes):
-        """Called when we've received a line
-        """
+        """Called when we've received a line"""
         with PreserveLoggingContext(self._logging_context):
             self._parse_and_dispatch_line(line)
 
@@ -299,8 +298,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.on_connection_closed()
 
     def send_error(self, error_string, *args):
-        """Send an error to remote and close the connection.
-        """
+        """Send an error to remote and close the connection."""
         self.send_command(ErrorCommand(error_string % args))
         self.close()
 
@@ -341,8 +339,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.last_sent_command = self.clock.time_msec()
 
     def _queue_command(self, cmd):
-        """Queue the command until the connection is ready to write to again.
-        """
+        """Queue the command until the connection is ready to write to again."""
         logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
         self.pending_commands.append(cmd)
 
@@ -355,8 +352,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             self.close()
 
     def _send_pending_commands(self):
-        """Send any queued commandes
-        """
+        """Send any queued commandes"""
         pending = self.pending_commands
         self.pending_commands = []
         for cmd in pending:
@@ -380,8 +376,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.state = ConnectionStates.PAUSED
 
     def resumeProducing(self):
-        """The remote has caught up after we started buffering!
-        """
+        """The remote has caught up after we started buffering!"""
         logger.info("[%s] Resume producing", self.id())
         self.state = ConnectionStates.ESTABLISHED
         self._send_pending_commands()
@@ -440,8 +435,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         return "%s-%s" % (self.name, self.conn_id)
 
     def lineLengthExceeded(self, line):
-        """Called when we receive a line that is above the maximum line length
-        """
+        """Called when we receive a line that is above the maximum line length"""
         self.send_error("Line length exceeded")
 
 
@@ -495,21 +489,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             self.send_error("Wrong remote")
 
     def replicate(self):
-        """Send the subscription request to the server
-        """
+        """Send the subscription request to the server"""
         logger.info("[%s] Subscribing to replication streams", self.id())
 
         self.send_command(ReplicateCommand())
 
 
 class AbstractConnection(abc.ABC):
-    """An interface for replication connections.
-    """
+    """An interface for replication connections."""
 
     @abc.abstractmethod
     def send_command(self, cmd: Command):
-        """Send the command down the connection
-        """
+        """Send the command down the connection"""
         pass
 
 
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..0e6155cf53 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,14 +15,16 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
 
+import attr
 import txredisapi
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import (
     BackgroundProcessLoggingContext,
     run_as_background_process,
+    wrap_as_background_process,
 )
 from synapse.replication.tcp.commands import (
     Command,
@@ -41,6 +43,24 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+T = TypeVar("T")
+V = TypeVar("V")
+
+
+@attr.s
+class ConstantProperty(Generic[T, V]):
+    """A descriptor that returns the given constant, ignoring attempts to set
+    it.
+    """
+
+    constant = attr.ib()  # type: V
+
+    def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V:
+        return self.constant
+
+    def __set__(self, obj: Optional[T], value: V):
+        pass
+
 
 class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     """Connection to redis subscribed to replication stream.
@@ -59,16 +79,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     immediately after initialisation.
 
     Attributes:
-        handler: The command handler to handle incoming commands.
-        stream_name: The *redis* stream name to subscribe to and publish from
-            (not anything to do with Synapse replication streams).
-        outbound_redis_connection: The connection to redis to use to send
+        synapse_handler: The command handler to handle incoming commands.
+        synapse_stream_name: The *redis* stream name to subscribe to and publish
+            from (not anything to do with Synapse replication streams).
+        synapse_outbound_redis_connection: The connection to redis to use to send
             commands.
     """
 
-    handler = None  # type: ReplicationCommandHandler
-    stream_name = None  # type: str
-    outbound_redis_connection = None  # type: txredisapi.RedisProtocol
+    synapse_handler = None  # type: ReplicationCommandHandler
+    synapse_stream_name = None  # type: str
+    synapse_outbound_redis_connection = None  # type: txredisapi.RedisProtocol
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -88,23 +108,22 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # it's important to make sure that we only send the REPLICATE command once we
         # have successfully subscribed to the stream - otherwise we might miss the
         # POSITION response sent back by the other end.
-        logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
-        await make_deferred_yieldable(self.subscribe(self.stream_name))
+        logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
+        await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
         logger.info(
             "Successfully subscribed to redis stream, sending REPLICATE command"
         )
-        self.handler.new_connection(self)
+        self.synapse_handler.new_connection(self)
         await self._async_send_command(ReplicateCommand())
         logger.info("REPLICATE successfully sent")
 
         # We send out our positions when there is a new connection in case the
         # other side missed updates. We do this for Redis connections as the
         # otherside won't know we've connected and so won't issue a REPLICATE.
-        self.handler.send_positions_to_connection(self)
+        self.synapse_handler.send_positions_to_connection(self)
 
     def messageReceived(self, pattern: str, channel: str, message: str):
-        """Received a message from redis.
-        """
+        """Received a message from redis."""
         with PreserveLoggingContext(self._logging_context):
             self._parse_and_dispatch_message(message)
 
@@ -117,7 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
             cmd = parse_command_from_line(message)
         except Exception:
             logger.exception(
-                "Failed to parse replication line: %r", message,
+                "Failed to parse replication line: %r",
+                message,
             )
             return
 
@@ -137,7 +157,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
             cmd: received command
         """
 
-        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
         if not cmd_func:
             logger.warning("Unhandled command: %r", cmd)
             return
@@ -155,7 +175,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     def connectionLost(self, reason):
         logger.info("Lost connection to redis")
         super().connectionLost(reason)
-        self.handler.lost_connection(self)
+        self.synapse_handler.lost_connection(self)
 
         # mark the logging context as finished
         self._logging_context.__exit__(None, None, None)
@@ -183,11 +203,58 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
         await make_deferred_yieldable(
-            self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+            self.synapse_outbound_redis_connection.publish(
+                self.synapse_stream_name, encoded_string
+            )
         )
 
 
-class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+class SynapseRedisFactory(txredisapi.RedisFactory):
+    """A subclass of RedisFactory that periodically sends pings to ensure that
+    we detect dead connections.
+    """
+
+    # We want to *always* retry connecting, txredisapi will stop if there is a
+    # failure during certain operations, e.g. during AUTH.
+    continueTrying = cast(bool, ConstantProperty(True))
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        uuid: str,
+        dbid: Optional[int],
+        poolsize: int,
+        isLazy: bool = False,
+        handler: Type = txredisapi.ConnectionHandler,
+        charset: str = "utf-8",
+        password: Optional[str] = None,
+        replyTimeout: int = 30,
+        convertNumbers: Optional[int] = True,
+    ):
+        super().__init__(
+            uuid=uuid,
+            dbid=dbid,
+            poolsize=poolsize,
+            isLazy=isLazy,
+            handler=handler,
+            charset=charset,
+            password=password,
+            replyTimeout=replyTimeout,
+            convertNumbers=convertNumbers,
+        )
+
+        hs.get_clock().looping_call(self._send_ping, 30 * 1000)
+
+    @wrap_as_background_process("redis_ping")
+    async def _send_ping(self):
+        for connection in self.pool:
+            try:
+                await make_deferred_yieldable(connection.ping())
+            except Exception:
+                logger.warning("Failed to send ping to a redis connection")
+
+
+class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
     """This is a reconnecting factory that connects to redis and immediately
     subscribes to a stream.
 
@@ -199,72 +266,68 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
     """
 
     maxDelay = 5
-    continueTrying = True
     protocol = RedisSubscriber
 
     def __init__(
         self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
     ):
 
-        super().__init__()
-
-        # This sets the password on the RedisFactory base class (as
-        # SubscriberFactory constructor doesn't pass it through).
-        self.password = hs.config.redis.redis_password
+        super().__init__(
+            hs,
+            uuid="subscriber",
+            dbid=None,
+            poolsize=1,
+            replyTimeout=30,
+            password=hs.config.redis.redis_password,
+        )
 
-        self.handler = hs.get_tcp_replication()
-        self.stream_name = hs.hostname
+        self.synapse_handler = hs.get_tcp_replication()
+        self.synapse_stream_name = hs.hostname
 
-        self.outbound_redis_connection = outbound_redis_connection
+        self.synapse_outbound_redis_connection = outbound_redis_connection
 
     def buildProtocol(self, addr):
-        p = super().buildProtocol(addr)  # type: RedisSubscriber
+        p = super().buildProtocol(addr)
+        p = cast(RedisSubscriber, p)
 
         # We do this here rather than add to the constructor of `RedisSubcriber`
         # as to do so would involve overriding `buildProtocol` entirely, however
         # the base method does some other things than just instantiating the
         # protocol.
-        p.handler = self.handler
-        p.outbound_redis_connection = self.outbound_redis_connection
-        p.stream_name = self.stream_name
-        p.password = self.password
+        p.synapse_handler = self.synapse_handler
+        p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
+        p.synapse_stream_name = self.synapse_stream_name
 
         return p
 
 
 def lazyConnection(
-    reactor,
+    hs: "HomeServer",
     host: str = "localhost",
     port: int = 6379,
     dbid: Optional[int] = None,
     reconnect: bool = True,
-    charset: str = "utf-8",
     password: Optional[str] = None,
-    connectTimeout: Optional[int] = None,
-    replyTimeout: Optional[int] = None,
-    convertNumbers: bool = True,
+    replyTimeout: int = 30,
 ) -> txredisapi.RedisProtocol:
-    """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
-    reactor.
+    """Creates a connection to Redis that is lazily set up and reconnects if the
+    connections is lost.
     """
 
-    isLazy = True
-    poolsize = 1
-
     uuid = "%s:%d" % (host, port)
-    factory = txredisapi.RedisFactory(
-        uuid,
-        dbid,
-        poolsize,
-        isLazy,
-        txredisapi.ConnectionHandler,
-        charset,
-        password,
-        replyTimeout,
-        convertNumbers,
+    factory = SynapseRedisFactory(
+        hs,
+        uuid=uuid,
+        dbid=dbid,
+        poolsize=1,
+        isLazy=True,
+        handler=txredisapi.ConnectionHandler,
+        password=password,
+        replyTimeout=replyTimeout,
     )
     factory.continueTrying = reconnect
-    for x in range(poolsize):
-        reactor.connectTCP(host, port, factory, connectTimeout)
+
+    reactor = hs.get_reactor()
+    reactor.connectTCP(host, port, factory, 30)
 
     return factory.handler
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 1d4ceac0f1..2018f9f29e 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -36,8 +36,7 @@ logger = logging.getLogger(__name__)
 
 
 class ReplicationStreamProtocolFactory(Factory):
-    """Factory for new replication connections.
-    """
+    """Factory for new replication connections."""
 
     def __init__(self, hs):
         self.command_handler = hs.get_tcp_replication()
@@ -181,7 +180,8 @@ class ReplicationStreamer:
                             raise
 
                         logger.debug(
-                            "Sending %d updates", len(updates),
+                            "Sending %d updates",
+                            len(updates),
                         )
 
                         if updates:
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 61b282ab2d..38809b5b7c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -183,7 +183,10 @@ class Stream:
             return [], upto_token, False
 
         updates, upto_token, limited = await self.update_function(
-            instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+            instance_name,
+            from_token,
+            upto_token,
+            _STREAM_UPDATE_TARGET_ROW_COUNT,
         )
         return updates, upto_token, limited
 
@@ -339,8 +342,7 @@ class ReceiptsStream(Stream):
 
 
 class PushRulesStream(Stream):
-    """A user has changed their push rules
-    """
+    """A user has changed their push rules"""
 
     PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
 
@@ -362,8 +364,7 @@ class PushRulesStream(Stream):
 
 
 class PushersStream(Stream):
-    """A user has added/changed/removed a pusher
-    """
+    """A user has added/changed/removed a pusher"""
 
     PushersStreamRow = namedtuple(
         "PushersStreamRow",
@@ -416,8 +417,7 @@ class CachesStream(Stream):
 
 
 class PublicRoomsStream(Stream):
-    """The public rooms list changed
-    """
+    """The public rooms list changed"""
 
     PublicRoomsStreamRow = namedtuple(
         "PublicRoomsStreamRow",
@@ -463,8 +463,7 @@ class DeviceListsStream(Stream):
 
 
 class ToDeviceStream(Stream):
-    """New to_device messages for a client
-    """
+    """New to_device messages for a client"""
 
     ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
 
@@ -481,8 +480,7 @@ class ToDeviceStream(Stream):
 
 
 class TagAccountDataStream(Stream):
-    """Someone added/removed a tag for a room
-    """
+    """Someone added/removed a tag for a room"""
 
     TagAccountDataStreamRow = namedtuple(
         "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
@@ -501,8 +499,7 @@ class TagAccountDataStream(Stream):
 
 
 class AccountDataStream(Stream):
-    """Global or per room account data was changed
-    """
+    """Global or per room account data was changed"""
 
     AccountDataStreamRow = namedtuple(
         "AccountDataStream",
@@ -589,8 +586,7 @@ class GroupServerStream(Stream):
 
 
 class UserSignatureStream(Stream):
-    """A user has signed their own device with their user-signing key
-    """
+    """A user has signed their own device with their user-signing key"""
 
     UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
 
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 86a62b71eb..fa5e37ba7b 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -113,8 +113,7 @@ TypeToRow = {Row.TypeId: Row for Row in _EventRows}
 
 
 class EventsStream(Stream):
-    """We received a new event, or an event went from being an outlier to not
-    """
+    """We received a new event, or an event went from being an outlier to not"""
 
     NAME = "events"
 
diff --git a/synapse/res/templates/sso.css b/synapse/res/templates/sso.css
new file mode 100644
index 0000000000..338214f5d0
--- /dev/null
+++ b/synapse/res/templates/sso.css
@@ -0,0 +1,129 @@
+body, input, select, textarea {
+  font-family: "Inter", "Helvetica", "Arial", sans-serif;
+  font-size: 14px;
+  color: #17191C;
+}
+
+header, footer {
+  max-width: 480px;
+  width: 100%;
+  margin: 24px auto;
+  text-align: center;
+}
+
+@media screen and (min-width: 800px) {
+  header {
+    margin-top: 90px;
+  }
+}
+
+header {
+  min-height: 60px;
+}
+
+header p {
+  color: #737D8C;
+  line-height: 24px;
+}
+
+h1 {
+  font-size: 24px;
+}
+
+a {
+  color: #418DED;
+}
+
+.error_page h1 {
+  color: #FE2928;
+}
+
+h2 {
+  font-size: 14px;
+}
+
+h2 img {
+  vertical-align: middle;
+  margin-right: 8px;
+  width: 24px;
+  height: 24px;
+}
+
+label {
+  cursor: pointer;
+}
+
+main {
+  max-width: 360px;
+  width: 100%;
+  margin: 24px auto;
+}
+
+.primary-button {
+  border: none;
+  -webkit-appearance: none;
+  -moz-appearance: none;
+  appearance: none;
+  text-decoration: none;
+  padding: 12px;
+  color: white;
+  background-color: #418DED;
+  font-weight: bold;
+  display: block;
+  border-radius: 12px;
+  width: 100%;
+  box-sizing: border-box;
+  margin: 16px 0;
+  cursor: pointer;
+  text-align: center;
+}
+
+.profile {
+  display: flex;
+  flex-direction: column;
+  align-items: center;
+  justify-content: center;
+  margin: 24px;
+  padding: 13px;
+  border: 1px solid #E9ECF1;
+  border-radius: 4px;
+}
+
+.profile.with-avatar {
+  margin-top: 42px; /* (36px / 2) + 24px*/
+}
+
+.profile .avatar {
+  width: 36px;
+  height: 36px;
+  border-radius: 100%;
+  display: block;
+  margin-top: -32px;
+  margin-bottom: 8px;
+}
+
+.profile .display-name {
+  font-weight: bold;
+  margin-bottom: 4px;
+  font-size: 15px;
+  line-height: 18px;
+}
+.profile .user-id {
+  color: #737D8C;
+  font-size: 12px;
+  line-height: 12px;
+}
+
+footer {
+  margin-top: 80px;
+}
+
+footer svg {
+  display: block;
+  width: 46px;
+  margin: 0px auto 12px auto;
+}
+
+footer p {
+  color: #737D8C;
+}
\ No newline at end of file
diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html
index 4eb8db9fb4..c3e4deed93 100644
--- a/synapse/res/templates/sso_account_deactivated.html
+++ b/synapse/res/templates/sso_account_deactivated.html
@@ -1,10 +1,25 @@
 <!DOCTYPE html>
 <html lang="en">
-<head>
-    <meta charset="UTF-8">
-    <title>SSO account deactivated</title>
-</head>
-    <body>
-        <p>This account has been deactivated.</p>
+    <head>
+        <meta charset="UTF-8">
+        <title>SSO account deactivated</title>
+        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <style type="text/css">
+            {% include "sso.css" without context %}
+        </style>
+    </head>
+    <body class="error_page">
+        <header>
+            <h1>Your account has been deactivated</h1>
+            <p>
+                <strong>No account found</strong>
+            </p>
+            <p>
+                Your account might have been deactivated by the server administrator.
+                You can either try to create a new account or contact the server’s
+                administrator.
+            </p>
+        </header>
+        {% include "sso_footer.html" without context %}
     </body>
 </html>
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
new file mode 100644
index 0000000000..f4fdc40b22
--- /dev/null
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -0,0 +1,188 @@
+<!DOCTYPE html>
+<html lang="en">
+  <head>
+    <title>Create your account</title>
+    <meta charset="utf-8">
+    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <script type="text/javascript">
+      let wasKeyboard = false;
+      document.addEventListener("mousedown", function() { wasKeyboard = false; });
+      document.addEventListener("keydown", function() { wasKeyboard = true; });
+      document.addEventListener("focusin", function() {
+        if (wasKeyboard) {
+          document.body.classList.add("keyboard-focus");
+        } else {
+          document.body.classList.remove("keyboard-focus");
+        }
+      });
+    </script>
+    <style type="text/css">
+      {% include "sso.css" without context %}
+
+      body.keyboard-focus :focus, body.keyboard-focus .username_input:focus-within {
+        outline: 3px solid #17191C;
+        outline-offset: 4px;
+      }
+
+      .username_input {
+        display: flex;
+        border: 2px solid #418DED;
+        border-radius: 8px;
+        padding: 12px;
+        position: relative;
+        margin: 16px 0;
+        align-items: center;
+        font-size: 12px;
+      }
+
+      .username_input.invalid {
+        border-color: #FE2928;
+      }
+
+      .username_input.invalid input, .username_input.invalid label {
+        color: #FE2928;
+      }
+
+      .username_input div, .username_input input {
+        line-height: 18px;
+        font-size: 14px;
+      }
+
+      .username_input label {
+        position: absolute;
+        top: -5px;
+        left: 14px;
+        font-size: 10px;
+        line-height: 10px;
+        background: white;
+        padding: 0 2px;
+      }
+
+      .username_input input {
+        flex: 1;
+        display: block;
+        min-width: 0;
+        border: none;
+      }
+
+      /* only clear the outline if we know it will be shown on the parent div using :focus-within */
+      @supports selector(:focus-within) {
+        .username_input input {
+          outline: none !important;
+        }
+      }
+
+      .username_input div {
+        color: #8D99A5;
+      }
+
+      .idp-pick-details {
+        border: 1px solid #E9ECF1;
+        border-radius: 8px;
+        margin: 24px 0;
+      }
+
+      .idp-pick-details h2 {
+        margin: 0;
+        padding: 8px 12px;
+      }
+
+      .idp-pick-details .idp-detail {
+        border-top: 1px solid #E9ECF1;
+        padding: 12px;
+        display: block;
+      }
+      .idp-pick-details .check-row {
+        display: flex;
+        align-items: center;
+      }
+
+      .idp-pick-details .check-row .name {
+        flex: 1;
+      }
+
+      .idp-pick-details .use, .idp-pick-details .idp-value {
+        color: #737D8C;
+      }
+
+      .idp-pick-details .idp-value {
+        margin: 0;
+        margin-top: 8px;
+      }
+
+      .idp-pick-details .avatar {
+        width: 53px;
+        height: 53px;
+        border-radius: 100%;
+        display: block;
+        margin-top: 8px;
+      }
+
+      output {
+        padding: 0 14px;
+        display: block;
+      }
+
+      output.error {
+        color: #FE2928;
+      }
+    </style>
+  </head>
+  <body>
+    <header>
+      <h1>Your account is nearly ready</h1>
+      <p>Check your details before creating an account on {{ server_name }}</p>
+    </header>
+    <main>
+      <form method="post" class="form__input" id="form">
+        <div class="username_input" id="username_input">
+          <label for="field-username">Username</label>
+          <div class="prefix">@</div>
+          <input type="text" name="username" id="field-username" autofocus>
+          <div class="postfix">:{{ server_name }}</div>
+        </div>
+        <output for="username_input" id="field-username-output"></output>
+        <input type="submit" value="Continue" class="primary-button">
+        {% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %}
+        <section class="idp-pick-details">
+          <h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2>
+          {% if user_attributes.avatar_url %}
+          <label class="idp-detail idp-avatar" for="idp-avatar">
+            <div class="check-row">
+              <span class="name">Avatar</span>
+              <span class="use">Use</span>
+              <input type="checkbox" name="use_avatar" id="idp-avatar" value="true" checked>
+            </div>
+            <img src="{{ user_attributes.avatar_url }}" class="avatar" />
+          </label>
+          {% endif %}
+          {% if user_attributes.display_name %}
+          <label class="idp-detail" for="idp-displayname">
+            <div class="check-row">
+              <span class="name">Display name</span>
+              <span class="use">Use</span>
+              <input type="checkbox" name="use_display_name" id="idp-displayname" value="true" checked>
+            </div>
+            <p class="idp-value">{{ user_attributes.display_name }}</p>
+          </label>
+          {% endif %}
+          {% for email in user_attributes.emails %}
+          <label class="idp-detail" for="idp-email{{ loop.index }}">
+            <div class="check-row">
+              <span class="name">E-mail</span>
+              <span class="use">Use</span>
+              <input type="checkbox" name="use_email" id="idp-email{{ loop.index }}" value="{{ email }}" checked>
+            </div>
+            <p class="idp-value">{{ email }}</p>
+          </label>
+          {% endfor %}
+        </section>
+        {% endif %}
+      </form>
+    </main>
+    {% include "sso_footer.html" without context %}
+    <script type="text/javascript">
+      {% include "sso_auth_account_details.js" without context %}
+    </script>
+  </body>
+</html>
diff --git a/synapse/res/templates/sso_auth_account_details.js b/synapse/res/templates/sso_auth_account_details.js
new file mode 100644
index 0000000000..3c45df9078
--- /dev/null
+++ b/synapse/res/templates/sso_auth_account_details.js
@@ -0,0 +1,116 @@
+const usernameField = document.getElementById("field-username");
+const usernameOutput = document.getElementById("field-username-output");
+const form = document.getElementById("form");
+
+// needed to validate on change event when no input was changed
+let needsValidation = true;
+let isValid = false;
+
+function throttle(fn, wait) {
+    let timeout;
+    const throttleFn = function() {
+        const args = Array.from(arguments);
+        if (timeout) {
+            clearTimeout(timeout);
+        }
+        timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait);
+    };
+    throttleFn.cancelQueued = function() {
+        clearTimeout(timeout);
+    };
+    return throttleFn;
+}
+
+function checkUsernameAvailable(username) {
+    let check_uri = 'check?username=' + encodeURIComponent(username);
+    return fetch(check_uri, {
+        // include the cookie
+        "credentials": "same-origin",
+    }).then(function(response) {
+        if(!response.ok) {
+            // for non-200 responses, raise the body of the response as an exception
+            return response.text().then((text) => { throw new Error(text); });
+        } else {
+            return response.json();
+        }
+    }).then(function(json) {
+        if(json.error) {
+            return {message: json.error};
+        } else if(json.available) {
+            return {available: true};
+        } else {
+            return {message: username + " is not available, please choose another."};
+        }
+    });
+}
+
+const allowedUsernameCharacters = new RegExp("^[a-z0-9\\.\\_\\-\\/\\=]+$");
+const allowedCharactersString = "lowercase letters, digits, ., _, -, /, =";
+
+function reportError(error) {
+    throttledCheckUsernameAvailable.cancelQueued();
+    usernameOutput.innerText = error;
+    usernameOutput.classList.add("error");
+    usernameField.parentElement.classList.add("invalid");
+    usernameField.focus();
+}
+
+function validateUsername(username) {
+    isValid = false;
+    needsValidation = false;
+    usernameOutput.innerText = "";
+    usernameField.parentElement.classList.remove("invalid");
+    usernameOutput.classList.remove("error");
+    if (!username) {
+        return reportError("Please provide a username");
+    }
+    if (username.length > 255) {
+        return reportError("Too long, please choose something shorter");
+    }
+    if (!allowedUsernameCharacters.test(username)) {
+        return reportError("Invalid username, please only use " + allowedCharactersString);
+    }
+    usernameOutput.innerText = "Checking if username is available …";
+    throttledCheckUsernameAvailable(username);
+}
+
+const throttledCheckUsernameAvailable = throttle(function(username) {
+    const handleError = function(err) {
+        // don't prevent form submission on error
+        usernameOutput.innerText = "";
+        isValid = true;
+    };
+    try {
+        checkUsernameAvailable(username).then(function(result) {
+            if (!result.available) {
+                reportError(result.message);
+            } else {
+                isValid = true;
+                usernameOutput.innerText = "";
+            }
+        }, handleError);
+    } catch (err) {
+        handleError(err);
+    }
+}, 500);
+
+form.addEventListener("submit", function(evt) {
+    if (needsValidation) {
+        validateUsername(usernameField.value);
+        evt.preventDefault();
+        return;
+    }
+    if (!isValid) {
+        evt.preventDefault();
+        usernameField.focus();
+        return;
+    }
+});
+usernameField.addEventListener("input", function(evt) {
+    validateUsername(usernameField.value);
+});
+usernameField.addEventListener("change", function(evt) {
+    if (needsValidation) {
+        validateUsername(usernameField.value);
+    }
+});
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
index 3611191bf9..da579ffe69 100644
--- a/synapse/res/templates/sso_auth_bad_user.html
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -1,18 +1,26 @@
-<html>
-<head>
-    <title>Authentication Failed</title>
-</head>
-    <body>
-        <div>
+<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <meta charset="UTF-8">
+        <title>Authentication failed</title>
+        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <style type="text/css">
+            {% include "sso.css" without context %}
+        </style>
+    </head>
+    <body class="error_page">
+        <header>
+            <h1>That doesn't look right</h1>
             <p>
-                We were unable to validate your <tt>{{server_name | e}}</tt> account via
-                single-sign-on (SSO), because the SSO Identity Provider returned
-                different details than when you logged in.
+                <strong>We were unable to validate your {{ server_name }} account</strong>
+                via single&nbsp;sign&#8209;on&nbsp;(SSO), because the SSO Identity
+                Provider returned different details than when you logged in.
             </p>
             <p>
                 Try the operation again, and ensure that you use the same details on
                 the Identity Provider as when you log into your account.
             </p>
-        </div>
+        </header>
+        {% include "sso_footer.html" without context %}
     </body>
 </html>
diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html
index 0d9de9d465..f9d0456f0a 100644
--- a/synapse/res/templates/sso_auth_confirm.html
+++ b/synapse/res/templates/sso_auth_confirm.html
@@ -1,14 +1,29 @@
-<html>
-<head>
-    <title>Authentication</title>
-</head>
+<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <meta charset="UTF-8">
+        <title>Confirm it's you</title>
+        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <style type="text/css">
+            {% include "sso.css" without context %}
+        </style>
+    </head>
     <body>
-        <div>
+        <header>
+            <h1>Confirm it's you to continue</h1>
             <p>
-                A client is trying to {{ description | e }}. To confirm this action,
-                <a href="{{ redirect_url | e }}">re-authenticate with single sign-on</a>.
-                If you did not expect this, your account may be compromised!
+                A client is trying to {{ description }}. To confirm this action
+                re-authorize your account with single sign-on.
             </p>
-        </div>
+            <p><strong>
+                If you did not expect this, your account may be compromised.
+            </strong></p>
+        </header>
+        <main>
+            <a href="{{ redirect_url }}" class="primary-button">
+                Continue with {{ idp.idp_name }}
+            </a>
+        </main>
+        {% include "sso_footer.html" without context %}
     </body>
 </html>
diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html
index 03f1419467..1ed3967e87 100644
--- a/synapse/res/templates/sso_auth_success.html
+++ b/synapse/res/templates/sso_auth_success.html
@@ -1,18 +1,28 @@
-<html>
-<head>
-    <title>Authentication Successful</title>
-    <script>
-    if (window.onAuthDone) {
-        window.onAuthDone();
-    } else if (window.opener && window.opener.postMessage) {
-        window.opener.postMessage("authDone", "*");
-    }
-    </script>
-</head>
+<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <meta charset="UTF-8">
+        <title>Authentication successful</title>
+        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <style type="text/css">
+            {% include "sso.css" without context %}
+        </style>
+        <script>
+            if (window.onAuthDone) {
+                window.onAuthDone();
+            } else if (window.opener && window.opener.postMessage) {
+                window.opener.postMessage("authDone", "*");
+            }
+        </script>
+    </head>
     <body>
-        <div>
-            <p>Thank you</p>
-            <p>You may now close this window and return to the application</p>
-        </div>
+        <header>
+            <h1>Thank you</h1>
+            <p>
+                Now we know it’s you, you can close this window and return to the
+                application.
+            </p>
+        </header>
+        {% include "sso_footer.html" without context %}
     </body>
 </html>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 944bc9c9ca..472309c350 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -1,53 +1,69 @@
 <!DOCTYPE html>
 <html lang="en">
-<head>
-    <meta charset="UTF-8">
-    <title>SSO error</title>
-</head>
-<body>
+    <head>
+        <meta charset="UTF-8">
+        <title>Authentication failed</title>
+        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <style type="text/css">
+            {% include "sso.css" without context %}
+
+            #error_code {
+                margin-top: 56px;
+            }
+        </style>
+    </head>
+    <body class="error_page">
 {# If an error of unauthorised is returned it means we have actively rejected their login #}
 {% if error == "unauthorised" %}
-    <p>You are not allowed to log in here.</p>
+        <header>
+            <p>You are not allowed to log in here.</p>
+        </header>
 {% else %}
-    <p>
-        There was an error during authentication:
-    </p>
-    <div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div>
-    <p>
-        If you are seeing this page after clicking a link sent to you via email, make
-        sure you only click the confirmation link once, and that you open the
-        validation link in the same client you're logging in from.
-    </p>
-    <p>
-        Try logging in again from your Matrix client and if the problem persists
-        please contact the server's administrator.
-    </p>
-    <p>Error: <code>{{ error }}</code></p>
+        <header>
+            <h1>There was an error</h1>
+            <p>
+                <strong id="errormsg">{{ error_description }}</strong>
+            </p>
+            <p>
+                If you are seeing this page after clicking a link sent to you via email,
+                make sure you only click the confirmation link once, and that you open
+                the validation link in the same client you're logging in from.
+            </p>
+            <p>
+                Try logging in again from your Matrix client and if the problem persists
+                please contact the server's administrator.
+            </p>
+            <div id="error_code">
+                <p><strong>Error code</strong></p>
+                <p>{{ error }}</p>
+            </div>
+        </header>
+        {% include "sso_footer.html" without context %}
 
-    <script type="text/javascript">
-        // Error handling to support Auth0 errors that we might get through a GET request
-        // to the validation endpoint. If an error is provided, it's either going to be
-        // located in the query string or in a query string-like URI fragment.
-        // We try to locate the error from any of these two locations, but if we can't
-        // we just don't print anything specific.
-        let searchStr = "";
-        if (window.location.search) {
-            // window.location.searchParams isn't always defined when
-            // window.location.search is, so it's more reliable to parse the latter.
-            searchStr = window.location.search;
-        } else if (window.location.hash) {
-            // Replace the # with a ? so that URLSearchParams does the right thing and
-            // doesn't parse the first parameter incorrectly.
-            searchStr = window.location.hash.replace("#", "?");
-        }
+        <script type="text/javascript">
+            // Error handling to support Auth0 errors that we might get through a GET request
+            // to the validation endpoint. If an error is provided, it's either going to be
+            // located in the query string or in a query string-like URI fragment.
+            // We try to locate the error from any of these two locations, but if we can't
+            // we just don't print anything specific.
+            let searchStr = "";
+            if (window.location.search) {
+                // window.location.searchParams isn't always defined when
+                // window.location.search is, so it's more reliable to parse the latter.
+                searchStr = window.location.search;
+            } else if (window.location.hash) {
+                // Replace the # with a ? so that URLSearchParams does the right thing and
+                // doesn't parse the first parameter incorrectly.
+                searchStr = window.location.hash.replace("#", "?");
+            }
 
-        // We might end up with no error in the URL, so we need to check if we have one
-        // to print one.
-        let errorDesc = new URLSearchParams(searchStr).get("error_description")
-        if (errorDesc) {
-            document.getElementById("errormsg").innerText = errorDesc;
-        }
-    </script>
+            // We might end up with no error in the URL, so we need to check if we have one
+            // to print one.
+            let errorDesc = new URLSearchParams(searchStr).get("error_description")
+            if (errorDesc) {
+                document.getElementById("errormsg").innerText = errorDesc;
+            }
+        </script>
 {% endif %}
 </body>
 </html>
diff --git a/synapse/res/templates/sso_footer.html b/synapse/res/templates/sso_footer.html
new file mode 100644
index 0000000000..588a3d508d
--- /dev/null
+++ b/synapse/res/templates/sso_footer.html
@@ -0,0 +1,19 @@
+<footer>
+	<svg role="img" aria-label="[Matrix logo]" viewBox="0 0 200 85" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+          <g id="parent" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
+              <g id="child" transform="translate(-122.000000, -6.000000)" fill="#000000" fill-rule="nonzero">
+                  <g id="matrix-logo" transform="translate(122.000000, 6.000000)">
+                      <polygon id="left-bracket" points="2.24708861 1.93811009 2.24708861 82.7268844 8.10278481 82.7268844 8.10278481 84.6652459 0 84.6652459 0 0 8.10278481 0 8.10278481 1.93811009"></polygon>
+                      <path d="M24.8073418,27.5493174 L24.8073418,31.6376991 L24.924557,31.6376991 C26.0227848,30.0814294 27.3455696,28.8730642 28.8951899,28.0163743 C30.4437975,27.1611927 32.2189873,26.7318422 34.218481,26.7318422 C36.1394937,26.7318422 37.8946835,27.102622 39.4825316,27.8416679 C41.0708861,28.5819706 42.276962,29.8856073 43.1005063,31.7548404 C44.0017722,30.431345 45.2270886,29.2629486 46.7767089,28.2506569 C48.3253165,27.2388679 50.158481,26.7318422 52.2764557,26.7318422 C53.8843038,26.7318422 55.3736709,26.9269101 56.7473418,27.3162917 C58.1189873,27.7056734 59.295443,28.3285835 60.2759494,29.185022 C61.255443,30.0422147 62.02,31.1615927 62.5701266,32.5426532 C63.1187342,33.9262275 63.3936709,35.5898349 63.3936709,37.5372459 L63.3936709,57.7443688 L55.0410127,57.7441174 L55.0410127,40.6319376 C55.0410127,39.6201486 55.0020253,38.6661761 54.9232911,37.7700202 C54.8440506,36.8751211 54.6293671,36.0968606 54.2764557,35.4339817 C53.9232911,34.772611 53.403038,34.2464807 52.7177215,33.8568477 C52.0313924,33.4689743 51.0997468,33.2731523 49.9235443,33.2731523 C48.7473418,33.2731523 47.7962025,33.4983853 47.0706329,33.944578 C46.344557,34.393033 45.7764557,34.9774826 45.3650633,35.6969211 C44.9534177,36.4181193 44.6787342,37.2353431 44.5417722,38.150855 C44.4037975,39.0653615 44.3356962,39.9904257 44.3356962,40.9247908 L44.3356962,57.7443688 L35.9835443,57.7443688 L35.9835443,40.8079009 C35.9835443,39.9124991 35.963038,39.0263982 35.9253165,38.150855 C35.8853165,37.2743064 35.7192405,36.4666349 35.424557,35.7263321 C35.1303797,34.9872862 34.64,34.393033 33.9539241,33.944578 C33.2675949,33.4983853 32.2579747,33.2731523 30.9248101,33.2731523 C30.5321519,33.2731523 30.0126582,33.3608826 29.3663291,33.5365945 C28.7192405,33.7118037 28.0913924,34.0433688 27.4840506,34.5292789 C26.875443,35.0164459 26.3564557,35.7172826 25.9250633,36.6315376 C25.4934177,37.5470495 25.2779747,38.7436 25.2779747,40.2229486 L25.2779747,57.7441174 L16.9260759,57.7443688 L16.9260759,27.5493174 L24.8073418,27.5493174 Z" id="m"></path>
+                      <path d="M68.7455696,31.9886202 C69.6075949,30.7033339 70.7060759,29.672189 72.0397468,28.8926716 C73.3724051,28.1141596 74.8716456,27.5596239 76.5387342,27.2283101 C78.2050633,26.8977505 79.8817722,26.7315908 81.5678481,26.7315908 C83.0974684,26.7315908 84.6458228,26.8391798 86.2144304,27.0525982 C87.7827848,27.2675248 89.2144304,27.6865688 90.5086076,28.3087248 C91.8025316,28.9313835 92.8610127,29.7983798 93.6848101,30.9074514 C94.5083544,32.0170257 94.92,33.4870734 94.92,35.3173431 L94.92,51.026844 C94.92,52.3913138 94.998481,53.6941963 95.1556962,54.9400165 C95.3113924,56.1865908 95.5863291,57.120956 95.9787342,57.7436147 L87.5091139,57.7436147 C87.3518987,57.276055 87.2240506,56.7996972 87.1265823,56.3125303 C87.0278481,55.8266202 86.9592405,55.3301523 86.9207595,54.8236294 C85.5873418,56.1865908 84.0182278,57.1405633 82.2156962,57.6857982 C80.4113924,58.2295248 78.5683544,58.503022 76.6860759,58.503022 C75.2346835,58.503022 73.8817722,58.3275615 72.6270886,57.9776459 C71.3718987,57.6269761 70.2744304,57.082244 69.3334177,56.3411872 C68.3921519,55.602644 67.656962,54.6680275 67.1275949,53.5390972 C66.5982278,52.410167 66.3331646,51.065556 66.3331646,49.5087835 C66.3331646,47.7961578 66.6367089,46.384178 67.2455696,45.2756092 C67.8529114,44.1652807 68.6367089,43.2799339 69.5987342,42.6173064 C70.5589873,41.9556844 71.6567089,41.4592165 72.8924051,41.1284055 C74.1273418,40.7978459 75.3721519,40.5356606 76.6270886,40.3398385 C77.8820253,40.1457761 79.116962,39.9896716 80.3329114,39.873033 C81.5483544,39.7558917 82.6270886,39.5804312 83.5681013,39.3469028 C84.5093671,39.1133743 85.2536709,38.7732624 85.8032911,38.3250587 C86.3513924,37.8773578 86.6063291,37.2252881 86.5678481,36.3680954 C86.5678481,35.4731963 86.4210127,34.7620532 86.1268354,34.2366771 C85.8329114,33.7113009 85.4405063,33.3018092 84.9506329,33.0099615 C84.4602532,32.7181138 83.8916456,32.5232972 83.2450633,32.4255119 C82.5977215,32.3294862 81.9010127,32.2797138 81.156962,32.2797138 C79.5098734,32.2797138 78.2159494,32.6303835 77.2746835,33.3312202 C76.3339241,34.0320569 75.7837975,35.2007046 75.6275949,36.8354037 L67.275443,36.8354037 C67.3924051,34.8892495 67.8817722,33.2726495 68.7455696,31.9886202 Z M85.2440506,43.6984752 C84.7149367,43.873433 84.1460759,44.0189798 83.5387342,44.1361211 C82.9306329,44.253011 82.2936709,44.350545 81.6270886,44.4279688 C80.96,44.5066495 80.2934177,44.6034294 79.6273418,44.7203193 C78.9994937,44.8362037 78.3820253,44.9933138 77.7749367,45.1871248 C77.1663291,45.3829468 76.636962,45.6451321 76.1865823,45.9759431 C75.7349367,46.3070055 75.3724051,46.7263009 75.0979747,47.2313156 C74.8232911,47.7375872 74.6863291,48.380356 74.6863291,49.1588679 C74.6863291,49.8979138 74.8232911,50.5218294 75.0979747,51.026844 C75.3724051,51.5338697 75.7455696,51.9328037 76.2159494,52.2246514 C76.6863291,52.5164991 77.2349367,52.7213706 77.8632911,52.8375064 C78.4898734,52.9546477 79.136962,53.012967 79.8037975,53.012967 C81.4506329,53.012967 82.724557,52.740978 83.6273418,52.1952404 C84.5288608,51.6507596 85.1949367,50.9981872 85.6270886,50.2382771 C86.0579747,49.4793725 86.323038,48.7119211 86.4212658,47.9321523 C86.518481,47.1536404 86.5681013,46.5304789 86.5681013,46.063422 L86.5681013,42.9677248 C86.2146835,43.2799339 85.7736709,43.5230147 85.2440506,43.6984752 Z" id="a"></path>
+                      <path d="M116.917975,27.5493174 L116.917975,33.0976917 L110.801266,33.0976917 L110.801266,48.0492936 C110.801266,49.4502128 111.036203,50.3850807 111.507089,50.8518862 C111.976962,51.3191945 112.918734,51.5527229 114.33038,51.5527229 C114.801013,51.5527229 115.251392,51.5336183 115.683038,51.4944037 C116.114177,51.4561945 116.526076,51.3968697 116.917975,51.3194459 L116.917975,57.7438661 C116.212152,57.860756 115.427595,57.9381798 114.565316,57.9778972 C113.702785,58.0153523 112.859747,58.0357138 112.036203,58.0357138 C110.742278,58.0357138 109.516456,57.9477321 108.36,57.7722716 C107.202785,57.5975651 106.183544,57.2577046 105.301519,56.7509303 C104.418987,56.2454128 103.722785,55.5242147 103.213418,54.5898495 C102.703038,53.6562385 102.448608,52.4292716 102.448608,50.9099541 L102.448608,33.0976917 L97.3903797,33.0976917 L97.3903797,27.5493174 L102.448608,27.5493174 L102.448608,18.4967596 L110.801013,18.4967596 L110.801013,27.5493174 L116.917975,27.5493174 Z" id="t"></path>
+                      <path d="M128.857975,27.5493174 L128.857975,33.1565138 L128.975696,33.1565138 C129.367089,32.2213945 129.896203,31.3559064 130.563544,30.557033 C131.23038,29.7596679 131.99443,29.0776844 132.857215,28.5130936 C133.719241,27.9495083 134.641266,27.5113596 135.622532,27.1988991 C136.601772,26.8879468 137.622025,26.7315908 138.681013,26.7315908 C139.229873,26.7315908 139.836962,26.8296275 140.504304,27.0239413 L140.504304,34.7336477 C140.111646,34.6552183 139.641013,34.586844 139.092658,34.5290275 C138.543291,34.4704569 138.014177,34.4410459 137.504304,34.4410459 C135.974937,34.4410459 134.681013,34.6949358 133.622785,35.2004532 C132.564051,35.7067248 131.711392,36.397255 131.064051,37.2735523 C130.417215,38.1501009 129.955443,39.1714422 129.681266,40.3398385 C129.407089,41.5074807 129.269873,42.7736624 129.269873,44.1361211 L129.269873,57.7438661 L120.917722,57.7438661 L120.917722,27.5493174 L128.857975,27.5493174 Z" id="r"></path>
+                      <path d="M144.033165,22.8767376 L144.033165,16.0435798 L152.386076,16.0435798 L152.386076,22.8767376 L144.033165,22.8767376 Z M152.386076,27.5493174 L152.386076,57.7438661 L144.033165,57.7438661 L144.033165,27.5493174 L152.386076,27.5493174 Z" id="i"></path>
+                      <polygon id="x" points="156.738228 27.5493174 166.266582 27.5493174 171.619494 35.4337303 176.913418 27.5493174 186.147848 27.5493174 176.148861 41.6831927 187.383544 57.7441174 177.85443 57.7441174 171.501772 48.2245028 165.148861 57.7441174 155.797468 57.7441174 166.737468 41.8589046"></polygon>
+                      <polygon id="right-bracket" points="197.580759 82.7268844 197.580759 1.93811009 191.725063 1.93811009 191.725063 0 199.828354 0 199.828354 84.6652459 191.725063 84.6652459 191.725063 82.7268844"></polygon>
+                  </g>
+              </g>
+          </g>
+      </svg>
+      <p>An open network for secure, decentralized communication.<br>© 2021 The Matrix.org Foundation C.I.C.</p>
+</footer>
\ No newline at end of file
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
index 5b38481012..53b82db84e 100644
--- a/synapse/res/templates/sso_login_idp_picker.html
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -2,30 +2,60 @@
 <html lang="en">
     <head>
         <meta charset="UTF-8">
-        <link rel="stylesheet" href="/_matrix/static/client/login/style.css">
-        <title>{{server_name | e}} Login</title>
+        <title>Choose identity provider</title>
+        <style type="text/css">
+          {% include "sso.css" without context %}
+
+          .providers {
+            list-style: none;
+            padding: 0;
+          }
+
+          .providers li {
+            margin: 12px;
+          }
+
+          .providers a {
+            display: block;
+            border-radius: 4px;
+            border: 1px solid #17191C;
+            padding: 8px;
+            text-align: center;
+            text-decoration: none;
+            color: #17191C;
+            display: flex;
+            align-items: center;
+            font-weight: bold;
+          }
+
+          .providers a img {
+            width: 24px;
+            height: 24px;
+          }
+          .providers a span {
+            flex: 1;
+          }
+        </style>
     </head>
     <body>
-        <div id="container">
-            <h1 id="title">{{server_name | e}} Login</h1>
-            <div class="login_flow">
-                <p>Choose one of the following identity providers:</p>
-            <form>
-                <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
-                <ul class="radiobuttons">
-{% for p in providers %}
-                    <li>
-                        <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
-                        <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
-{% if p.idp_icon %}
-                        <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
-{% endif %}
-                    </li>
-{% endfor %}
-                </ul>
-                <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
-            </form>
-            </div>
-        </div>
+        <header>
+            <h1>Log in to {{ server_name }} </h1>
+            <p>Choose an identity provider to log in</p>
+        </header>
+        <main>
+            <ul class="providers">
+                {% for p in providers %}
+                <li>
+                    <a href="pick_idp?idp={{ p.idp_id }}&redirectUrl={{ redirect_url | urlencode }}">
+                        {% if p.idp_icon %}
+                        <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/>
+                        {% endif %}
+                        <span>{{ p.idp_name }}</span>
+                    </a>
+                </li>
+                {% endfor %}
+            </ul>
+        </main>
+        {% include "sso_footer.html" without context %}
     </body>
 </html>
diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html
new file mode 100644
index 0000000000..68c8b9f33a
--- /dev/null
+++ b/synapse/res/templates/sso_new_user_consent.html
@@ -0,0 +1,32 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <title>Agree to terms and conditions</title>
+    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <style type="text/css">
+      {% include "sso.css" without context %}
+
+      #consent_form {
+        margin-top: 56px;
+      }
+    </style>
+</head>
+    <body>
+        <header>
+            <h1>Your account is nearly ready</h1>
+            <p>Agree to the terms to create your account.</p>
+        </header>
+        <main>
+            {% include "sso_partial_profile.html" %}
+            <form method="post" action="{{my_url}}" id="consent_form">
+                <p>
+                    <input id="accepted_version" type="checkbox" name="accepted_version" value="{{ consent_version }}" required>
+                    <label for="accepted_version">I have read and agree to the <a href="{{ terms_url }}" target="_blank" rel="noopener">terms and conditions</a>.</label>
+                </p>
+                <input type="submit" class="primary-button" value="Continue"/>
+            </form>
+        </main>
+        {% include "sso_footer.html" without context %}
+    </body>
+</html>
diff --git a/synapse/res/templates/sso_partial_profile.html b/synapse/res/templates/sso_partial_profile.html
new file mode 100644
index 0000000000..c9c76c455e
--- /dev/null
+++ b/synapse/res/templates/sso_partial_profile.html
@@ -0,0 +1,19 @@
+{# html fragment to be included in SSO pages, to show the user's profile #}
+
+<div class="profile{% if user_profile.avatar_url %} with-avatar{% endif %}">
+    {% if user_profile.avatar_url %}
+    <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
+    {% endif %}
+    {# users that signed up with SSO will have a display_name of some sort;
+       however that is not the case for users who signed up via other
+       methods, so we need to handle that.
+    #}
+    {% if user_profile.display_name %}
+        <div class="display-name">{{ user_profile.display_name }}</div>
+    {% else %}
+        {# split the userid on ':', take the part before the first ':',
+           and then remove the leading '@'. #}
+        <div class="display-name">{{ user_id.split(":")[0][1:] }}</div>
+    {% endif %}
+    <div class="user-id">{{ user_id }}</div>
+</div>
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index 20a15e1e74..1b01471ac8 100644
--- a/synapse/res/templates/sso_redirect_confirm.html
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -2,13 +2,39 @@
 <html lang="en">
 <head>
     <meta charset="UTF-8">
-    <title>SSO redirect confirmation</title>
+    <title>Continue to your account</title>
+    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <style type="text/css">
+      {% include "sso.css" without context %}
+
+      .confirm-trust {
+        margin: 34px 0;
+        color: #8D99A5;
+      }
+      .confirm-trust strong {
+        color: #17191C;
+      }
+
+      .confirm-trust::before {
+        content: "";
+        background-image: url('');
+        background-repeat: no-repeat;
+        width: 24px;
+        height: 24px;
+        display: block;
+        float: left;
+      }
+    </style>
 </head>
     <body>
-        <p>The application at <span style="font-weight:bold">{{ display_url | e }}</span> is requesting full access to your <span style="font-weight:bold">{{ server_name }}</span> Matrix account.</p>
-        <p>If you don't recognise this address, you should ignore this and close this tab.</p>
-        <p>
-            <a href="{{ redirect_url | e }}">I trust this address</a>
-        </p>
+        <header>
+            <h1>Continue to your account</h1>
+        </header>
+        <main>
+            {% include "sso_partial_profile.html" %}
+            <p class="confirm-trust">Continuing will grant <strong>{{ display_url }}</strong> access to your account.</p>
+            <a href="{{ redirect_url }}" class="primary-button">Continue</a>
+        </main>
+        {% include "sso_footer.html" without context %}
     </body>
-</html>
\ No newline at end of file
+</html>
diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html
deleted file mode 100644
index 37ea8bb6d8..0000000000
--- a/synapse/res/username_picker/index.html
+++ /dev/null
@@ -1,19 +0,0 @@
-<!DOCTYPE html>
-<html lang="en">
-  <head>
-    <title>Synapse Login</title>
-    <link rel="stylesheet" href="style.css" type="text/css" />
-  </head>
-  <body>
-    <div class="card">
-      <form method="post" class="form__input" id="form" action="submit">
-        <label for="field-username">Please pick your username:</label>
-        <input type="text" name="username" id="field-username" autofocus="">
-        <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
-      </form>
-      <!-- this is used for feedback -->
-      <div role=alert class="tooltip hidden" id="message"></div>
-      <script src="script.js"></script>
-    </div>
-  </body>
-</html>
diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js
deleted file mode 100644
index 416a7c6f41..0000000000
--- a/synapse/res/username_picker/script.js
+++ /dev/null
@@ -1,95 +0,0 @@
-let inputField = document.getElementById("field-username");
-let inputForm = document.getElementById("form");
-let submitButton = document.getElementById("button-submit");
-let message = document.getElementById("message");
-
-// Submit username and receive response
-function showMessage(messageText) {
-    // Unhide the message text
-    message.classList.remove("hidden");
-
-    message.textContent = messageText;
-};
-
-function doSubmit() {
-    showMessage("Success. Please wait a moment for your browser to redirect.");
-
-    // remove the event handler before re-submitting the form.
-    delete inputForm.onsubmit;
-    inputForm.submit();
-}
-
-function onResponse(response) {
-    // Display message
-    showMessage(response);
-
-    // Enable submit button and input field
-    submitButton.classList.remove('button--disabled');
-    submitButton.value = "Submit";
-};
-
-let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]");
-function usernameIsValid(username) {
-    return !allowedUsernameCharacters.test(username);
-}
-let allowedCharactersString = "lowercase letters, digits, ., _, -, /, =";
-
-function buildQueryString(params) {
-    return Object.keys(params)
-        .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k]))
-        .join('&');
-}
-
-function submitUsername(username) {
-    if(username.length == 0) {
-        onResponse("Please enter a username.");
-        return;
-    }
-    if(!usernameIsValid(username)) {
-        onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString);
-        return;
-    }
-
-    // if this browser doesn't support fetch, skip the availability check.
-    if(!window.fetch) {
-        doSubmit();
-        return;
-    }
-
-    let check_uri = 'check?' + buildQueryString({"username": username});
-    fetch(check_uri, {
-        // include the cookie
-        "credentials": "same-origin",
-    }).then((response) => {
-        if(!response.ok) {
-            // for non-200 responses, raise the body of the response as an exception
-            return response.text().then((text) => { throw text; });
-        } else {
-            return response.json();
-        }
-    }).then((json) => {
-        if(json.error) {
-            throw json.error;
-        } else if(json.available) {
-            doSubmit();
-        } else {
-            onResponse("This username is not available, please choose another.");
-        }
-    }).catch((err) => {
-        onResponse("Error checking username availability: " + err);
-    });
-}
-
-function clickSubmit() {
-    event.preventDefault();
-    if(submitButton.classList.contains('button--disabled')) { return; }
-
-    // Disable submit button and input field
-    submitButton.classList.add('button--disabled');
-
-    // Submit username
-    submitButton.value = "Checking...";
-    submitUsername(inputField.value);
-};
-
-inputForm.onsubmit = clickSubmit;
diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css
deleted file mode 100644
index 745bd4c684..0000000000
--- a/synapse/res/username_picker/style.css
+++ /dev/null
@@ -1,27 +0,0 @@
-input[type="text"] {
-  font-size: 100%;
-  background-color: #ededf0;
-  border: 1px solid #fff;
-  border-radius: .2em;
-  padding: .5em .9em;
-  display: block;
-  width: 26em;
-}
-
-.button--disabled {
-  border-color: #fff;
-  background-color: transparent;
-  color: #000;
-  text-transform: none;
-}
-
-.hidden {
-  display: none;
-}
-
-.tooltip {
-  background-color: #f9f9fa;
-  padding: 1em;
-  margin: 1em 0;
-}
-
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6f7dc06503..8457db1e22 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018-2019 New Vector Ltd
+# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
+
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -36,11 +38,14 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
 from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
 from synapse.rest.admin.rooms import (
     DeleteRoomRestServlet,
+    ForwardExtremitiesRestServlet,
     JoinRoomAliasServlet,
     ListRoomRestServlet,
     MakeRoomAdminRestServlet,
+    RoomEventContextServlet,
     RoomMembersRestServlet,
     RoomRestServlet,
+    RoomStateRestServlet,
     ShutdownRoomRestServlet,
 )
 from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -51,6 +56,7 @@ from synapse.rest.admin.users import (
     PushersRestServlet,
     ResetPasswordRestServlet,
     SearchUsersRestServlet,
+    ShadowBanRestServlet,
     UserAdminServlet,
     UserMediaRestServlet,
     UserMembershipRestServlet,
@@ -209,6 +215,7 @@ def register_servlets(hs, http_server):
     """
     register_servlets_for_client_rest_resource(hs, http_server)
     ListRoomRestServlet(hs).register(http_server)
+    RoomStateRestServlet(hs).register(http_server)
     RoomRestServlet(hs).register(http_server)
     RoomMembersRestServlet(hs).register(http_server)
     DeleteRoomRestServlet(hs).register(http_server)
@@ -230,6 +237,9 @@ def register_servlets(hs, http_server):
     EventReportsRestServlet(hs).register(http_server)
     PushersRestServlet(hs).register(http_server)
     MakeRoomAdminRestServlet(hs).register(http_server)
+    ShadowBanRestServlet(hs).register(http_server)
+    ForwardExtremitiesRestServlet(hs).register(http_server)
+    RoomEventContextServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index d0c86b204a..ebc587aa06 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -22,8 +22,7 @@ logger = logging.getLogger(__name__)
 
 
 class DeleteGroupAdminRestServlet(RestServlet):
-    """Allows deleting of local groups
-    """
+    """Allows deleting of local groups"""
 
     PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
 
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 8720b1401f..b996862c05 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -119,8 +119,7 @@ class QuarantineMediaByID(RestServlet):
 
 
 class ProtectMediaByID(RestServlet):
-    """Protect local media from being quarantined.
-    """
+    """Protect local media from being quarantined."""
 
     PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
 
@@ -141,8 +140,7 @@ class ProtectMediaByID(RestServlet):
 
 
 class ListMediaInRoom(RestServlet):
-    """Lists all of the media in a given room.
-    """
+    """Lists all of the media in a given room."""
 
     PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
 
@@ -180,8 +178,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
 
 
 class DeleteMediaByID(RestServlet):
-    """Delete local media by a given ID. Removes it from this server.
-    """
+    """Delete local media by a given ID. Removes it from this server."""
 
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
 
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ab7cc9102a..1a3a36f6cf 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,9 +15,11 @@
 import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, List, Optional, Tuple
+from urllib import parse as urlparse
 
 from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.filtering import Filter
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
@@ -33,6 +35,7 @@ from synapse.rest.admin._base import (
 )
 from synapse.storage.databases.main.room import RoomSortOrder
 from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+from synapse.util import json_decoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -292,6 +295,45 @@ class RoomMembersRestServlet(RestServlet):
         return 200, ret
 
 
+class RoomStateRestServlet(RestServlet):
+    """
+    Get full state within a room.
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self._event_serializer = hs.get_event_client_serializer()
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        ret = await self.store.get_room(room_id)
+        if not ret:
+            raise NotFoundError("Room not found")
+
+        event_ids = await self.store.get_current_state_ids(room_id)
+        events = await self.store.get_events(event_ids.values())
+        now = self.clock.time_msec()
+        room_state = await self._event_serializer.serialize_events(
+            events.values(),
+            now,
+            # We don't bother bundling aggregations in when asked for state
+            # events, as clients won't use them.
+            bundle_aggregations=False,
+        )
+        ret = {"state": room_state}
+
+        return 200, ret
+
+
 class JoinRoomAliasServlet(RestServlet):
 
     PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
@@ -431,7 +473,18 @@ class MakeRoomAdminRestServlet(RestServlet):
             if not admin_users:
                 raise SynapseError(400, "No local admin user in room")
 
-            admin_user_id = admin_users[-1]
+            admin_user_id = None
+
+            for admin_user in reversed(admin_users):
+                if room_state.get((EventTypes.Member, admin_user)):
+                    admin_user_id = admin_user
+                    break
+
+            if not admin_user_id:
+                raise SynapseError(
+                    400,
+                    "No local admin user in room",
+                )
 
             pl_content = power_levels.content
         else:
@@ -440,7 +493,8 @@ class MakeRoomAdminRestServlet(RestServlet):
             admin_user_id = create_event.sender
             if not self.is_mine_id(admin_user_id):
                 raise SynapseError(
-                    400, "No local admin user in room",
+                    400,
+                    "No local admin user in room",
                 )
 
         # Grant the user power equal to the room admin by attempting to send an
@@ -450,7 +504,8 @@ class MakeRoomAdminRestServlet(RestServlet):
         new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
 
         fake_requester = create_requester(
-            admin_user_id, authenticated_entity=requester.authenticated_entity,
+            admin_user_id,
+            authenticated_entity=requester.authenticated_entity,
         )
 
         try:
@@ -499,3 +554,122 @@ class MakeRoomAdminRestServlet(RestServlet):
         )
 
         return 200, {}
+
+
+class ForwardExtremitiesRestServlet(RestServlet):
+    """Allows a server admin to get or clear forward extremities.
+
+    Clearing does not require restarting the server.
+
+        Clear forward extremities:
+        DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+
+        Get forward_extremities:
+        GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.room_member_handler = hs.get_room_member_handler()
+        self.store = hs.get_datastore()
+
+    async def resolve_room_id(self, room_identifier: str) -> str:
+        """Resolve to a room ID, if necessary."""
+        if RoomID.is_valid(room_identifier):
+            resolved_room_id = room_identifier
+        elif RoomAlias.is_valid(room_identifier):
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+            resolved_room_id = room_id.to_string()
+        else:
+            raise SynapseError(
+                400, "%s was not legal room ID or room alias" % (room_identifier,)
+            )
+        if not resolved_room_id:
+            raise SynapseError(
+                400, "Unknown room ID or room alias %s" % room_identifier
+            )
+        return resolved_room_id
+
+    async def on_DELETE(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        room_id = await self.resolve_room_id(room_identifier)
+
+        deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
+        return 200, {"deleted": deleted_count}
+
+    async def on_GET(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        room_id = await self.resolve_room_id(room_identifier)
+
+        extremities = await self.store.get_forward_extremities_for_room(room_id)
+        return 200, {"count": len(extremities), "results": extremities}
+
+
+class RoomEventContextServlet(RestServlet):
+    """
+    Provide the context for an event.
+    This API is designed to be used when system administrators wish to look at
+    an abuse report and understand what happened during and immediately prior
+    to this event.
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
+
+    def __init__(self, hs):
+        super().__init__()
+        self.clock = hs.get_clock()
+        self.room_context_handler = hs.get_room_context_handler()
+        self._event_serializer = hs.get_event_client_serializer()
+        self.auth = hs.get_auth()
+
+    async def on_GET(self, request, room_id, event_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=False)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        limit = parse_integer(request, "limit", default=10)
+
+        # picking the API shape for symmetry with /messages
+        filter_str = parse_string(request, b"filter", encoding="utf-8")
+        if filter_str:
+            filter_json = urlparse.unquote(filter_str)
+            event_filter = Filter(
+                json_decoder.decode(filter_json)
+            )  # type: Optional[Filter]
+        else:
+            event_filter = None
+
+        results = await self.room_context_handler.get_event_context(
+            requester,
+            room_id,
+            event_id,
+            limit,
+            event_filter,
+            use_admin_priviledge=True,
+        )
+
+        if not results:
+            raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+
+        time_now = self.clock.time_msec()
+        results["events_before"] = await self._event_serializer.serialize_events(
+            results["events_before"], time_now
+        )
+        results["event"] = await self._event_serializer.serialize_event(
+            results["event"], time_now
+        )
+        results["events_after"] = await self._event_serializer.serialize_events(
+            results["events_after"], time_now
+        )
+        results["state"] = await self._event_serializer.serialize_events(
+            results["state"], time_now
+        )
+
+        return 200, results
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f39e3d6d5c..998a0ef671 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet):
     The parameter `deactivated` can be used to include deactivated users.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         start = parse_integer(request, "from", default=0)
         limit = parse_integer(request, "limit", default=100)
+
+        if start < 0:
+            raise SynapseError(
+                400,
+                "Query parameter from must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if limit < 0:
+            raise SynapseError(
+                400,
+                "Query parameter limit must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
         user_id = parse_string(request, "user_id", default=None)
         name = parse_string(request, "name", default=None)
         guests = parse_boolean(request, "guests", default=True)
@@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet):
             start, limit, user_id, name, guests, deactivated
         )
         ret = {"users": users, "total": total}
-        if len(users) >= limit:
+        if (start + limit) < total:
             ret["next_token"] = str(start + len(users))
 
         return 200, ret
@@ -564,7 +579,7 @@ class ResetPasswordRestServlet(RestServlet):
             }
         Returns:
             200 OK with empty object if success otherwise an error.
-        """
+    """
 
     PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
 
@@ -737,7 +752,7 @@ class PushersRestServlet(RestServlet):
 
     Returns:
         pushers: Dictionary containing pushers information.
-        total: Number of pushers in dictonary `pushers`.
+        total: Number of pushers in dictionary `pushers`.
     """
 
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
@@ -875,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
         )
 
         return 200, {"access_token": token}
+
+
+class ShadowBanRestServlet(RestServlet):
+    """An admin API for shadow-banning a user.
+
+    A shadow-banned users receives successful responses to their client-server
+    API requests, but the events are not propagated into rooms.
+
+    Shadow-banning a user should be used as a tool of last resort and may lead
+    to confusing or broken behaviour for the client.
+
+    Example:
+
+        POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
+        {}
+
+        200 OK
+        {}
+    """
+
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_POST(self, request, user_id):
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Only local users can be shadow-banned")
+
+        await self.store.set_shadow_banned(UserID.from_string(user_id), True)
+
+        return 200, {}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index be938df962..6e2fbedd99 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
     parse_json_object_from_request,
@@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet):
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
 
         self.auth = hs.get_auth()
 
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
+        self._sso_handler = hs.get_sso_handler()
+
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
             clock=hs.get_clock(),
@@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet):
             flows.append({"type": LoginRestServlet.CAS_TYPE})
 
         if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            flows.append({"type": LoginRestServlet.SSO_TYPE})
-            # While its valid for us to advertise this login type generally,
+            sso_flow = {"type": LoginRestServlet.SSO_TYPE}  # type: JsonDict
+
+            if self._msc2858_enabled:
+                sso_flow["org.matrix.msc2858.identity_providers"] = [
+                    _get_auth_flow_dict_for_idp(idp)
+                    for idp in self._sso_handler.get_identity_providers().values()
+                ]
+
+            flows.append(sso_flow)
+
+            # While it's valid for us to advertise this login type generally,
             # synapse currently only gives out these tokens as part of the
             # SSO login flow.
             # Generally we don't want to advertise login flows that clients
@@ -297,7 +310,9 @@ class LoginRestServlet(RestServlet):
         except jwt.PyJWTError as e:
             # A JWT error occurred, return some info back to the client.
             raise LoginError(
-                403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
+                403,
+                "JWT validation failed: %s" % (str(e),),
+                errcode=Codes.FORBIDDEN,
             )
 
         user = payload.get("sub", None)
@@ -311,8 +326,22 @@ class LoginRestServlet(RestServlet):
         return result
 
 
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+    """Return an entry for the login flow dict
+
+    Returns an entry suitable for inclusion in "identity_providers" in the
+    response to GET /_matrix/client/r0/login
+    """
+    e = {"id": idp.idp_id, "name": idp.idp_name}  # type: JsonDict
+    if idp.idp_icon:
+        e["icon"] = idp.idp_icon
+    if idp.idp_brand:
+        e["brand"] = idp.idp_brand
+    return e
+
+
 class SsoRedirectServlet(RestServlet):
-    PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+    PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
 
     def __init__(self, hs: "HomeServer"):
         # make sure that the relevant handlers are instantiated, so that they
@@ -324,13 +353,33 @@ class SsoRedirectServlet(RestServlet):
         if hs.config.oidc_enabled:
             hs.get_oidc_handler()
         self._sso_handler = hs.get_sso_handler()
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+    def register(self, http_server: HttpServer) -> None:
+        super().register(http_server)
+        if self._msc2858_enabled:
+            # expose additional endpoint for MSC2858 support
+            http_server.register_paths(
+                "GET",
+                client_patterns(
+                    "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+                    releases=(),
+                    unstable=True,
+                ),
+                self.on_GET,
+                self.__class__.__name__,
+            )
 
-    async def on_GET(self, request: SynapseRequest):
+    async def on_GET(
+        self, request: SynapseRequest, idp_id: Optional[str] = None
+    ) -> None:
         client_redirect_url = parse_string(
             request, "redirectUrl", required=True, encoding=None
         )
         sso_url = await self._sso_handler.handle_redirect_request(
-            request, client_redirect_url
+            request,
+            client_redirect_url,
+            idp_id,
         )
         logger.info("Redirecting to %s", sso_url)
         request.redirect(sso_url)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 85a66458c5..717c5f2b10 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -60,7 +60,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
             new_name = content["displayname"]
         except Exception:
             raise SynapseError(
-                code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON,
+                code=400,
+                msg="Unable to parse name",
+                errcode=Codes.BAD_JSON,
             )
 
         await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 89823fcc39..0c148a213d 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -159,7 +159,9 @@ class PushersRemoveRestServlet(RestServlet):
         self.notifier.on_new_replication_data()
 
         respond_with_html_bytes(
-            request, 200, PushersRemoveRestServlet.SUCCESS_HTML,
+            request,
+            200,
+            PushersRemoveRestServlet.SUCCESS_HTML,
         )
         return None
 
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f95627ee61..9a1df30c29 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -362,7 +362,9 @@ class PublicRoomListRestServlet(TransactionRestServlet):
                 parse_and_validate_server_name(server)
             except ValueError:
                 raise SynapseError(
-                    400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+                    400,
+                    "Invalid server name: %s" % (server,),
+                    Codes.INVALID_PARAM,
                 )
 
             try:
@@ -413,7 +415,9 @@ class PublicRoomListRestServlet(TransactionRestServlet):
                 parse_and_validate_server_name(server)
             except ValueError:
                 raise SynapseError(
-                    400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+                    400,
+                    "Invalid server name: %s" % (server,),
+                    Codes.INVALID_PARAM,
                 )
 
             try:
@@ -650,7 +654,7 @@ class RoomEventContextServlet(RestServlet):
             event_filter = None
 
         results = await self.room_context_handler.get_event_context(
-            requester.user, room_id, event_id, limit, event_filter
+            requester, room_id, event_id, limit, event_filter
         )
 
         if not results:
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 65e68d641b..adf1d39728 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
 class EmailPasswordRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password/email/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.datastore = hs.get_datastore()
@@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
         # an email address which is controlled by the attacker but which, after
@@ -191,7 +193,10 @@ class PasswordRestServlet(RestServlet):
             requester = await self.auth.get_user_by_req(request)
             try:
                 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
-                    requester, request, body, "modify your account password",
+                    requester,
+                    request,
+                    body,
+                    "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
                 # The user needs to provide more steps to complete auth, but
@@ -310,7 +315,10 @@ class DeactivateAccountRestServlet(RestServlet):
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester, request, body, "deactivate your account",
+            requester,
+            request,
+            body,
+            "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
             requester.user.to_string(),
@@ -379,6 +387,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
@@ -430,7 +440,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
 class MsisdnThreepidRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         super().__init__()
         self.store = self.hs.get_datastore()
@@ -458,6 +468,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
@@ -695,7 +709,10 @@ class ThreepidAddRestServlet(RestServlet):
         assert_valid_client_secret(client_secret)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester, request, body, "add a third-party identifier to your account",
+            requester,
+            request,
+            body,
+            "add a third-party identifier to your account",
         )
 
         validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 314e01dfe4..3d07aadd39 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,7 +83,10 @@ class DeleteDevicesRestServlet(RestServlet):
         assert_params_in_dict(body, ["devices"])
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester, request, body, "remove device(s) from your account",
+            requester,
+            request,
+            body,
+            "remove device(s) from your account",
         )
 
         await self.device_handler.delete_devices(
@@ -129,7 +132,10 @@ class DeviceRestServlet(RestServlet):
                 raise
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester, request, body, "remove a device from your account",
+            requester,
+            request,
+            body,
+            "remove a device from your account",
         )
 
         await self.device_handler.delete_device(requester.user.to_string(), device_id)
@@ -206,7 +212,9 @@ class DehydratedDeviceServlet(RestServlet):
 
         if "device_data" not in submission:
             raise errors.SynapseError(
-                400, "device_data missing", errcode=errors.Codes.MISSING_PARAM,
+                400,
+                "device_data missing",
+                errcode=errors.Codes.MISSING_PARAM,
             )
         elif not isinstance(submission["device_data"], dict):
             raise errors.SynapseError(
@@ -259,11 +267,15 @@ class ClaimDehydratedDeviceServlet(RestServlet):
 
         if "device_id" not in submission:
             raise errors.SynapseError(
-                400, "device_id missing", errcode=errors.Codes.MISSING_PARAM,
+                400,
+                "device_id missing",
+                errcode=errors.Codes.MISSING_PARAM,
             )
         elif not isinstance(submission["device_id"], str):
             raise errors.SynapseError(
-                400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM,
+                400,
+                "device_id must be a string",
+                errcode=errors.Codes.INVALID_PARAM,
             )
 
         result = await self.device_handler.rehydrate_device(
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 5b5da71815..d3434225cb 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,13 +16,29 @@
 
 import logging
 from functools import wraps
-
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import GroupID
+from typing import TYPE_CHECKING, Optional, Tuple
+
+from twisted.web.http import Request
+
+from synapse.api.constants import (
+    MAX_GROUP_CATEGORYID_LENGTH,
+    MAX_GROUP_ROLEID_LENGTH,
+    MAX_GROUPID_LENGTH,
+)
+from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.servlet import (
+    RestServlet,
+    assert_params_in_dict,
+    parse_json_object_from_request,
+)
+from synapse.types import GroupID, JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -33,7 +49,7 @@ def _validate_group_id(f):
     """
 
     @wraps(f)
-    def wrapper(self, request, group_id, *args, **kwargs):
+    def wrapper(self, request: Request, group_id: str, *args, **kwargs):
         if not GroupID.is_valid(group_id):
             raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
 
@@ -43,19 +59,18 @@ def _validate_group_id(f):
 
 
 class GroupServlet(RestServlet):
-    """Get the group profile
-    """
+    """Get the group profile"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -66,11 +81,17 @@ class GroupServlet(RestServlet):
         return 200, group_description
 
     @_validate_group_id
-    async def on_POST(self, request, group_id):
+    async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert_params_in_dict(
+            content, ("name", "avatar_url", "short_description", "long_description")
+        )
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot create group profiles."
         await self.groups_handler.update_group_profile(
             group_id, requester_user_id, content
         )
@@ -79,19 +100,18 @@ class GroupServlet(RestServlet):
 
 
 class GroupSummaryServlet(RestServlet):
-    """Get the full group summary
-    """
+    """Get the full group summary"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -116,18 +136,34 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         "/rooms/(?P<room_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, category_id, room_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, category_id: Optional[str], room_id: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        if category_id == "":
+            raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
+
+        if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+            raise SynapseError(
+                400,
+                "category_id may not be longer than %s characters"
+                % (MAX_GROUP_CATEGORYID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         content = parse_json_object_from_request(request)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group summaries."
         resp = await self.groups_handler.update_group_summary_room(
             group_id,
             requester_user_id,
@@ -139,10 +175,15 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, category_id, room_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, category_id: str, room_id: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group profiles."
         resp = await self.groups_handler.delete_group_summary_room(
             group_id, requester_user_id, room_id=room_id, category_id=category_id
         )
@@ -151,21 +192,22 @@ class GroupSummaryRoomsCatServlet(RestServlet):
 
 
 class GroupCategoryServlet(RestServlet):
-    """Get/add/update/delete a group category
-    """
+    """Get/add/update/delete a group category"""
 
     PATTERNS = client_patterns(
         "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id, category_id):
+    async def on_GET(
+        self, request: Request, group_id: str, category_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -176,11 +218,27 @@ class GroupCategoryServlet(RestServlet):
         return 200, category
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, category_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, category_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        if not category_id:
+            raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
+
+        if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+            raise SynapseError(
+                400,
+                "category_id may not be longer than %s characters"
+                % (MAX_GROUP_CATEGORYID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         content = parse_json_object_from_request(request)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group categories."
         resp = await self.groups_handler.update_group_category(
             group_id, requester_user_id, category_id=category_id, content=content
         )
@@ -188,10 +246,15 @@ class GroupCategoryServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, category_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, category_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group categories."
         resp = await self.groups_handler.delete_group_category(
             group_id, requester_user_id, category_id=category_id
         )
@@ -200,19 +263,18 @@ class GroupCategoryServlet(RestServlet):
 
 
 class GroupCategoriesServlet(RestServlet):
-    """Get all group categories
-    """
+    """Get all group categories"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -224,19 +286,20 @@ class GroupCategoriesServlet(RestServlet):
 
 
 class GroupRoleServlet(RestServlet):
-    """Get/add/update/delete a group role
-    """
+    """Get/add/update/delete a group role"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id, role_id):
+    async def on_GET(
+        self, request: Request, group_id: str, role_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -247,11 +310,27 @@ class GroupRoleServlet(RestServlet):
         return 200, category
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, role_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, role_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        if not role_id:
+            raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
+
+        if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+            raise SynapseError(
+                400,
+                "role_id may not be longer than %s characters"
+                % (MAX_GROUP_ROLEID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         content = parse_json_object_from_request(request)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group roles."
         resp = await self.groups_handler.update_group_role(
             group_id, requester_user_id, role_id=role_id, content=content
         )
@@ -259,10 +338,15 @@ class GroupRoleServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, role_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, role_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group roles."
         resp = await self.groups_handler.delete_group_role(
             group_id, requester_user_id, role_id=role_id
         )
@@ -271,19 +355,18 @@ class GroupRoleServlet(RestServlet):
 
 
 class GroupRolesServlet(RestServlet):
-    """Get all group roles
-    """
+    """Get all group roles"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -308,18 +391,34 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         "/users/(?P<user_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, role_id, user_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, role_id: Optional[str], user_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        if role_id == "":
+            raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
+
+        if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+            raise SynapseError(
+                400,
+                "role_id may not be longer than %s characters"
+                % (MAX_GROUP_ROLEID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         content = parse_json_object_from_request(request)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group summaries."
         resp = await self.groups_handler.update_group_summary_user(
             group_id,
             requester_user_id,
@@ -331,10 +430,15 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, role_id, user_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, role_id: str, user_id: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group summaries."
         resp = await self.groups_handler.delete_group_summary_user(
             group_id, requester_user_id, user_id=user_id, role_id=role_id
         )
@@ -343,19 +447,18 @@ class GroupSummaryUsersRoleServlet(RestServlet):
 
 
 class GroupRoomServlet(RestServlet):
-    """Get all rooms in a group
-    """
+    """Get all rooms in a group"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -367,19 +470,18 @@ class GroupRoomServlet(RestServlet):
 
 
 class GroupUsersServlet(RestServlet):
-    """Get all users in a group
-    """
+    """Get all users in a group"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -391,19 +493,18 @@ class GroupUsersServlet(RestServlet):
 
 
 class GroupInvitedUsersServlet(RestServlet):
-    """Get users invited to a group
-    """
+    """Get users invited to a group"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -415,23 +516,25 @@ class GroupInvitedUsersServlet(RestServlet):
 
 
 class GroupSettingJoinPolicyServlet(RestServlet):
-    """Set group join policy
-    """
+    """Set group join policy"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
 
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group join policy."
         result = await self.groups_handler.set_group_join_policy(
             group_id, requester_user_id, content
         )
@@ -440,19 +543,18 @@ class GroupSettingJoinPolicyServlet(RestServlet):
 
 
 class GroupCreateServlet(RestServlet):
-    """Create a group
-    """
+    """Create a group"""
 
     PATTERNS = client_patterns("/create_group$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
         self.server_name = hs.hostname
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -461,6 +563,19 @@ class GroupCreateServlet(RestServlet):
         localpart = content.pop("localpart")
         group_id = GroupID(localpart, self.server_name).to_string()
 
+        if not localpart:
+            raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
+
+        if len(group_id) > MAX_GROUPID_LENGTH:
+            raise SynapseError(
+                400,
+                "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot create groups."
         result = await self.groups_handler.create_group(
             group_id, requester_user_id, content
         )
@@ -469,25 +584,29 @@ class GroupCreateServlet(RestServlet):
 
 
 class GroupAdminRoomsServlet(RestServlet):
-    """Add a room to the group
-    """
+    """Add a room to the group"""
 
     PATTERNS = client_patterns(
         "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, room_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify rooms in a group."
         result = await self.groups_handler.add_room_to_group(
             group_id, requester_user_id, room_id, content
         )
@@ -495,10 +614,15 @@ class GroupAdminRoomsServlet(RestServlet):
         return 200, result
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, room_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group categories."
         result = await self.groups_handler.remove_room_from_group(
             group_id, requester_user_id, room_id
         )
@@ -507,26 +631,30 @@ class GroupAdminRoomsServlet(RestServlet):
 
 
 class GroupAdminRoomsConfigServlet(RestServlet):
-    """Update the config of a room in a group
-    """
+    """Update the config of a room in a group"""
 
     PATTERNS = client_patterns(
         "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
         "/config/(?P<config_key>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, room_id, config_key):
+    async def on_PUT(
+        self, request: Request, group_id: str, room_id: str, config_key: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group categories."
         result = await self.groups_handler.update_room_in_group(
             group_id, requester_user_id, room_id, config_key, content
         )
@@ -535,14 +663,13 @@ class GroupAdminRoomsConfigServlet(RestServlet):
 
 
 class GroupAdminUsersInviteServlet(RestServlet):
-    """Invite a user to the group
-    """
+    """Invite a user to the group"""
 
     PATTERNS = client_patterns(
         "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
@@ -551,12 +678,15 @@ class GroupAdminUsersInviteServlet(RestServlet):
         self.is_mine_id = hs.is_mine_id
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, user_id):
+    async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
         config = content.get("config", {})
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot invite users to a group."
         result = await self.groups_handler.invite(
             group_id, user_id, requester_user_id, config
         )
@@ -565,25 +695,27 @@ class GroupAdminUsersInviteServlet(RestServlet):
 
 
 class GroupAdminUsersKickServlet(RestServlet):
-    """Kick a user from the group
-    """
+    """Kick a user from the group"""
 
     PATTERNS = client_patterns(
         "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, user_id):
+    async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot kick users from a group."
         result = await self.groups_handler.remove_user_from_group(
             group_id, user_id, requester_user_id, content
         )
@@ -592,23 +724,25 @@ class GroupAdminUsersKickServlet(RestServlet):
 
 
 class GroupSelfLeaveServlet(RestServlet):
-    """Leave a joined group
-    """
+    """Leave a joined group"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot leave a group for a users."
         result = await self.groups_handler.remove_user_from_group(
             group_id, requester_user_id, requester_user_id, content
         )
@@ -617,23 +751,25 @@ class GroupSelfLeaveServlet(RestServlet):
 
 
 class GroupSelfJoinServlet(RestServlet):
-    """Attempt to join a group, or knock
-    """
+    """Attempt to join a group, or knock"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot join a user to a group."
         result = await self.groups_handler.join_group(
             group_id, requester_user_id, content
         )
@@ -642,23 +778,25 @@ class GroupSelfJoinServlet(RestServlet):
 
 
 class GroupSelfAcceptInviteServlet(RestServlet):
-    """Accept a group invite
-    """
+    """Accept a group invite"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot accept an invite to a group."
         result = await self.groups_handler.accept_invite(
             group_id, requester_user_id, content
         )
@@ -667,19 +805,18 @@ class GroupSelfAcceptInviteServlet(RestServlet):
 
 
 class GroupSelfUpdatePublicityServlet(RestServlet):
-    """Update whether we publicise a users membership of a group
-    """
+    """Update whether we publicise a users membership of a group"""
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -691,19 +828,18 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
 
 
 class PublicisedGroupsForUserServlet(RestServlet):
-    """Get the list of groups a user is advertising
-    """
+    """Get the list of groups a user is advertising"""
 
     PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_GET(self, request, user_id):
+    async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -712,19 +848,18 @@ class PublicisedGroupsForUserServlet(RestServlet):
 
 
 class PublicisedGroupsForUsersServlet(RestServlet):
-    """Get the list of groups a user is advertising
-    """
+    """Get the list of groups a user is advertising"""
 
     PATTERNS = client_patterns("/publicised_groups$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         content = parse_json_object_from_request(request)
@@ -736,18 +871,17 @@ class PublicisedGroupsForUsersServlet(RestServlet):
 
 
 class GroupsForUserServlet(RestServlet):
-    """Get all groups the logged in user is joined to
-    """
+    """Get all groups the logged in user is joined to"""
 
     PATTERNS = client_patterns("/joined_groups$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -756,7 +890,7 @@ class GroupsForUserServlet(RestServlet):
         return 200, result
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     GroupServlet(hs).register(http_server)
     GroupSummaryServlet(hs).register(http_server)
     GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index a6134ead8a..f092e5b3a2 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,7 +271,10 @@ class SigningKeyUploadServlet(RestServlet):
         body = parse_json_object_from_request(request)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester, request, body, "add a device signing key to your account",
+            requester,
+            request,
+            body,
+            "add a device signing key to your account",
         )
 
         result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b093183e79..8f68d8dfc8 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "email", email
         )
@@ -191,6 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
             body, ["client_secret", "country", "phone_number", "send_attempt"]
         )
         client_secret = body["client_secret"]
+        assert_valid_client_secret(client_secret)
         country = body["country"]
         phone_number = body["phone_number"]
         send_attempt = body["send_attempt"]
@@ -205,6 +208,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "msisdn", msisdn
         )
@@ -287,6 +294,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
 
         sid = parse_string(request, "sid", required=True)
         client_secret = parse_string(request, "client_secret", required=True)
+        assert_valid_client_secret(client_secret)
         token = parse_string(request, "token", required=True)
 
         # Attempt to validate a 3PID session
@@ -514,7 +522,10 @@ class RegisterRestServlet(RestServlet):
         # not this will raise a user-interactive auth error.
         try:
             auth_result, params, session_id = await self.auth_handler.check_ui_auth(
-                self._registration_flows, request, body, "register a new account",
+                self._registration_flows,
+                request,
+                body,
+                "register a new account",
             )
         except InteractiveAuthIncompleteError as e:
             # The user needs to provide more steps to complete auth.
@@ -657,7 +668,9 @@ class RegisterRestServlet(RestServlet):
             username, as_token
         )
         return await self._create_registration_details(
-            user_id, body, is_appservice_ghost=True,
+            user_id,
+            body,
+            is_appservice_ghost=True,
         )
 
     async def _create_registration_details(
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 18c75738f8..fe765da23c 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -244,7 +244,9 @@ class RelationAggregationPaginationServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         await self.auth.check_user_in_room_or_world_readable(
-            room_id, requester.user.to_string(), allow_departed_users=True,
+            room_id,
+            requester.user.to_string(),
+            allow_departed_users=True,
         )
 
         # This checks that a) the event exists and b) the user is allowed to
@@ -322,7 +324,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         await self.auth.check_user_in_room_or_world_readable(
-            room_id, requester.user.to_string(), allow_departed_users=True,
+            room_id,
+            requester.user.to_string(),
+            allow_departed_users=True,
         )
 
         # This checks that a) the event exists and b) the user is allowed to
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index bf030e0ff4..147920767f 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 
 
 class RoomUpgradeRestServlet(RestServlet):
-    """Handler for room uprade requests.
+    """Handler for room upgrade requests.
 
     Handles requests of the form:
 
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index b3e4d5612e..8b9ef26cf2 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource):
 
         consent_template_directory = hs.config.user_consent_template_dir
 
+        # TODO: switch to synapse.util.templates.build_jinja_env
         loader = jinja2.FileSystemLoader(consent_template_directory)
         self._jinja_env = jinja2.Environment(
             loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 31a41e4a27..90bbeca679 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -137,7 +137,7 @@ def add_file_headers(
         # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token`
         # is (essentially) a single US-ASCII word, and a `quoted-string` is a
         # US-ASCII string surrounded by double-quotes, using backslash as an
-        # escape charater. Note that %-encoding is *not* permitted.
+        # escape character. Note that %-encoding is *not* permitted.
         #
         # `filename*` is defined to be an `ext-value`, which is defined in
         # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`,
@@ -300,6 +300,7 @@ class FileInfo:
         thumbnail_height (int)
         thumbnail_method (str)
         thumbnail_type (str): Content type of thumbnail, e.g. image/png
+        thumbnail_length (int): The size of the media file, in bytes.
     """
 
     def __init__(
@@ -312,6 +313,7 @@ class FileInfo:
         thumbnail_height=None,
         thumbnail_method=None,
         thumbnail_type=None,
+        thumbnail_length=None,
     ):
         self.server_name = server_name
         self.file_id = file_id
@@ -321,6 +323,7 @@ class FileInfo:
         self.thumbnail_height = thumbnail_height
         self.thumbnail_method = thumbnail_method
         self.thumbnail_type = thumbnail_type
+        self.thumbnail_length = thumbnail_length
 
 
 def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 3ed219ae43..48f4433155 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -51,7 +51,8 @@ class DownloadResource(DirectServeJsonResource):
             b" object-src 'self';",
         )
         request.setHeader(
-            b"Referrer-Policy", b"no-referrer",
+            b"Referrer-Policy",
+            b"no-referrer",
         )
         server_name, media_id, name = parse_media_id(request)
         if server_name == self.server_name:
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 4c9946a616..a0162d4255 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -184,7 +184,7 @@ class MediaRepository:
     async def get_local_media(
         self, request: Request, media_id: str, name: Optional[str]
     ) -> None:
-        """Responds to reqests for local media, if exists, or returns 404.
+        """Responds to requests for local media, if exists, or returns 404.
 
         Args:
             request: The incoming request.
@@ -306,7 +306,7 @@ class MediaRepository:
         media_info = await self.store.get_cached_remote_media(server_name, media_id)
 
         # file_id is the ID we use to track the file locally. If we've already
-        # seen the file then reuse the existing ID, otherwise genereate a new
+        # seen the file then reuse the existing ID, otherwise generate a new
         # one.
 
         # If we have an entry in the DB, try and look for it
@@ -325,7 +325,10 @@ class MediaRepository:
         # Failed to find the file anywhere, lets download it.
 
         try:
-            media_info = await self._download_remote_file(server_name, media_id,)
+            media_info = await self._download_remote_file(
+                server_name,
+                media_id,
+            )
         except SynapseError:
             raise
         except Exception as e:
@@ -351,7 +354,11 @@ class MediaRepository:
         responder = await self.media_storage.fetch_media(file_info)
         return responder, media_info
 
-    async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
+    async def _download_remote_file(
+        self,
+        server_name: str,
+        media_id: str,
+    ) -> dict:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
 
@@ -773,7 +780,11 @@ class MediaRepository:
                         )
                     except Exception as e:
                         thumbnail_exists = await self.store.get_remote_media_thumbnail(
-                            server_name, media_id, t_width, t_height, t_type,
+                            server_name,
+                            media_id,
+                            t_width,
+                            t_height,
+                            t_type,
                         )
                         if not thumbnail_exists:
                             raise e
@@ -832,7 +843,10 @@ class MediaRepository:
         return await self._remove_local_media_from_disk([media_id])
 
     async def delete_old_local_media(
-        self, before_ts: int, size_gt: int = 0, keep_profiles: bool = True,
+        self,
+        before_ts: int,
+        size_gt: int = 0,
+        keep_profiles: bool = True,
     ) -> Tuple[List[str], int]:
         """
         Delete local or remote media from this server by size and timestamp. Removes
@@ -849,7 +863,9 @@ class MediaRepository:
             A tuple of (list of deleted media IDs, total deleted media IDs).
         """
         old_media = await self.store.get_local_media_before(
-            before_ts, size_gt, keep_profiles,
+            before_ts,
+            size_gt,
+            keep_profiles,
         )
         return await self._remove_local_media_from_disk(old_media)
 
@@ -927,10 +943,10 @@ class MediaRepositoryResource(Resource):
 
            <thumbnail>
 
-    The thumbnail methods are "crop" and "scale". "scale" trys to return an
+    The thumbnail methods are "crop" and "scale". "scale" tries to return an
     image where either the width or the height is smaller than the requested
     size. The client should then scale and letterbox the image if it needs to
-    fit within a given rectangle. "crop" trys to return an image where the
+    fit within a given rectangle. "crop" tries to return an image where the
     width and height are close to the requested size and the aspect matches
     the requested size. The client should scale the image if it needs to fit
     within a given rectangle.
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 89cdd605aa..1057e638be 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -16,13 +16,17 @@ import contextlib
 import logging
 import os
 import shutil
-from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
+from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+
+import attr
 
 from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
+from synapse.api.errors import NotFoundError
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.util import Clock
 from synapse.util.file_consumer import BackgroundFileConsumer
 
 from ._base import FileInfo, Responder
@@ -58,6 +62,8 @@ class MediaStorage:
         self.local_media_directory = local_media_directory
         self.filepaths = filepaths
         self.storage_providers = storage_providers
+        self.spam_checker = hs.get_spam_checker()
+        self.clock = hs.get_clock()
 
     async def store_file(self, source: IO, file_info: FileInfo) -> str:
         """Write `source` to the on disk media store, and also any other
@@ -79,8 +85,7 @@ class MediaStorage:
         return fname
 
     async def write_to_file(self, source: IO, output: IO):
-        """Asynchronously write the `source` to `output`.
-        """
+        """Asynchronously write the `source` to `output`."""
         await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
 
     @contextlib.contextmanager
@@ -127,18 +132,29 @@ class MediaStorage:
                     f.flush()
                     f.close()
 
+                    spam = await self.spam_checker.check_media_file_for_spam(
+                        ReadableFileWrapper(self.clock, fname), file_info
+                    )
+                    if spam:
+                        logger.info("Blocking media due to spam checker")
+                        # Note that we'll delete the stored media, due to the
+                        # try/except below. The media also won't be stored in
+                        # the DB.
+                        raise SpamMediaException()
+
                     for provider in self.storage_providers:
                         await provider.store_file(path, file_info)
 
                     finished_called[0] = True
 
                 yield f, fname, finish
-        except Exception:
+        except Exception as e:
             try:
                 os.remove(fname)
             except Exception:
                 pass
-            raise
+
+            raise e from None
 
         if not finished_called:
             raise Exception("Finished callback not called")
@@ -302,3 +318,38 @@ class FileResponder(Responder):
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.open_file.close()
+
+
+class SpamMediaException(NotFoundError):
+    """The media was blocked by a spam checker, so we simply 404 the request (in
+    the same way as if it was quarantined).
+    """
+
+
+@attr.s(slots=True)
+class ReadableFileWrapper:
+    """Wrapper that allows reading a file in chunks, yielding to the reactor,
+    and writing to a callback.
+
+    This is simplified `FileSender` that takes an IO object rather than an
+    `IConsumer`.
+    """
+
+    CHUNK_SIZE = 2 ** 14
+
+    clock = attr.ib(type=Clock)
+    path = attr.ib(type=str)
+
+    async def write_chunks_to(self, callback: Callable[[bytes], None]):
+        """Reads the file in chunks and calls the callback with each chunk."""
+
+        with open(self.path, "rb") as file:
+            while True:
+                chunk = file.read(self.CHUNK_SIZE)
+                if not chunk:
+                    break
+
+                callback(chunk)
+
+                # We yield to the reactor by sleeping for 0 seconds.
+                await self.clock.sleep(0)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a632099167..6104ef4e46 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -58,7 +58,10 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
+_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I)
+_xml_encoding_match = re.compile(
+    br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I
+)
 _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
 
 OG_TAG_NAME_MAXLEN = 50
@@ -300,24 +303,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             with open(media_info["filename"], "rb") as file:
                 body = file.read()
 
-            encoding = None
-
-            # Let's try and figure out if it has an encoding set in a meta tag.
-            # Limit it to the first 1kb, since it ought to be in the meta tags
-            # at the top.
-            match = _charset_match.search(body[:1000])
-
-            # If we find a match, it should take precedence over the
-            # Content-Type header, so set it here.
-            if match:
-                encoding = match.group(1).decode("ascii")
-
-            # If we don't find a match, we'll look at the HTTP Content-Type, and
-            # if that doesn't exist, we'll fall back to UTF-8.
-            if not encoding:
-                content_match = _content_type_match.match(media_info["media_type"])
-                encoding = content_match.group(1) if content_match else "utf-8"
-
+            encoding = get_html_media_encoding(body, media_info["media_type"])
             og = decode_and_calc_og(body, media_info["uri"], encoding)
 
             # pre-cache the image for posterity
@@ -386,7 +372,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         """
         Check whether the URL should be downloaded as oEmbed content instead.
 
-        Params:
+        Args:
             url: The URL to check.
 
         Returns:
@@ -403,7 +389,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         """
         Request content from an oEmbed endpoint.
 
-        Params:
+        Args:
             endpoint: The oEmbed API endpoint.
             url: The URL to pass to the API.
 
@@ -594,8 +580,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         )
 
     async def _expire_url_cache_data(self) -> None:
-        """Clean up expired url cache content, media and thumbnails.
-        """
+        """Clean up expired url cache content, media and thumbnails."""
         # TODO: Delete from backup media store
 
         assert self._worker_run_media_background_jobs
@@ -689,30 +674,101 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
+def get_html_media_encoding(body: bytes, content_type: str) -> str:
+    """
+    Get the encoding of the body based on the (presumably) HTML body or media_type.
+
+    The precedence used for finding a character encoding is:
+
+    1. meta tag with a charset declared.
+    2. The XML document's character encoding attribute.
+    3. The Content-Type header.
+    4. Fallback to UTF-8.
+
+    Args:
+        body: The HTML document, as bytes.
+        content_type: The Content-Type header.
+
+    Returns:
+        The character encoding of the body, as a string.
+    """
+    # Limit searches to the first 1kb, since it ought to be at the top.
+    body_start = body[:1024]
+
+    # Let's try and figure out if it has an encoding set in a meta tag.
+    match = _charset_match.search(body_start)
+    if match:
+        return match.group(1).decode("ascii")
+
+    # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+
+    # If we didn't find a match, see if it an XML document with an encoding.
+    match = _xml_encoding_match.match(body_start)
+    if match:
+        return match.group(1).decode("ascii")
+
+    # If we don't find a match, we'll look at the HTTP Content-Type, and
+    # if that doesn't exist, we'll fall back to UTF-8.
+    content_match = _content_type_match.match(content_type)
+    if content_match:
+        return content_match.group(1)
+
+    return "utf-8"
+
+
 def decode_and_calc_og(
     body: bytes, media_uri: str, request_encoding: Optional[str] = None
 ) -> Dict[str, Optional[str]]:
+    """
+    Calculate metadata for an HTML document.
+
+    This uses lxml to parse the HTML document into the OG response. If errors
+    occur during processing of the document, an empty response is returned.
+
+    Args:
+        body: The HTML document, as bytes.
+        media_url: The URI used to download the body.
+        request_encoding: The character encoding of the body, as a string.
+
+    Returns:
+        The OG response as a dictionary.
+    """
     # If there's no body, nothing useful is going to be found.
     if not body:
         return {}
 
     from lxml import etree
 
+    # Create an HTML parser. If this fails, log and return no metadata.
     try:
         parser = etree.HTMLParser(recover=True, encoding=request_encoding)
-        tree = etree.fromstring(body, parser)
-        og = _calc_og(tree, media_uri)
+    except LookupError:
+        # blindly consider the encoding as utf-8.
+        parser = etree.HTMLParser(recover=True, encoding="utf-8")
+    except Exception as e:
+        logger.warning("Unable to create HTML parser: %s" % (e,))
+        return {}
+
+    def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
+        # Attempt to parse the body. If this fails, log and return no metadata.
+        tree = etree.fromstring(body_attempt, parser)
+
+        # The data was successfully parsed, but no tree was found.
+        if tree is None:
+            return {}
+
+        return _calc_og(tree, media_uri)
+
+    # Attempt to parse the body. If this fails, log and return no metadata.
+    try:
+        return _attempt_calc_og(body)
     except UnicodeDecodeError:
         # blindly try decoding the body as utf-8, which seems to fix
         # the charset mismatches on https://google.com
-        parser = etree.HTMLParser(recover=True, encoding=request_encoding)
-        tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
-        og = _calc_og(tree, media_uri)
-
-    return og
+        return _attempt_calc_og(body.decode("utf-8", "ignore"))
 
 
-def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
+def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
     # suck our tree into lxml and define our OG response.
 
     # if we see any image URLs in the OG response, then spider them
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d6880f2e6e..d653a58be9 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,7 +16,7 @@
 
 
 import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
 
 from twisted.web.http import Request
 
@@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource):
             return
 
         thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
-        if thumbnail_infos:
-            thumbnail_info = self._select_thumbnail(
-                width, height, method, m_type, thumbnail_infos
-            )
-
-            file_info = FileInfo(
-                server_name=None,
-                file_id=media_id,
-                url_cache=media_info["url_cache"],
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-            )
-
-            t_type = file_info.thumbnail_type
-            t_length = thumbnail_info["thumbnail_length"]
-
-            responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(request, responder, t_type, t_length)
-        else:
-            logger.info("Couldn't find any generated thumbnails")
-            respond_404(request)
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_id,
+            url_cache=media_info["url_cache"],
+            server_name=None,
+        )
 
     async def _select_or_generate_local_thumbnail(
         self,
@@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource):
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
         )
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_info["filesystem_id"],
+            url_cache=None,
+            server_name=server_name,
+        )
 
+    async def _select_and_respond_with_thumbnail(
+        self,
+        request: Request,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        thumbnail_infos: List[Dict[str, Any]],
+        file_id: str,
+        url_cache: Optional[str] = None,
+        server_name: Optional[str] = None,
+    ) -> None:
+        """
+        Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            request: The incoming request.
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of dictionaries of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: The URL cache value.
+            server_name: The server name, if this is a remote thumbnail.
+        """
         if thumbnail_infos:
-            thumbnail_info = self._select_thumbnail(
-                width, height, method, m_type, thumbnail_infos
+            file_info = self._select_thumbnail(
+                desired_width,
+                desired_height,
+                desired_method,
+                desired_type,
+                thumbnail_infos,
+                file_id,
+                url_cache,
+                server_name,
             )
-            file_info = FileInfo(
-                server_name=server_name,
-                file_id=media_info["filesystem_id"],
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-            )
-
-            t_type = file_info.thumbnail_type
-            t_length = thumbnail_info["thumbnail_length"]
+            if not file_info:
+                logger.info("Couldn't find a thumbnail matching the desired inputs")
+                respond_404(request)
+                return
 
             responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(request, responder, t_type, t_length)
+            await respond_with_responder(
+                request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+            )
         else:
             logger.info("Failed to find any generated thumbnails")
             respond_404(request)
@@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource):
         desired_height: int,
         desired_method: str,
         desired_type: str,
-        thumbnail_infos,
-    ) -> dict:
+        thumbnail_infos: List[Dict[str, Any]],
+        file_id: str,
+        url_cache: Optional[str],
+        server_name: Optional[str],
+    ) -> Optional[FileInfo]:
+        """
+        Choose an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of dictionaries of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: The URL cache value.
+            server_name: The server name, if this is a remote thumbnail.
+
+        Returns:
+             The thumbnail which best matches the desired parameters.
+        """
+        desired_method = desired_method.lower()
+
+        # The chosen thumbnail.
+        thumbnail_info = None
+
         d_w = desired_width
         d_h = desired_height
 
-        if desired_method.lower() == "crop":
+        if desired_method == "crop":
+            # Thumbnails that match equal or larger sizes of desired width/height.
             crop_info_list = []
+            # Other thumbnails.
             crop_info_list2 = []
             for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info["thumbnail_method"] != "crop":
+                    continue
+
                 t_w = info["thumbnail_width"]
                 t_h = info["thumbnail_height"]
-                t_method = info["thumbnail_method"]
-                if t_method == "crop":
-                    aspect_quality = abs(d_w * t_h - d_h * t_w)
-                    min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
-                    size_quality = abs((d_w - t_w) * (d_h - t_h))
-                    type_quality = desired_type != info["thumbnail_type"]
-                    length_quality = info["thumbnail_length"]
-                    if t_w >= d_w or t_h >= d_h:
-                        crop_info_list.append(
-                            (
-                                aspect_quality,
-                                min_quality,
-                                size_quality,
-                                type_quality,
-                                length_quality,
-                                info,
-                            )
+                aspect_quality = abs(d_w * t_h - d_h * t_w)
+                min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+                size_quality = abs((d_w - t_w) * (d_h - t_h))
+                type_quality = desired_type != info["thumbnail_type"]
+                length_quality = info["thumbnail_length"]
+                if t_w >= d_w or t_h >= d_h:
+                    crop_info_list.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
                         )
-                    else:
-                        crop_info_list2.append(
-                            (
-                                aspect_quality,
-                                min_quality,
-                                size_quality,
-                                type_quality,
-                                length_quality,
-                                info,
-                            )
+                    )
+                else:
+                    crop_info_list2.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
                         )
+                    )
             if crop_info_list:
-                return min(crop_info_list)[-1]
-            else:
-                return min(crop_info_list2)[-1]
-        else:
+                thumbnail_info = min(crop_info_list)[-1]
+            elif crop_info_list2:
+                thumbnail_info = min(crop_info_list2)[-1]
+        elif desired_method == "scale":
+            # Thumbnails that match equal or larger sizes of desired width/height.
             info_list = []
+            # Other thumbnails.
             info_list2 = []
+
             for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info["thumbnail_method"] != "scale":
+                    continue
+
                 t_w = info["thumbnail_width"]
                 t_h = info["thumbnail_height"]
-                t_method = info["thumbnail_method"]
                 size_quality = abs((d_w - t_w) * (d_h - t_h))
                 type_quality = desired_type != info["thumbnail_type"]
                 length_quality = info["thumbnail_length"]
-                if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+                if t_w >= d_w or t_h >= d_h:
                     info_list.append((size_quality, type_quality, length_quality, info))
-                elif t_method == "scale":
+                else:
                     info_list2.append(
                         (size_quality, type_quality, length_quality, info)
                     )
             if info_list:
-                return min(info_list)[-1]
-            else:
-                return min(info_list2)[-1]
+                thumbnail_info = min(info_list)[-1]
+            elif info_list2:
+                thumbnail_info = min(info_list2)[-1]
+
+        if thumbnail_info:
+            return FileInfo(
+                file_id=file_id,
+                url_cache=url_cache,
+                server_name=server_name,
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
+                thumbnail_length=thumbnail_info["thumbnail_length"],
+            )
+
+        # No matching thumbnail was found.
+        return None
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 6da76ae994..1136277794 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.http import Request
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_string
+from synapse.rest.media.v1.media_storage import SpamMediaException
 
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
@@ -86,9 +87,14 @@ class UploadResource(DirectServeJsonResource):
         #     disposition = headers.getRawHeaders(b"Content-Disposition")[0]
         # TODO(markjh): parse content-dispostion
 
-        content_uri = await self.media_repo.create_content(
-            media_type, upload_name, request.content, content_length, requester.user
-        )
+        try:
+            content_uri = await self.media_repo.create_content(
+                media_type, upload_name, request.content, content_length, requester.user
+            )
+        except SpamMediaException:
+            # For uploading of media we want to respond with a 400, instead of
+            # the default 404, as that would just be confusing.
+            raise SynapseError(400, "Bad content")
 
         logger.info("Uploaded content with URI %r", content_uri)
 
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index c0b733488b..9eeb970580 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,3 +12,56 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from typing import TYPE_CHECKING, Mapping
+
+from twisted.web.resource import Resource
+
+from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client.sso_register import SsoRegisterResource
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]:
+    """Builds a resource tree to include synapse-specific client resources
+
+    These are resources which should be loaded on all workers which expose a C-S API:
+    ie, the main process, and any generic workers so configured.
+
+    Returns:
+         map from path to Resource.
+    """
+    resources = {
+        # SSO bits. These are always loaded, whether or not SSO login is actually
+        # enabled (they just won't work very well if it's not)
+        "/_synapse/client/pick_idp": PickIdpResource(hs),
+        "/_synapse/client/pick_username": pick_username_resource(hs),
+        "/_synapse/client/new_user_consent": NewUserConsentResource(hs),
+        "/_synapse/client/sso_register": SsoRegisterResource(hs),
+    }
+
+    # provider-specific SSO bits. Only load these if they are enabled, since they
+    # rely on optional dependencies.
+    if hs.config.oidc_enabled:
+        from synapse.rest.synapse.client.oidc import OIDCResource
+
+        resources["/_synapse/client/oidc"] = OIDCResource(hs)
+
+    if hs.config.saml2_enabled:
+        from synapse.rest.synapse.client.saml2 import SAML2Resource
+
+        res = SAML2Resource(hs)
+        resources["/_synapse/client/saml2"] = res
+
+        # This is also mounted under '/_matrix' for backwards-compatibility.
+        # To be removed in Synapse v1.32.0.
+        resources["/_matrix/saml2"] = res
+
+    return resources
+
+
+__all__ = ["build_synapse_client_resource_tree"]
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
new file mode 100644
index 0000000000..b2e0f93810
--- /dev/null
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
+from synapse.http.servlet import parse_string
+from synapse.types import UserID
+from synapse.util.templates import build_jinja_env
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class NewUserConsentResource(DirectServeHtmlResource):
+    """A resource which collects consent to the server's terms from a new user
+
+    This resource gets mounted at /_synapse/client/new_user_consent, and is shown
+    when we are automatically creating a new user due to an SSO login.
+
+    It shows a template which prompts the user to go and read the Ts and Cs, and click
+    a clickybox if they have done so.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+        self._server_name = hs.hostname
+        self._consent_version = hs.config.consent.user_consent_version
+
+        def template_search_dirs():
+            if hs.config.sso.sso_template_dir:
+                yield hs.config.sso.sso_template_dir
+            yield hs.config.sso.default_template_dir
+
+        self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+            session = self._sso_handler.get_mapping_session(session_id)
+        except SynapseError as e:
+            logger.warning("Error fetching session: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        user_id = UserID(session.chosen_localpart, self._server_name)
+        user_profile = {
+            "display_name": session.display_name,
+        }
+
+        template_params = {
+            "user_id": user_id.to_string(),
+            "user_profile": user_profile,
+            "consent_version": self._consent_version,
+            "terms_url": "/_matrix/consent?v=%s" % (self._consent_version,),
+        }
+
+        template = self._jinja_env.get_template("sso_new_user_consent.html")
+        html = template.render(template_params)
+        respond_with_html(request, 200, html)
+
+    async def _async_render_POST(self, request: Request):
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        try:
+            accepted_version = parse_string(request, "accepted_version", required=True)
+        except SynapseError as e:
+            self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+            return
+
+        await self._sso_handler.handle_terms_accepted(
+            request, session_id, accepted_version
+        )
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index d958dd65bb..64c0deb75d 100644
--- a/synapse/rest/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -12,11 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 
 from twisted.web.resource import Resource
 
-from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
 
 logger = logging.getLogger(__name__)
 
@@ -25,3 +26,6 @@ class OIDCResource(Resource):
     def __init__(self, hs):
         Resource.__init__(self)
         self.putChild(b"callback", OIDCCallbackResource(hs))
+
+
+__all__ = ["OIDCResource"]
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index f7a0bc4bdb..1af33f0a45 100644
--- a/synapse/rest/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
@@ -12,19 +12,30 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.http.server import DirectServeHtmlResource
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class OIDCCallbackResource(DirectServeHtmlResource):
     isLeaf = 1
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self._oidc_handler = hs.get_oidc_handler()
 
     async def _async_render_GET(self, request):
         await self._oidc_handler.handle_oidc_callback(request)
+
+    async def _async_render_POST(self, request):
+        # the auth response can be returned via an x-www-form-urlencoded form instead
+        # of GET params, as per
+        # https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.
+        await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index d3b6803e65..96077cfcd1 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -12,42 +12,42 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING
 
-import pkg_resources
+import logging
+from typing import TYPE_CHECKING, List
 
 from twisted.web.http import Request
 from twisted.web.resource import Resource
-from twisted.web.static import File
 
 from synapse.api.errors import SynapseError
-from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME
-from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource
-from synapse.http.servlet import parse_string
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    DirectServeJsonResource,
+    respond_with_html,
+)
+from synapse.http.servlet import parse_boolean, parse_string
 from synapse.http.site import SynapseRequest
+from synapse.util.templates import build_jinja_env
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+logger = logging.getLogger(__name__)
+
 
 def pick_username_resource(hs: "HomeServer") -> Resource:
     """Factory method to generate the username picker resource.
 
-    This resource gets mounted under /_synapse/client/pick_username. The top-level
-    resource is just a File resource which serves up the static files in the resources
-    "res" directory, but it has a couple of children:
-
-    * "submit", which does the mechanics of registering the new user, and redirects the
-      browser back to the client URL
+    This resource gets mounted under /_synapse/client/pick_username and has two
+       children:
 
-    * "check": checks if a userid is free.
+      * "account_details": renders the form and handles the POSTed response
+      * "check": a JSON endpoint which checks if a userid is free.
     """
 
-    # XXX should we make this path customisable so that admins can restyle it?
-    base_path = pkg_resources.resource_filename("synapse", "res/username_picker")
-
-    res = File(base_path)
-    res.putChild(b"submit", SubmitResource(hs))
+    res = Resource()
+    res.putChild(b"account_details", AccountDetailsResource(hs))
     res.putChild(b"check", AvailabilityCheckResource(hs))
 
     return res
@@ -61,28 +61,71 @@ class AvailabilityCheckResource(DirectServeJsonResource):
     async def _async_render_GET(self, request: Request):
         localpart = parse_string(request, "username", required=True)
 
-        session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
-        if not session_id:
-            raise SynapseError(code=400, msg="missing session_id")
+        session_id = get_username_mapping_session_cookie_from_request(request)
 
         is_available = await self._sso_handler.check_username_availability(
-            localpart, session_id.decode("ascii", errors="replace")
+            localpart, session_id
         )
         return 200, {"available": is_available}
 
 
-class SubmitResource(DirectServeHtmlResource):
+class AccountDetailsResource(DirectServeHtmlResource):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self._sso_handler = hs.get_sso_handler()
 
-    async def _async_render_POST(self, request: SynapseRequest):
-        localpart = parse_string(request, "username", required=True)
+        def template_search_dirs():
+            if hs.config.sso.sso_template_dir:
+                yield hs.config.sso.sso_template_dir
+            yield hs.config.sso.default_template_dir
+
+        self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+            session = self._sso_handler.get_mapping_session(session_id)
+        except SynapseError as e:
+            logger.warning("Error fetching session: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        idp_id = session.auth_provider_id
+        template_params = {
+            "idp": self._sso_handler.get_identity_providers()[idp_id],
+            "user_attributes": {
+                "display_name": session.display_name,
+                "emails": session.emails,
+            },
+        }
+
+        template = self._jinja_env.get_template("sso_auth_account_details.html")
+        html = template.render(template_params)
+        respond_with_html(request, 200, html)
 
-        session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
-        if not session_id:
-            raise SynapseError(code=400, msg="missing session_id")
+    async def _async_render_POST(self, request: SynapseRequest):
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        try:
+            localpart = parse_string(request, "username", required=True)
+            use_display_name = parse_boolean(request, "use_display_name", default=False)
+
+            try:
+                emails_to_use = [
+                    val.decode("utf-8") for val in request.args.get(b"use_email", [])
+                ]  # type: List[str]
+            except ValueError:
+                raise SynapseError(400, "Query parameter use_email must be utf-8")
+        except SynapseError as e:
+            logger.warning("[session %s] bad param: %s", session_id, e)
+            self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+            return
 
         await self._sso_handler.handle_submit_username_request(
-            request, localpart, session_id.decode("ascii", errors="replace")
+            request, session_id, localpart, use_display_name, emails_to_use
         )
diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 68da37ca6a..3e8235ee1e 100644
--- a/synapse/rest/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -12,12 +12,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 
 from twisted.web.resource import Resource
 
-from synapse.rest.saml2.metadata_resource import SAML2MetadataResource
-from synapse.rest.saml2.response_resource import SAML2ResponseResource
+from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
+from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
 
 logger = logging.getLogger(__name__)
 
@@ -27,3 +28,6 @@ class SAML2Resource(Resource):
         Resource.__init__(self)
         self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
         self.putChild(b"authn_response", SAML2ResponseResource(hs))
+
+
+__all__ = ["SAML2Resource"]
diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index 1e8526e22e..1e8526e22e 100644
--- a/synapse/rest/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index f6668fb5e3..f6668fb5e3 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py
new file mode 100644
index 0000000000..dfefeb7796
--- /dev/null
+++ b/synapse/rest/synapse/client/sso_register.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SsoRegisterResource(DirectServeHtmlResource):
+    """A resource which completes SSO registration
+
+    This resource gets mounted at /_synapse/client/sso_register, and is shown
+    after we collect username and/or consent for a new SSO user. It (finally) registers
+    the user, and confirms redirect to the client
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+        await self._sso_handler.register_sso_user(request, session_id)
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 241fe746d9..f591cc6c5c 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -34,6 +34,10 @@ class WellKnownBuilder:
         self._config = hs.config
 
     def get_well_known(self):
+        # if we don't have a public_baseurl, we can't help much here.
+        if self._config.public_baseurl is None:
+            return None
+
         result = {"m.homeserver": {"base_url": self._config.public_baseurl}}
 
         if self._config.default_identity_server:
diff --git a/synapse/server.py b/synapse/server.py
index 9cdda83aa1..6b3892e3cd 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -25,7 +25,17 @@ import abc
 import functools
 import logging
 import os
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import twisted.internet.base
 import twisted.internet.tcp
@@ -103,6 +113,7 @@ from synapse.notifier import Notifier
 from synapse.push.action_generator import ActionGenerator
 from synapse.push.pusherpool import PusherPool
 from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.tcp.external_cache import ExternalCache
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.resource import ReplicationStreamer
 from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -128,6 +139,8 @@ from synapse.util.stringutils import random_string
 logger = logging.getLogger(__name__)
 
 if TYPE_CHECKING:
+    from txredisapi import RedisProtocol
+
     from synapse.handlers.oidc_handler import OidcHandler
     from synapse.handlers.saml_handler import SamlHandler
 
@@ -585,7 +598,9 @@ class HomeServer(metaclass=abc.ABCMeta):
         return UserDirectoryHandler(self)
 
     @cache_in_self
-    def get_groups_local_handler(self):
+    def get_groups_local_handler(
+        self,
+    ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
         if self.config.worker_app:
             return GroupsLocalWorkerHandler(self)
         else:
@@ -716,6 +731,33 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_account_data_handler(self) -> AccountDataHandler:
         return AccountDataHandler(self)
 
+    @cache_in_self
+    def get_external_cache(self) -> ExternalCache:
+        return ExternalCache(self)
+
+    @cache_in_self
+    def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
+        if not self.config.redis.redis_enabled:
+            return None
+
+        # We only want to import redis module if we're using it, as we have
+        # `txredisapi` as an optional dependency.
+        from synapse.replication.tcp.redis import lazyConnection
+
+        logger.info(
+            "Connecting to redis (host=%r port=%r) for external cache",
+            self.config.redis_host,
+            self.config.redis_port,
+        )
+
+        return lazyConnection(
+            hs=self,
+            host=self.config.redis_host,
+            port=self.config.redis_port,
+            password=self.config.redis.redis_password,
+            reconnect=True,
+        )
+
     async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
         return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 8dd01fce76..6652451346 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
 
 
 class ResourceLimitsServerNotices:
-    """ Keeps track of whether the server has reached it's resource limit and
+    """Keeps track of whether the server has reached it's resource limit and
     ensures that the client is kept up to date.
     """
 
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 84f59c7d85..c3d6e80c49 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -310,6 +310,7 @@ class StateHandler:
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
+            entry = None
 
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
                 current_state_ids=state_ids_before_event,
             )
 
-            # XXX: can we update the state cache entry for the new state group? or
-            # could we set a flag on resolve_state_groups_for_events to tell it to
-            # always make a state group?
+            # Assign the new state group to the cached state entry.
+            #
+            # Note that this can race in that we could generate multiple state
+            # groups for the same state entry, but that is just inefficient
+            # rather than dangerous.
+            if entry and entry.state_group is None:
+                entry.state_group = state_group_before_event
 
         #
         # now if it's not a state event, we're done
@@ -393,7 +398,7 @@ class StateHandler:
     async def resolve_state_groups_for_events(
         self, room_id: str, event_ids: Iterable[str]
     ) -> _StateCacheEntry:
-        """ Given a list of event_ids this method fetches the state at each
+        """Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
         Args:
@@ -565,7 +570,9 @@ class StateResolutionHandler:
                 return cache
 
             logger.info(
-                "Resolving state for %s with groups %s", room_id, list(group_names),
+                "Resolving state for %s with groups %s",
+                room_id,
+                list(group_names),
             )
 
             state_groups_histogram.observe(len(state_groups_ids))
@@ -610,7 +617,7 @@ class StateResolutionHandler:
             event_map:
                 a dict from event_id to event, for any events that we happen to
                 have in flight (eg, those currently being persisted). This will be
-                used as a starting point fof finding the state we need; any missing
+                used as a starting point for finding the state we need; any missing
                 events will be requested via state_map_factory.
 
                 If None, all events will be fetched via state_res_store.
@@ -651,11 +658,15 @@ class StateResolutionHandler:
             return
 
         self._report_biggest(
-            lambda i: i.cpu_time, "CPU time", _biggest_room_by_cpu_counter,
+            lambda i: i.cpu_time,
+            "CPU time",
+            _biggest_room_by_cpu_counter,
         )
 
         self._report_biggest(
-            lambda i: i.db_time, "DB time", _biggest_room_by_db_counter,
+            lambda i: i.db_time,
+            "DB time",
+            _biggest_room_by_db_counter,
         )
 
         self._state_res_metrics.clear()
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 85edae053d..ce255da6fd 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -95,7 +95,11 @@ async def resolve_events_with_store(
         if event.room_id != room_id:
             raise Exception(
                 "Attempting to state-resolve for room %s with event %s which is in %s"
-                % (room_id, event.event_id, event.room_id,)
+                % (
+                    room_id,
+                    event.event_id,
+                    event.room_id,
+                )
             )
 
     # get the ids of the auth events which allow us to authenticate the
@@ -119,7 +123,11 @@ async def resolve_events_with_store(
         if event.room_id != room_id:
             raise Exception(
                 "Attempting to state-resolve for room %s with event %s which is in %s"
-                % (room_id, event.event_id, event.room_id,)
+                % (
+                    room_id,
+                    event.event_id,
+                    event.room_id,
+                )
             )
 
     state_map.update(state_map_new)
@@ -243,7 +251,7 @@ def _resolve_with_state(
 def _resolve_state_events(
     conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
 ) -> StateMap[EventBase]:
-    """ This is where we actually decide which of the conflicted state to
+    """This is where we actually decide which of the conflicted state to
     use.
 
     We resolve conflicts in the following order:
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index e585954bd8..e73a548ee4 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -118,7 +118,11 @@ async def resolve_events_with_store(
         if event.room_id != room_id:
             raise Exception(
                 "Attempting to state-resolve for room %s with event %s which is in %s"
-                % (room_id, event.event_id, event.room_id,)
+                % (
+                    room_id,
+                    event.event_id,
+                    event.room_id,
+                )
             )
 
     full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index c0d9d1240f..a3c52695e9 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -43,8 +43,7 @@ __all__ = ["Databases", "DataStore"]
 
 
 class Storage:
-    """The high level interfaces for talking to various storage layers.
-    """
+    """The high level interfaces for talking to various storage layers."""
 
     def __init__(self, hs: "HomeServer", stores: Databases):
         # We include the main data store here mainly so that we don't have to
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 29b8ca676a..329660cf0f 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -77,7 +77,7 @@ class BackgroundUpdatePerformance:
 
 
 class BackgroundUpdater:
-    """ Background updates are updates to the database that run in the
+    """Background updates are updates to the database that run in the
     background. Each update processes a batch of data at once. We attempt to
     limit the impact of each update by monitoring how long each batch takes to
     process and autotuning the batch size.
@@ -158,8 +158,7 @@ class BackgroundUpdater:
         return False
 
     async def has_completed_background_update(self, update_name: str) -> bool:
-        """Check if the given background update has finished running.
-        """
+        """Check if the given background update has finished running."""
         if self._all_done:
             return True
 
@@ -198,7 +197,8 @@ class BackgroundUpdater:
 
         if not self._current_background_update:
             all_pending_updates = await self.db_pool.runInteraction(
-                "background_updates", get_background_updates_txn,
+                "background_updates",
+                get_background_updates_txn,
             )
             if not all_pending_updates:
                 # no work left to do
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a19d65ad23..4646926449 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -85,8 +85,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 def make_pool(
     reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
 ) -> adbapi.ConnectionPool:
-    """Get the connection pool for the database.
-    """
+    """Get the connection pool for the database."""
 
     # By default enable `cp_reconnect`. We need to fiddle with db_args in case
     # someone has explicitly set `cp_reconnect`.
@@ -158,8 +157,8 @@ class LoggingDatabaseConnection:
     def commit(self) -> None:
         self.conn.commit()
 
-    def rollback(self, *args, **kwargs) -> None:
-        self.conn.rollback(*args, **kwargs)
+    def rollback(self) -> None:
+        self.conn.rollback()
 
     def __enter__(self) -> "Connection":
         self.conn.__enter__()
@@ -244,12 +243,15 @@ class LoggingTransaction:
         assert self.exception_callbacks is not None
         self.exception_callbacks.append((callback, args, kwargs))
 
+    def fetchone(self) -> Optional[Tuple]:
+        return self.txn.fetchone()
+
+    def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
+        return self.txn.fetchmany(size=size)
+
     def fetchall(self) -> List[Tuple]:
         return self.txn.fetchall()
 
-    def fetchone(self) -> Tuple:
-        return self.txn.fetchone()
-
     def __iter__(self) -> Iterator[Tuple]:
         return self.txn.__iter__()
 
@@ -262,13 +264,18 @@ class LoggingTransaction:
         return self.txn.description
 
     def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+        """Similar to `executemany`, except `txn.rowcount` will not be correct
+        afterwards.
+
+        More efficient than `executemany` on PostgreSQL
+        """
+
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch  # type: ignore
 
             self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
         else:
-            for val in args:
-                self.execute(sql, val)
+            self.executemany(sql, args)
 
     def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
         """Corresponds to psycopg2.extras.execute_values. Only available when
@@ -424,8 +431,7 @@ class DatabasePool:
         )
 
     def is_running(self) -> bool:
-        """Is the database pool currently running
-        """
+        """Is the database pool currently running"""
         return self._db_pool.running
 
     async def _check_safe_to_upsert(self) -> None:
@@ -538,7 +544,11 @@ class DatabasePool:
                     # This can happen if the database disappears mid
                     # transaction.
                     transaction_logger.warning(
-                        "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
+                        "[TXN OPERROR] {%s} %s %d/%d",
+                        name,
+                        e,
+                        i,
+                        N,
                     )
                     if i < N:
                         i += 1
@@ -559,7 +569,9 @@ class DatabasePool:
                                 conn.rollback()
                             except self.engine.module.Error as e1:
                                 transaction_logger.warning(
-                                    "[TXN EROLL] {%s} %s", name, e1,
+                                    "[TXN EROLL] {%s} %s",
+                                    name,
+                                    e1,
                                 )
                             continue
                     raise
@@ -749,6 +761,7 @@ class DatabasePool:
         Returns:
             A list of dicts where the key is the column header.
         """
+        assert cursor.description is not None, "cursor.description was None"
         col_headers = [intern(str(column[0])) for column in cursor.description]
         results = [dict(zip(col_headers, row)) for row in cursor]
         return results
@@ -888,7 +901,7 @@ class DatabasePool:
             ", ".join("?" for _ in keys[0]),
         )
 
-        txn.executemany(sql, vals)
+        txn.execute_batch(sql, vals)
 
     async def simple_upsert(
         self,
@@ -1397,7 +1410,10 @@ class DatabasePool:
 
     @staticmethod
     def simple_select_onecol_txn(
-        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
     ) -> List[Any]:
         sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
 
@@ -1707,7 +1723,11 @@ class DatabasePool:
             desc: description of the transaction, for logging and metrics
         """
         await self.runInteraction(
-            desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
+            desc,
+            self.simple_delete_one_txn,
+            table,
+            keyvalues,
+            db_autocommit=True,
         )
 
     @staticmethod
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0c24325011..e84f8b42f7 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -56,7 +56,10 @@ class Databases:
                     database_config.databases,
                 )
                 prepare_database(
-                    db_conn, engine, hs.config, databases=database_config.databases,
+                    db_conn,
+                    engine,
+                    hs.config,
+                    databases=database_config.databases,
                 )
 
                 database = DatabasePool(hs, database_config, engine)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ae561a2da3..70b49854cf 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
 from .event_federation import EventFederationStore
 from .event_push_actions import EventPushActionsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
+from .events_forward_extremities import EventForwardExtremitiesStore
 from .filtering import FilteringStore
 from .group_server import GroupServerStore
 from .keys import KeyStore
@@ -118,6 +119,7 @@ class DataStore(
     UIAuthStore,
     CacheInvalidationWorkerStore,
     ServerMetricsStore,
+    EventForwardExtremitiesStore,
 ):
     def __init__(self, database: DatabasePool, db_conn, hs):
         self.hs = hs
@@ -338,7 +340,7 @@ class DataStore(
             count = txn.fetchone()[0]
 
             sql = (
-                "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+                "SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url "
                 + sql_base
                 + " ORDER BY u.name LIMIT ? OFFSET ?"
             )
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e550cbc866..03a38422a1 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -73,8 +73,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         return self.services_cache
 
     def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
-        """Check if the user is one associated with an app service (exclusively)
-        """
+        """Check if the user is one associated with an app service (exclusively)"""
         if self.exclusive_user_regex:
             return bool(self.exclusive_user_regex.match(user_id))
         else:
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index ea1e8fb580..6d18e692b0 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -280,8 +280,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         return batch_size
 
     async def _devices_last_seen_update(self, progress, batch_size):
-        """Background update to insert last seen info into devices table
-        """
+        """Background update to insert last seen info into devices table"""
 
         last_user_id = progress.get("last_user_id", "")
         last_device_id = progress.get("last_device_id", "")
@@ -363,8 +362,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
     @wrap_as_background_process("prune_old_user_ips")
     async def _prune_old_user_ips(self):
-        """Removes entries in user IPs older than the configured period.
-        """
+        """Removes entries in user IPs older than the configured period."""
 
         if self.user_ips_max_age is None:
             # Nothing to do
@@ -565,7 +563,11 @@ class ClientIpStore(ClientIpWorkerStore):
         results = {}
 
         for key in self._batch_row_update:
-            uid, access_token, ip, = key
+            (
+                uid,
+                access_token,
+                ip,
+            ) = key
             if uid == user_id:
                 user_agent, _, last_seen = self._batch_row_update[key]
                 results[(access_token, ip)] = (user_agent, last_seen)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 31f70ac5ef..45ca6620a8 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -450,7 +450,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 },
             )
 
-            # Add the messages to the approriate local device inboxes so that
+            # Add the messages to the appropriate local device inboxes so that
             # they'll be sent to the devices when they next sync.
             self._add_messages_to_local_device_inbox_txn(
                 txn, stream_id, local_messages_by_user_then_device
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9097677648..d327e9aa0b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -315,7 +315,8 @@ class DeviceWorkerStore(SQLBaseStore):
 
             # make sure we go through the devices in stream order
             device_ids = sorted(
-                user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+                user_devices.keys(),
+                key=lambda i: query_map[(user_id, i)][0],
             )
 
             for device_id in device_ids:
@@ -366,8 +367,7 @@ class DeviceWorkerStore(SQLBaseStore):
     async def mark_as_sent_devices_by_remote(
         self, destination: str, stream_id: int
     ) -> None:
-        """Mark that updates have successfully been sent to the destination.
-        """
+        """Mark that updates have successfully been sent to the destination."""
         await self.db_pool.runInteraction(
             "mark_as_sent_devices_by_remote",
             self._mark_as_sent_devices_by_remote_txn,
@@ -681,7 +681,8 @@ class DeviceWorkerStore(SQLBaseStore):
         return results
 
     async def get_user_ids_requiring_device_list_resync(
-        self, user_ids: Optional[Collection[str]] = None,
+        self,
+        user_ids: Optional[Collection[str]] = None,
     ) -> Set[str]:
         """Given a list of remote users return the list of users that we
         should resync the device lists for. If None is given instead of a list,
@@ -721,8 +722,7 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
-        """Mark that we no longer track device lists for remote user.
-        """
+        """Mark that we no longer track device lists for remote user."""
 
         def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
             self.db_pool.simple_delete_txn(
@@ -897,12 +897,13 @@ class DeviceWorkerStore(SQLBaseStore):
                 DELETE FROM device_lists_outbound_last_success
                 WHERE destination = ? AND user_id = ?
             """
-            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+            txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
 
             logger.info("Pruned %d device list outbound pokes", count)
 
         await self.db_pool.runInteraction(
-            "_prune_old_outbound_device_pokes", _prune_txn,
+            "_prune_old_outbound_device_pokes",
+            _prune_txn,
         )
 
 
@@ -943,7 +944,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
         # clear out duplicate device list outbound pokes
         self.db_pool.updates.register_background_update_handler(
-            BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
+            BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
+            self._remove_duplicate_outbound_pokes,
         )
 
         # a pair of background updates that were added during the 1.14 release cycle,
@@ -1004,17 +1006,23 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             row = None
             for row in rows:
                 self.db_pool.simple_delete_txn(
-                    txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
+                    txn,
+                    "device_lists_outbound_pokes",
+                    {x: row[x] for x in KEY_COLS},
                 )
 
                 row["sent"] = False
                 self.db_pool.simple_insert_txn(
-                    txn, "device_lists_outbound_pokes", row,
+                    txn,
+                    "device_lists_outbound_pokes",
+                    row,
                 )
 
             if row:
                 self.db_pool.updates._background_update_progress_txn(
-                    txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
+                    txn,
+                    BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
+                    {"last_row": row},
                 )
 
             return len(rows)
@@ -1286,7 +1294,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         # we've done a full resync, so we remove the entry that says we need
         # to resync
         self.db_pool.simple_delete_txn(
-            txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
+            txn,
+            table="device_lists_remote_resync",
+            keyvalues={"user_id": user_id},
         )
 
     async def add_device_change_to_streams(
@@ -1336,14 +1346,16 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         stream_ids: List[str],
     ):
         txn.call_after(
-            self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
+            self._device_list_stream_cache.entity_has_changed,
+            user_id,
+            stream_ids[-1],
         )
 
         min_stream_id = stream_ids[0]
 
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
-        txn.executemany(
+        txn.execute_batch(
             """
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index e5060d4c46..267b948397 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -85,7 +85,7 @@ class DirectoryStore(DirectoryWorkerStore):
         servers: Iterable[str],
         creator: Optional[str] = None,
     ) -> None:
-        """ Creates an association between a room alias and room_id/servers
+        """Creates an association between a room alias and room_id/servers
 
         Args:
             room_alias: The alias to create.
@@ -160,7 +160,10 @@ class DirectoryStore(DirectoryWorkerStore):
         return room_id
 
     async def update_aliases_for_room(
-        self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
+        self,
+        old_room_id: str,
+        new_room_id: str,
+        creator: Optional[str] = None,
     ) -> None:
         """Repoint all of the aliases for a given room, to a different room.
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c128889bf9..f1e7859d26 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -361,7 +361,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
     async def count_e2e_one_time_keys(
         self, user_id: str, device_id: str
     ) -> Dict[str, int]:
-        """ Count the number of one time keys the server has for a device
+        """Count the number of one time keys the server has for a device
         Returns:
             A mapping from algorithm to number of keys for that algorithm.
         """
@@ -494,7 +494,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         )
 
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
-        self, txn: Connection, user_ids: List[str],
+        self,
+        txn: Connection,
+        user_ids: List[str],
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
@@ -556,7 +558,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return result
 
     def _get_e2e_cross_signing_signatures_txn(
-        self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+        self,
+        txn: Connection,
+        keys: Dict[str, Dict[str, dict]],
+        from_user_id: str,
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing signatures made by a user on a set of keys.
 
@@ -634,7 +639,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, dict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -724,7 +729,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str]]
-    ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+    ) -> Dict[str, Dict[str, Dict[str, str]]]:
         """Take a list of one time keys out of the database.
 
         Args:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 8326640d20..18ddb92fcc 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -71,7 +71,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         return await self.get_events_as_list(event_ids)
 
     async def get_auth_chain_ids(
-        self, event_ids: Collection[str], include_given: bool = False,
+        self,
+        event_ids: Collection[str],
+        include_given: bool = False,
     ) -> List[str]:
         """Get auth events for given event_ids. The events *must* be state events.
 
@@ -273,7 +275,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                     # origin chain.
                     if origin_sequence_number <= chains.get(origin_chain_id, 0):
                         chains[target_chain_id] = max(
-                            target_sequence_number, chains.get(target_chain_id, 0),
+                            target_sequence_number,
+                            chains.get(target_chain_id, 0),
                         )
 
                 seen_chains.add(target_chain_id)
@@ -371,7 +374,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         # and state sets {A} and {B} then walking the auth chains of A and B
         # would immediately show that C is reachable by both. However, if we
         # stopped at C then we'd only reach E via the auth chain of B and so E
-        # would errornously get included in the returned difference.
+        # would erroneously get included in the returned difference.
         #
         # The other thing that we do is limit the number of auth chains we walk
         # at once, due to practical limits (i.e. we can only query the database
@@ -497,7 +500,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
                         a_ids = new_aids
 
-                # Mark that the auth event is reachable by the approriate sets.
+                # Mark that the auth event is reachable by the appropriate sets.
                 sets.intersection_update(event_to_missing_sets[event_id])
 
             search.sort()
@@ -632,8 +635,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         )
 
     async def get_min_depth(self, room_id: str) -> int:
-        """For the given room, get the minimum depth we have seen for it.
-        """
+        """For the given room, get the minimum depth we have seen for it."""
         return await self.db_pool.runInteraction(
             "get_min_depth", self._get_min_depth_interaction, room_id
         )
@@ -858,12 +860,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             )
 
         await self.db_pool.runInteraction(
-            "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
+            "_delete_old_forward_extrem_cache",
+            _delete_old_forward_extrem_cache_txn,
         )
 
 
 class EventFederationStore(EventFederationWorkerStore):
-    """ Responsible for storing and serving up the various graphs associated
+    """Responsible for storing and serving up the various graphs associated
     with an event. Including the main event graph and the auth chains for an
     event.
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 1b657191a9..78245ad5bd 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -54,8 +54,7 @@ def _serialize_action(actions, is_highlight):
 
 
 def _deserialize_action(actions, is_highlight):
-    """Custom deserializer for actions. This allows us to "compress" common actions
-    """
+    """Custom deserializer for actions. This allows us to "compress" common actions"""
     if actions:
         return db_to_json(actions)
 
@@ -91,7 +90,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     @cached(num_args=3, tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
-        self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+        self,
+        room_id: str,
+        user_id: str,
+        last_read_event_id: Optional[str],
     ) -> Dict[str, int]:
         """Get the notification count, the highlight count and the unread message count
         for a given user in a given room after the given read receipt.
@@ -120,13 +122,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     def _get_unread_counts_by_receipt_txn(
-        self, txn, room_id, user_id, last_read_event_id,
+        self,
+        txn,
+        room_id,
+        user_id,
+        last_read_event_id,
     ):
         stream_ordering = None
 
         if last_read_event_id is not None:
             stream_ordering = self.get_stream_id_for_event_txn(
-                txn, last_read_event_id, allow_none=True,
+                txn,
+                last_read_event_id,
+                allow_none=True,
             )
 
         if stream_ordering is None:
@@ -487,7 +495,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 VALUES (?, ?, ?, ?, ?, ?)
             """
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     _gen_entry(user_id, actions)
@@ -803,7 +811,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             ],
         )
 
-        txn.executemany(
+        txn.execute_batch(
             """
                 UPDATE event_push_summary
                 SET notif_count = ?, unread_count = ?, stream_ordering = ?
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3216b3f3c8..287606cb4f 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -399,7 +399,9 @@ class PersistEventsStore:
         self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
 
     def _persist_event_auth_chain_txn(
-        self, txn: LoggingTransaction, events: List[EventBase],
+        self,
+        txn: LoggingTransaction,
+        events: List[EventBase],
     ) -> None:
 
         # We only care about state events, so this if there are no state events.
@@ -470,11 +472,16 @@ class PersistEventsStore:
         event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
 
         self._add_chain_cover_index(
-            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+            txn,
+            self.db_pool,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
         )
 
-    @staticmethod
+    @classmethod
     def _add_chain_cover_index(
+        cls,
         txn,
         db_pool: DatabasePool,
         event_to_room_id: Dict[str, str],
@@ -516,7 +523,10 @@ class PersistEventsStore:
             # simple_select_many, but this case happens rarely and almost always
             # with a single row.)
             auth_events = db_pool.simple_select_onecol_txn(
-                txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
+                txn,
+                "event_auth",
+                keyvalues={"event_id": event_id},
+                retcol="auth_id",
             )
 
             events_to_calc_chain_id_for.add(event_id)
@@ -549,7 +559,9 @@ class PersistEventsStore:
                 WHERE
             """
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "event_id", missing_auth_chains,
+                txn.database_engine,
+                "event_id",
+                missing_auth_chains,
             )
             txn.execute(sql + clause, args)
 
@@ -614,60 +626,17 @@ class PersistEventsStore:
         if not events_to_calc_chain_id_for:
             return
 
-        # We now calculate the chain IDs/sequence numbers for the events. We
-        # do this by looking at the chain ID and sequence number of any auth
-        # event with the same type/state_key and incrementing the sequence
-        # number by one. If there was no match or the chain ID/sequence
-        # number is already taken we generate a new chain.
-        #
-        # We need to do this in a topologically sorted order as we want to
-        # generate chain IDs/sequence numbers of an event's auth events
-        # before the event itself.
-        chains_tuples_allocated = set()  # type: Set[Tuple[int, int]]
-        new_chain_tuples = {}  # type: Dict[str, Tuple[int, int]]
-        for event_id in sorted_topologically(
-            events_to_calc_chain_id_for, event_to_auth_chain
-        ):
-            existing_chain_id = None
-            for auth_id in event_to_auth_chain.get(event_id, []):
-                if event_to_types.get(event_id) == event_to_types.get(auth_id):
-                    existing_chain_id = chain_map[auth_id]
-                    break
-
-            new_chain_tuple = None
-            if existing_chain_id:
-                # We found a chain ID/sequence number candidate, check its
-                # not already taken.
-                proposed_new_id = existing_chain_id[0]
-                proposed_new_seq = existing_chain_id[1] + 1
-                if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
-                    already_allocated = db_pool.simple_select_one_onecol_txn(
-                        txn,
-                        table="event_auth_chains",
-                        keyvalues={
-                            "chain_id": proposed_new_id,
-                            "sequence_number": proposed_new_seq,
-                        },
-                        retcol="event_id",
-                        allow_none=True,
-                    )
-                    if already_allocated:
-                        # Mark it as already allocated so we don't need to hit
-                        # the DB again.
-                        chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
-                    else:
-                        new_chain_tuple = (
-                            proposed_new_id,
-                            proposed_new_seq,
-                        )
-
-            if not new_chain_tuple:
-                new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
-
-            chains_tuples_allocated.add(new_chain_tuple)
-
-            chain_map[event_id] = new_chain_tuple
-            new_chain_tuples[event_id] = new_chain_tuple
+        # Allocate chain ID/sequence numbers to each new event.
+        new_chain_tuples = cls._allocate_chain_ids(
+            txn,
+            db_pool,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
+            events_to_calc_chain_id_for,
+            chain_map,
+        )
+        chain_map.update(new_chain_tuples)
 
         db_pool.simple_insert_many_txn(
             txn,
@@ -746,7 +715,8 @@ class PersistEventsStore:
                 if chain_map[a_id][0] != chain_id
             }
             for start_auth_id, end_auth_id in itertools.permutations(
-                event_to_auth_chain.get(event_id, []), r=2,
+                event_to_auth_chain.get(event_id, []),
+                r=2,
             ):
                 if chain_links.exists_path_from(
                     chain_map[start_auth_id], chain_map[end_auth_id]
@@ -794,13 +764,143 @@ class PersistEventsStore:
             ],
         )
 
+    @staticmethod
+    def _allocate_chain_ids(
+        txn,
+        db_pool: DatabasePool,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+        events_to_calc_chain_id_for: Set[str],
+        chain_map: Dict[str, Tuple[int, int]],
+    ) -> Dict[str, Tuple[int, int]]:
+        """Allocates, but does not persist, chain ID/sequence numbers for the
+        events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+        for info on args)
+        """
+
+        # We now calculate the chain IDs/sequence numbers for the events. We do
+        # this by looking at the chain ID and sequence number of any auth event
+        # with the same type/state_key and incrementing the sequence number by
+        # one. If there was no match or the chain ID/sequence number is already
+        # taken we generate a new chain.
+        #
+        # We try to reduce the number of times that we hit the database by
+        # batching up calls, to make this more efficient when persisting large
+        # numbers of state events (e.g. during joins).
+        #
+        # We do this by:
+        #   1. Calculating for each event which auth event will be used to
+        #      inherit the chain ID, i.e. converting the auth chain graph to a
+        #      tree that we can allocate chains on. We also keep track of which
+        #      existing chain IDs have been referenced.
+        #   2. Fetching the max allocated sequence number for each referenced
+        #      existing chain ID, generating a map from chain ID to the max
+        #      allocated sequence number.
+        #   3. Iterating over the tree and allocating a chain ID/seq no. to the
+        #      new event, by incrementing the sequence number from the
+        #      referenced event's chain ID/seq no. and checking that the
+        #      incremented sequence number hasn't already been allocated (by
+        #      looking in the map generated in the previous step). We generate a
+        #      new chain if the sequence number has already been allocated.
+        #
+
+        existing_chains = set()  # type: Set[int]
+        tree = []  # type: List[Tuple[str, Optional[str]]]
+
+        # We need to do this in a topologically sorted order as we want to
+        # generate chain IDs/sequence numbers of an event's auth events before
+        # the event itself.
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            for auth_id in event_to_auth_chain.get(event_id, []):
+                if event_to_types.get(event_id) == event_to_types.get(auth_id):
+                    existing_chain_id = chain_map.get(auth_id)
+                    if existing_chain_id:
+                        existing_chains.add(existing_chain_id[0])
+
+                    tree.append((event_id, auth_id))
+                    break
+            else:
+                tree.append((event_id, None))
+
+        # Fetch the current max sequence number for each existing referenced chain.
+        sql = """
+            SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+            WHERE %s
+            GROUP BY chain_id
+        """
+        clause, args = make_in_list_sql_clause(
+            db_pool.engine, "chain_id", existing_chains
+        )
+        txn.execute(sql % (clause,), args)
+
+        chain_to_max_seq_no = {row[0]: row[1] for row in txn}  # type: Dict[Any, int]
+
+        # Allocate the new events chain ID/sequence numbers.
+        #
+        # To reduce the number of calls to the database we don't allocate a
+        # chain ID number in the loop, instead we use a temporary `object()` for
+        # each new chain ID. Once we've done the loop we generate the necessary
+        # number of new chain IDs in one call, replacing all temporary
+        # objects with real allocated chain IDs.
+
+        unallocated_chain_ids = set()  # type: Set[object]
+        new_chain_tuples = {}  # type: Dict[str, Tuple[Any, int]]
+        for event_id, auth_event_id in tree:
+            # If we reference an auth_event_id we fetch the allocated chain ID,
+            # either from the existing `chain_map` or the newly generated
+            # `new_chain_tuples` map.
+            existing_chain_id = None
+            if auth_event_id:
+                existing_chain_id = new_chain_tuples.get(auth_event_id)
+                if not existing_chain_id:
+                    existing_chain_id = chain_map[auth_event_id]
+
+            new_chain_tuple = None  # type: Optional[Tuple[Any, int]]
+            if existing_chain_id:
+                # We found a chain ID/sequence number candidate, check its
+                # not already taken.
+                proposed_new_id = existing_chain_id[0]
+                proposed_new_seq = existing_chain_id[1] + 1
+
+                if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+                    new_chain_tuple = (
+                        proposed_new_id,
+                        proposed_new_seq,
+                    )
+
+            # If we need to start a new chain we allocate a temporary chain ID.
+            if not new_chain_tuple:
+                new_chain_tuple = (object(), 1)
+                unallocated_chain_ids.add(new_chain_tuple[0])
+
+            new_chain_tuples[event_id] = new_chain_tuple
+            chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+        # Generate new chain IDs for all unallocated chain IDs.
+        newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+            txn, len(unallocated_chain_ids)
+        )
+
+        # Map from potentially temporary chain ID to real chain ID
+        chain_id_to_allocated_map = dict(
+            zip(unallocated_chain_ids, newly_allocated_chain_ids)
+        )  # type: Dict[Any, int]
+        chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+        return {
+            event_id: (chain_id_to_allocated_map[chain_id], seq)
+            for event_id, (chain_id, seq) in new_chain_tuples.items()
+        }
+
     def _persist_transaction_ids_txn(
         self,
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
     ):
-        """Persist the mapping from transaction IDs to event IDs (if defined).
-        """
+        """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
         to_insert = []
         for event, _ in events_and_contexts:
@@ -820,7 +920,9 @@ class PersistEventsStore:
 
         if to_insert:
             self.db_pool.simple_insert_many_txn(
-                txn, table="event_txn_id", values=to_insert,
+                txn,
+                table="event_txn_id",
+                values=to_insert,
             )
 
     def _update_current_state_txn(
@@ -852,7 +954,9 @@ class PersistEventsStore:
                 txn.execute(sql, (stream_id, self._instance_name, room_id))
 
                 self.db_pool.simple_delete_txn(
-                    txn, table="current_state_events", keyvalues={"room_id": room_id},
+                    txn,
+                    table="current_state_events",
+                    keyvalues={"room_id": room_id},
                 )
             else:
                 # We're still in the room, so we update the current state as normal.
@@ -876,7 +980,7 @@ class PersistEventsStore:
                         WHERE room_id = ? AND type = ? AND state_key = ?
                     )
                 """
-                txn.executemany(
+                txn.execute_batch(
                     sql,
                     (
                         (
@@ -895,7 +999,7 @@ class PersistEventsStore:
                 )
                 # Now we actually update the current_state_events table
 
-                txn.executemany(
+                txn.execute_batch(
                     "DELETE FROM current_state_events"
                     " WHERE room_id = ? AND type = ? AND state_key = ?",
                     (
@@ -907,7 +1011,7 @@ class PersistEventsStore:
                 # We include the membership in the current state table, hence we do
                 # a lookup when we insert. This assumes that all events have already
                 # been inserted into room_memberships.
-                txn.executemany(
+                txn.execute_batch(
                     """INSERT INTO current_state_events
                         (room_id, type, state_key, event_id, membership)
                     VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -927,7 +1031,7 @@ class PersistEventsStore:
             # we have no record of the fact the user *was* a member of the
             # room but got, say, state reset out of it.
             if to_delete or to_insert:
-                txn.executemany(
+                txn.execute_batch(
                     "DELETE FROM local_current_membership"
                     " WHERE room_id = ? AND user_id = ?",
                     (
@@ -938,7 +1042,7 @@ class PersistEventsStore:
                 )
 
             if to_insert:
-                txn.executemany(
+                txn.execute_batch(
                     """INSERT INTO local_current_membership
                         (room_id, user_id, event_id, membership)
                     VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -961,7 +1065,7 @@ class PersistEventsStore:
             # Figure out the changes of membership to invalidate the
             # `get_rooms_for_user` cache.
             # We find out which membership events we may have deleted
-            # and which we have added, then we invlidate the caches for all
+            # and which we have added, then we invalidate the caches for all
             # those users.
             members_changed = {
                 state_key
@@ -1519,8 +1623,7 @@ class PersistEventsStore:
         )
 
     def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database.
-        """
+        """Store a room member in the database."""
 
         def str_or_none(val: Any) -> Optional[str]:
             return val if isinstance(val, str) else None
@@ -1738,7 +1841,7 @@ class PersistEventsStore:
         """
 
         if events_and_contexts:
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (
@@ -1767,7 +1870,7 @@ class PersistEventsStore:
 
         # Now we delete the staging area for *all* events that were being
         # persisted.
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM event_push_actions_staging WHERE event_id = ?",
             ((event.event_id,) for event, _ in all_events_and_contexts),
         )
@@ -1886,7 +1989,7 @@ class PersistEventsStore:
             " )"
         )
 
-        txn.executemany(
+        txn.execute_batch(
             query,
             [
                 (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1900,7 +2003,7 @@ class PersistEventsStore:
             "DELETE FROM event_backward_extremities"
             " WHERE event_id = ? AND room_id = ?"
         )
-        txn.executemany(
+        txn.execute_batch(
             query,
             [
                 (ev.event_id, ev.room_id)
@@ -1912,8 +2015,7 @@ class PersistEventsStore:
 
 @attr.s(slots=True)
 class _LinkMap:
-    """A helper type for tracking links between chains.
-    """
+    """A helper type for tracking links between chains."""
 
     # Stores the set of links as nested maps: source chain ID -> target chain ID
     # -> source sequence number -> target sequence number.
@@ -2019,7 +2121,9 @@ class _LinkMap:
                 yield (src_chain, src_seq, target_chain, target_seq)
 
     def exists_path_from(
-        self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
+        self,
+        src_tuple: Tuple[int, int],
+        target_tuple: Tuple[int, int],
     ) -> bool:
         """Checks if there is a path between the source chain ID/sequence and
         target chain ID/sequence.
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e46e44ba54..89274e75f7 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -32,8 +32,7 @@ logger = logging.getLogger(__name__)
 
 @attr.s(slots=True, frozen=True)
 class _CalculateChainCover:
-    """Return value for _calculate_chain_cover_txn.
-    """
+    """Return value for _calculate_chain_cover_txn."""
 
     # The last room_id/depth/stream processed.
     room_id = attr.ib(type=str)
@@ -127,11 +126,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         )
 
         self.db_pool.updates.register_background_update_handler(
-            "rejected_events_metadata", self._rejected_events_metadata,
+            "rejected_events_metadata",
+            self._rejected_events_metadata,
         )
 
         self.db_pool.updates.register_background_update_handler(
-            "chain_cover", self._chain_cover_index,
+            "chain_cover",
+            self._chain_cover_index,
         )
 
     async def _background_reindex_fields_sender(self, progress, batch_size):
@@ -139,8 +140,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id, json FROM events"
@@ -178,9 +177,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
 
-            for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
-                clump = update_rows[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, update_rows)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -210,8 +207,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_search_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id FROM events"
@@ -256,9 +251,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
 
-            for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
-                clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, rows_to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -470,8 +463,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         return num_handled
 
     async def _redactions_received_ts(self, progress, batch_size):
-        """Handles filling out the `received_ts` column in redactions.
-        """
+        """Handles filling out the `received_ts` column in redactions."""
         last_event_id = progress.get("last_event_id", "")
 
         def _redactions_received_ts_txn(txn):
@@ -526,8 +518,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         return count
 
     async def _event_fix_redactions_bytes(self, progress, batch_size):
-        """Undoes hex encoded censored redacted event JSON.
-        """
+        """Undoes hex encoded censored redacted event JSON."""
 
         def _event_fix_redactions_bytes_txn(txn):
             # This update is quite fast due to new index.
@@ -650,7 +641,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 LIMIT ?
             """
 
-            txn.execute(sql, (last_event_id, batch_size,))
+            txn.execute(
+                sql,
+                (
+                    last_event_id,
+                    batch_size,
+                ),
+            )
 
             return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn]  # type: ignore
 
@@ -918,7 +915,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         # Annoyingly we need to gut wrench into the persit event store so that
         # we can reuse the function to calculate the chain cover for rooms.
         PersistEventsStore._add_chain_cover_index(
-            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+            txn,
+            self.db_pool,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
         )
 
         return _CalculateChainCover(
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
new file mode 100644
index 0000000000..b3703ae161
--- /dev/null
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventForwardExtremitiesStore(SQLBaseStore):
+    async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+        """Delete any extra forward extremities for a room.
+
+        Invalidates the "get_latest_event_ids_in_room" cache if any forward
+        extremities were deleted.
+
+        Returns count deleted.
+        """
+
+        def delete_forward_extremities_for_room_txn(txn):
+            # First we need to get the event_id to not delete
+            sql = """
+                SELECT event_id FROM event_forward_extremities
+                INNER JOIN events USING (room_id, event_id)
+                WHERE room_id = ?
+                ORDER BY stream_ordering DESC
+                LIMIT 1
+            """
+            txn.execute(sql, (room_id,))
+            rows = txn.fetchall()
+            try:
+                event_id = rows[0][0]
+                logger.debug(
+                    "Found event_id %s as the forward extremity to keep for room %s",
+                    event_id,
+                    room_id,
+                )
+            except KeyError:
+                msg = "No forward extremity event found for room %s" % room_id
+                logger.warning(msg)
+                raise SynapseError(400, msg)
+
+            # Now delete the extra forward extremities
+            sql = """
+                DELETE FROM event_forward_extremities
+                WHERE event_id != ? AND room_id = ?
+            """
+
+            txn.execute(sql, (event_id, room_id))
+            logger.info(
+                "Deleted %s extra forward extremities for room %s",
+                txn.rowcount,
+                room_id,
+            )
+
+            if txn.rowcount > 0:
+                # Invalidate the cache
+                self._invalidate_cache_and_stream(
+                    txn,
+                    self.get_latest_event_ids_in_room,
+                    (room_id,),
+                )
+
+            return txn.rowcount
+
+        return await self.db_pool.runInteraction(
+            "delete_forward_extremities_for_room",
+            delete_forward_extremities_for_room_txn,
+        )
+
+    async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+        """Get list of forward extremities for a room."""
+
+        def get_forward_extremities_for_room_txn(txn):
+            sql = """
+                SELECT event_id, state_group, depth, received_ts
+                FROM event_forward_extremities
+                INNER JOIN event_to_state_groups USING (event_id)
+                INNER JOIN events USING (room_id, event_id)
+                WHERE room_id = ?
+            """
+
+            txn.execute(sql, (room_id,))
+            return self.db_pool.cursor_to_dict(txn)
+
+        return await self.db_pool.runInteraction(
+            "get_forward_extremities_for_room",
+            get_forward_extremities_for_room_txn,
+        )
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 71d823be72..c8850a4707 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -120,7 +120,9 @@ class EventsWorkerStore(SQLBaseStore):
             # SQLite).
             if hs.get_instance_name() in hs.config.worker.writers.events:
                 self._stream_id_gen = StreamIdGenerator(
-                    db_conn, "events", "stream_ordering",
+                    db_conn,
+                    "events",
+                    "stream_ordering",
                 )
                 self._backfill_id_gen = StreamIdGenerator(
                     db_conn,
@@ -140,7 +142,8 @@ class EventsWorkerStore(SQLBaseStore):
         if hs.config.run_background_tasks:
             # We periodically clean out old transaction ID mappings
             self._clock.looping_call(
-                self._cleanup_old_transaction_ids, 5 * 60 * 1000,
+                self._cleanup_old_transaction_ids,
+                5 * 60 * 1000,
             )
 
         self._get_event_cache = LruCache(
@@ -1325,8 +1328,7 @@ class EventsWorkerStore(SQLBaseStore):
         return rows, to_token, True
 
     async def is_event_after(self, event_id1, event_id2):
-        """Returns True if event_id1 is after event_id2 in the stream
-        """
+        """Returns True if event_id1 is after event_id2 in the stream"""
         to_1, so_1 = await self.get_event_ordering(event_id1)
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
@@ -1428,8 +1430,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     @wrap_as_background_process("_cleanup_old_transaction_ids")
     async def _cleanup_old_transaction_ids(self):
-        """Cleans out transaction id mappings older than 24hrs.
-        """
+        """Cleans out transaction id mappings older than 24hrs."""
 
         def _cleanup_old_transaction_ids_txn(txn):
             sql = """
@@ -1440,5 +1441,6 @@ class EventsWorkerStore(SQLBaseStore):
             txn.execute(sql, (one_day_ago,))
 
         return await self.db_pool.runInteraction(
-            "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
+            "_cleanup_old_transaction_ids",
+            _cleanup_old_transaction_ids_txn,
         )
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 7218191965..ac07e0197b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple
+
+from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -26,6 +28,9 @@ from synapse.util import json_encoder
 _DEFAULT_CATEGORY_ID = ""
 _DEFAULT_ROLE_ID = ""
 
+# A room in a group.
+_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
+
 
 class GroupServerWorkerStore(SQLBaseStore):
     async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
@@ -72,7 +77,7 @@ class GroupServerWorkerStore(SQLBaseStore):
 
     async def get_rooms_in_group(
         self, group_id: str, include_private: bool = False
-    ) -> List[Dict[str, Union[str, bool]]]:
+    ) -> List[_RoomInGroup]:
         """Retrieve the rooms that belong to a given group. Does not return rooms that
         lack members.
 
@@ -123,7 +128,9 @@ class GroupServerWorkerStore(SQLBaseStore):
         )
 
     async def get_rooms_for_summary_by_category(
-        self, group_id: str, include_private: bool = False,
+        self,
+        group_id: str,
+        include_private: bool = False,
     ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
         """Get the rooms and categories that should be included in a summary request
 
@@ -368,8 +375,7 @@ class GroupServerWorkerStore(SQLBaseStore):
     async def is_user_invited_to_local_group(
         self, group_id: str, user_id: str
     ) -> Optional[bool]:
-        """Has the group server invited a user?
-        """
+        """Has the group server invited a user?"""
         return await self.db_pool.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id, "user_id": user_id},
@@ -427,8 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
         )
 
     async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
-        """Get all groups a user is publicising
-        """
+        """Get all groups a user is publicising"""
         return await self.db_pool.simple_select_onecol(
             table="local_group_membership",
             keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
@@ -437,8 +442,7 @@ class GroupServerWorkerStore(SQLBaseStore):
         )
 
     async def get_attestations_need_renewals(self, valid_until_ms):
-        """Get all attestations that need to be renewed until givent time
-        """
+        """Get all attestations that need to be renewed until givent time"""
 
         def _get_attestations_need_renewals_txn(txn):
             sql = """
@@ -781,8 +785,7 @@ class GroupServerStore(GroupServerWorkerStore):
         profile: Optional[JsonDict],
         is_public: Optional[bool],
     ) -> None:
-        """Add/update room category for group
-        """
+        """Add/update room category for group"""
         insertion_values = {}
         update_values = {"category_id": category_id}  # This cannot be empty
 
@@ -818,8 +821,7 @@ class GroupServerStore(GroupServerWorkerStore):
         profile: Optional[JsonDict],
         is_public: Optional[bool],
     ) -> None:
-        """Add/remove user role
-        """
+        """Add/remove user role"""
         insertion_values = {}
         update_values = {"role_id": role_id}  # This cannot be empty
 
@@ -1012,8 +1014,7 @@ class GroupServerStore(GroupServerWorkerStore):
         )
 
     async def add_group_invite(self, group_id: str, user_id: str) -> None:
-        """Record that the group server has invited a user
-        """
+        """Record that the group server has invited a user"""
         await self.db_pool.simple_insert(
             table="group_invites",
             values={"group_id": group_id, "user_id": user_id},
@@ -1156,8 +1157,7 @@ class GroupServerStore(GroupServerWorkerStore):
     async def update_group_publicity(
         self, group_id: str, user_id: str, publicise: bool
     ) -> None:
-        """Update whether the user is publicising their membership of the group
-        """
+        """Update whether the user is publicising their membership of the group"""
         await self.db_pool.simple_update_one(
             table="local_group_membership",
             keyvalues={"group_id": group_id, "user_id": user_id},
@@ -1300,8 +1300,7 @@ class GroupServerStore(GroupServerWorkerStore):
     async def update_attestation_renewal(
         self, group_id: str, user_id: str, attestation: dict
     ) -> None:
-        """Update an attestation that we have renewed
-        """
+        """Update an attestation that we have renewed"""
         await self.db_pool.simple_update_one(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
@@ -1312,8 +1311,7 @@ class GroupServerStore(GroupServerWorkerStore):
     async def update_remote_attestion(
         self, group_id: str, user_id: str, attestation: dict
     ) -> None:
-        """Update an attestation that a remote has renewed
-        """
+        """Update an attestation that a remote has renewed"""
         await self.db_pool.simple_update_one(
             table="group_attestations_remote",
             keyvalues={"group_id": group_id, "user_id": user_id},
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 04ac2d0ced..d504323b03 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -33,8 +33,7 @@ db_binary_type = memoryview
 
 
 class KeyStore(SQLBaseStore):
-    """Persistence for signature verification keys
-    """
+    """Persistence for signature verification keys"""
 
     @cached()
     def _get_server_verify_key(self, server_name_and_key_id):
@@ -155,7 +154,7 @@ class KeyStore(SQLBaseStore):
         (server_name, key_id, from_server) triplet if one already existed.
         Args:
             server_name: The name of the server.
-            key_id: The identifer of the key this JSON is for.
+            key_id: The identifier of the key this JSON is for.
             from_server: The server this JSON was fetched from.
             ts_now_ms: The time now in milliseconds.
             ts_valid_until_ms: The time when this json stops being valid.
@@ -182,7 +181,7 @@ class KeyStore(SQLBaseStore):
     async def get_server_keys_json(
         self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
     ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
-        """Retrive the key json for a list of server_keys and key ids.
+        """Retrieve the key json for a list of server_keys and key ids.
         If no keys are found for a given server, key_id and source then
         that server, key_id, and source triplet entry will be an empty list.
         The JSON is returned as a byte array so that it can be efficiently
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 283c8a5e22..a0313c3ccf 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -169,7 +169,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     async def get_local_media_before(
-        self, before_ts: int, size_gt: int, keep_profiles: bool,
+        self,
+        before_ts: int,
+        size_gt: int,
+        keep_profiles: bool,
     ) -> List[str]:
 
         # to find files that have never been accessed (last_access_ts IS NULL)
@@ -417,7 +420,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_origin = ? AND media_id = ?"
             )
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (time_ms, media_origin, media_id)
@@ -430,7 +433,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_id = ?"
             )
 
-            txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+            txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
 
         return await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
@@ -454,10 +457,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     async def get_remote_media_thumbnail(
-        self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+        self,
+        origin: str,
+        media_id: str,
+        t_width: int,
+        t_height: int,
+        t_type: str,
     ) -> Optional[Dict[str, Any]]:
-        """Fetch the thumbnail info of given width, height and type.
-        """
+        """Fetch the thumbnail info of given width, height and type."""
 
         return await self.db_pool.simple_select_one(
             table="remote_media_cache_thumbnails",
@@ -557,7 +564,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
         def _delete_url_cache_txn(txn):
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache", _delete_url_cache_txn
@@ -586,11 +593,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def _delete_url_cache_media_txn(txn):
             sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
             sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..614a418a15 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
+    async def count_daily_e2ee_messages(self):
+        """
+        Returns an estimate of the number of messages sent in the last day.
+
+        If it has been significantly less or more than one day since the last
+        call to this function, it will return None.
+        """
+
+        def _count_messages(txn):
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
+
+    async def count_daily_sent_e2ee_messages(self):
+        def _count_messages(txn):
+            # This is good enough as if you have silly characters in your own
+            # hostname then that's your own fault.
+            like_clause = "%:" + self.hs.hostname
+
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                    AND sender LIKE ?
+                AND stream_ordering > ?
+            """
+
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction(
+            "count_daily_sent_e2ee_messages", _count_messages
+        )
+
+    async def count_daily_active_e2ee_rooms(self):
+        def _count(txn):
+            sql = """
+                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction(
+            "count_daily_active_e2ee_rooms", _count
+        )
+
     async def count_daily_messages(self):
         """
         Returns an estimate of the number of messages sent in the last day.
@@ -111,7 +167,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     async def count_daily_sent_messages(self):
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
-            # hostname then thats your own fault.
+            # hostname then that's your own fault.
             like_clause = "%:" + self.hs.hostname
 
             sql = """
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index dbbb99cb95..29edab34d4 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -130,7 +130,9 @@ class PresenceStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
+        cached_method_name="_get_presence_for_user",
+        list_name="user_ids",
+        num_args=1,
     )
     async def get_presence_for_users(self, user_ids):
         rows = await self.db_pool.simple_select_many_batch(
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 54ef0f1f54..ba01d3108a 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -118,8 +118,7 @@ class ProfileWorkerStore(SQLBaseStore):
             )
 
     async def is_subscribed_remote_profile_for_user(self, user_id):
-        """Check whether we are interested in a remote user's profile.
-        """
+        """Check whether we are interested in a remote user's profile."""
         res = await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"user_id": user_id},
@@ -145,8 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
     async def get_remote_profile_cache_entries_that_expire(
         self, last_checked: int
     ) -> List[Dict[str, str]]:
-        """Get all users who haven't been checked since `last_checked`
-        """
+        """Get all users who haven't been checked since `last_checked`"""
 
         def _get_remote_profile_cache_entries_that_expire_txn(txn):
             sql = """
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         )
 
         # Update backward extremeties
-        txn.executemany(
+        txn.execute_batch(
             "INSERT INTO event_backward_extremities (room_id, event_id)"
             " VALUES (?, ?)",
             [(room_id, event_id) for event_id, in new_backwards_extrems],
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 711d5aa23d..9e58dc0e6a 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -168,7 +168,9 @@ class PushRulesWorkerStore(
             )
 
     @cachedList(
-        cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
+        cached_method_name="get_push_rules_for_user",
+        list_name="user_ids",
+        num_args=1,
     )
     async def bulk_get_push_rules(self, user_ids):
         if not user_ids:
@@ -195,7 +197,9 @@ class PushRulesWorkerStore(
             use_new_defaults = user_id in self._users_new_default_push_rules
 
             results[user_id] = _load_rules(
-                rules, enabled_map_by_user.get(user_id, {}), use_new_defaults,
+                rules,
+                enabled_map_by_user.get(user_id, {}),
+                use_new_defaults,
             )
 
         return results
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bc7621b8d6..7cb69dd6bd 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -179,7 +179,9 @@ class PusherWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
+        cached_method_name="get_if_user_has_pusher",
+        list_name="user_ids",
+        num_args=1,
     )
     async def get_if_users_have_pushers(
         self, user_ids: Iterable[str]
@@ -263,7 +265,8 @@ class PusherWorkerStore(SQLBaseStore):
         params_by_room = {}
         for row in res:
             params_by_room[row["room_id"]] = ThrottleParams(
-                row["last_sent_ts"], row["throttle_ms"],
+                row["last_sent_ts"],
+                row["throttle_ms"],
             )
 
         return params_by_room
@@ -344,7 +347,9 @@ class PusherStore(PusherWorkerStore):
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
-            self.db_pool.simple_delete_one_txn(
+            # It is expected that there is exactly one pusher to delete, but
+            # if it isn't there (or there are multiple) delete them all.
+            self.db_pool.simple_delete_txn(
                 txn,
                 "pushers",
                 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e4843a202c..43c852c96c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -160,7 +160,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         Args:
             room_id: List of room_ids.
-            to_key: Max stream id to fetch receipts upto.
+            to_key: Max stream id to fetch receipts up to.
             from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
@@ -189,7 +189,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         Args:
             room_ids: The room id.
-            to_key: Max stream id to fetch receipts upto.
+            to_key: Max stream id to fetch receipts up to.
             from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
@@ -208,8 +208,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
     ) -> List[dict]:
-        """See get_linearized_receipts_for_room
-        """
+        """See get_linearized_receipts_for_room"""
 
         def f(txn):
             if from_key:
@@ -304,7 +303,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         }
         return results
 
-    @cached(num_args=2,)
+    @cached(
+        num_args=2,
+    )
     async def get_linearized_receipts_for_all_rooms(
         self, to_key: int, from_key: Optional[int] = None
     ) -> Dict[str, JsonDict]:
@@ -312,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         to a limit of the latest 100 read receipts.
 
         Args:
-            to_key: Max stream id to fetch receipts upto.
+            to_key: Max stream id to fetch receipts up to.
             from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 8d05288ed4..d5b5507815 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -79,13 +79,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         # call `find_max_generated_user_id_localpart` each time, which is
         # expensive if there are many entries.
         self._user_id_seq = build_sequence_generator(
-            database.engine, find_max_generated_user_id_localpart, "user_id_seq",
+            database.engine,
+            find_max_generated_user_id_localpart,
+            "user_id_seq",
         )
 
         self._account_validity = hs.config.account_validity
         if hs.config.run_background_tasks and self._account_validity.enabled:
             self._clock.call_later(
-                0.0, self._set_expiration_date_when_missing,
+                0.0,
+                self._set_expiration_date_when_missing,
             )
 
         # Create a background job for culling expired 3PID validity tokens
@@ -110,6 +113,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 "creation_ts",
                 "user_type",
                 "deactivated",
+                "shadow_banned",
             ],
             allow_none=True,
             desc="get_user_by_id",
@@ -360,6 +364,37 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
+    async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
+        """Sets whether a user shadow-banned.
+
+        Args:
+            user: user ID of the user to test
+            shadow_banned: true iff the user is to be shadow-banned, false otherwise.
+        """
+
+        def set_shadow_banned_txn(txn):
+            user_id = user.to_string()
+            self.db_pool.simple_update_one_txn(
+                txn,
+                table="users",
+                keyvalues={"name": user_id},
+                updatevalues={"shadow_banned": shadow_banned},
+            )
+            # In order for this to apply immediately, clear the cache for this user.
+            tokens = self.db_pool.simple_select_onecol_txn(
+                txn,
+                table="access_tokens",
+                keyvalues={"user_id": user_id},
+                retcol="token",
+            )
+            for token in tokens:
+                self._invalidate_cache_and_stream(
+                    txn, self.get_user_by_access_token, (token,)
+                )
+            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
+        await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+
     def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
         sql = """
             SELECT users.name as user_id,
@@ -443,6 +478,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
 
+    async def record_user_external_id(
+        self, auth_provider: str, external_id: str, user_id: str
+    ) -> None:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        await self.db_pool.simple_insert(
+            table="user_external_ids",
+            values={
+                "auth_provider": auth_provider,
+                "external_id": external_id,
+                "user_id": user_id,
+            },
+            desc="record_user_external_id",
+        )
+
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
     ) -> Optional[str]:
@@ -1104,7 +1159,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
                 FROM user_threepids
             """
 
-            txn.executemany(sql, [(id_server,) for id_server in id_servers])
+            txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
 
         if id_servers:
             await self.db_pool.runInteraction(
@@ -1371,26 +1426,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-    async def record_user_external_id(
-        self, auth_provider: str, external_id: str, user_id: str
-    ) -> None:
-        """Record a mapping from an external user id to a mxid
-
-        Args:
-            auth_provider: identifier for the remote auth provider
-            external_id: id on that system
-            user_id: complete mxid that it is mapped to
-        """
-        await self.db_pool.simple_insert(
-            table="user_external_ids",
-            values={
-                "auth_provider": auth_provider,
-                "external_id": external_id,
-                "user_id": user_id,
-            },
-            desc="record_user_external_id",
-        )
-
     async def user_set_password_hash(
         self, user_id: str, password_hash: Optional[str]
     ) -> None:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a9fcb5f59c..9cbcd53026 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -193,8 +193,7 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
     async def get_room_count(self) -> int:
-        """Retrieve the total number of rooms.
-        """
+        """Retrieve the total number of rooms."""
 
         def f(txn):
             sql = "SELECT count(*)  FROM rooms"
@@ -517,7 +516,8 @@ class RoomWorkerStore(SQLBaseStore):
             return rooms, room_count[0]
 
         return await self.db_pool.runInteraction(
-            "get_rooms_paginate", _get_rooms_paginate_txn,
+            "get_rooms_paginate",
+            _get_rooms_paginate_txn,
         )
 
     @cached(max_entries=10000)
@@ -578,7 +578,8 @@ class RoomWorkerStore(SQLBaseStore):
             return self.db_pool.cursor_to_dict(txn)
 
         ret = await self.db_pool.runInteraction(
-            "get_retention_policy_for_room", get_retention_policy_for_room_txn,
+            "get_retention_policy_for_room",
+            get_retention_policy_for_room_txn,
         )
 
         # If we don't know this room ID, ret will be None, in this case return the default
@@ -707,7 +708,10 @@ class RoomWorkerStore(SQLBaseStore):
         return local_media_mxcs, remote_media_mxcs
 
     async def quarantine_media_by_id(
-        self, server_name: str, media_id: str, quarantined_by: str,
+        self,
+        server_name: str,
+        media_id: str,
+        quarantined_by: str,
     ) -> int:
         """quarantines a single local or remote media id
 
@@ -961,7 +965,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         self.config = hs.config
 
         self.db_pool.updates.register_background_update_handler(
-            "insert_room_retention", self._background_insert_retention,
+            "insert_room_retention",
+            self._background_insert_retention,
         )
 
         self.db_pool.updates.register_background_update_handler(
@@ -1033,7 +1038,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
                 return False
 
         end = await self.db_pool.runInteraction(
-            "insert_room_retention", _background_insert_retention_txn,
+            "insert_room_retention",
+            _background_insert_retention_txn,
         )
 
         if end:
@@ -1044,7 +1050,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
     async def _background_add_rooms_room_version_column(
         self, progress: dict, batch_size: int
     ):
-        """Background update to go and add room version inforamtion to `rooms`
+        """Background update to go and add room version information to `rooms`
         table from `current_state_events` table.
         """
 
@@ -1588,7 +1594,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 LIMIT ?
                 OFFSET ?
             """.format(
-                where_clause=where_clause, order=order,
+                where_clause=where_clause,
+                order=order,
             )
 
             args += [limit, start]
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..a9216ca9ae 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -70,10 +70,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         ):
             self._known_servers_count = 1
             self.hs.get_clock().looping_call(
-                self._count_known_servers, 60 * 1000,
+                self._count_known_servers,
+                60 * 1000,
             )
             self.hs.get_clock().call_later(
-                1000, self._count_known_servers,
+                1000,
+                self._count_known_servers,
             )
             LaterGauge(
                 "synapse_federation_known_servers",
@@ -174,7 +176,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(max_entries=100000)
     async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
-        """ Get the details of a room roughly suitable for use by the room
+        """Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
         Args:
             room_id: The room ID to query
@@ -488,8 +490,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     async def get_users_who_share_room_with_user(
         self, user_id: str, cache_context: _CacheContext
     ) -> Set[str]:
-        """Returns the set of users who share a room with `user_id`
-        """
+        """Returns the set of users who share a room with `user_id`"""
         room_ids = await self.get_rooms_for_user(
             user_id, on_invalidate=cache_context.invalidate
         )
@@ -618,7 +619,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
+        cached_method_name="_get_joined_profile_from_event_id",
+        list_name="event_ids",
     )
     async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
         """For given set of member event_ids check if they point to a join
@@ -802,8 +804,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
     ) -> List[dict]:
-        """Get user_id and membership of a set of event IDs.
-        """
+        """Get user_id and membership of a set of event IDs."""
 
         return await self.db_pool.simple_select_many_batch(
             table="room_memberships",
@@ -873,8 +874,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             "max_stream_id_exclusive", self._stream_order_on_start + 1
         )
 
-        INSERT_CLUMP_SIZE = 1000
-
         def add_membership_profile_txn(txn):
             sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +914,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
                 UPDATE room_memberships SET display_name = ?, avatar_url = ?
                 WHERE event_id = ? AND room_id = ?
             """
-            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
-                clump = to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(to_update_sql, clump)
+            txn.execute_batch(to_update_sql, to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index ad875c733a..3907189e29 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,5 +23,6 @@ def run_create(cur, database_engine, *args, **kwargs):
 
 def run_upgrade(cur, database_engine, *args, **kwargs):
     cur.execute(
-        "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
+        "UPDATE remote_media_cache SET last_access_ts = ?",
+        (int(time.time() * 1000),),
     )
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
index f35c70b699..9e8f35c1d2 100644
--- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
         # { "ignored_users": "@someone:example.org": {} }
         ignored_users = content.get("ignored_users", {})
         if isinstance(ignored_users, dict) and ignored_users:
-            cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+            cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
 
     # Add indexes after inserting data for efficiency.
     logger.info("Adding constraints to ignored_users table")
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..308124e531 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -67,11 +67,6 @@ CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT
 CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
 CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value )
 /* event_search(event_id,room_id,sender,"key",value) */;
-CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value');
-CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB);
-CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));
-CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB);
-CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB);
 CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) );
 CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) );
 CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag     TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) );
@@ -149,11 +144,6 @@ CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_las
 CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') );
 CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value )
 /* user_directory_search(user_id,value) */;
-CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value');
-CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB);
-CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));
-CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB);
-CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB);
 CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL );
 CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);
 CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT );
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..f5e7d9ef98 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
 
@@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
 
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
@@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
     async def search_rooms(
         self,
-        room_ids: List[str],
+        room_ids: Collection[str],
         search_term: str,
         keys: List[str],
         limit,
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 3c1e33819b..a7f371732f 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -52,8 +52,7 @@ class _GetStateGroupDelta(
 
 # this inherits from EventsWorkerStore because it calls self.get_events
 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
-    """The parts of StateGroupStore that can be called from workers.
-    """
+    """The parts of StateGroupStore that can be called from workers."""
 
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -276,8 +275,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         num_args=1,
     )
     async def _get_state_group_for_events(self, event_ids):
-        """Returns mapping event_id -> state_group
-        """
+        """Returns mapping event_id -> state_group"""
         rows = await self.db_pool.simple_select_many_batch(
             table="event_to_state_groups",
             column="event_id",
@@ -338,7 +336,8 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
             columns=["state_group"],
         )
         self.db_pool.updates.register_background_update_handler(
-            self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
+            self.DELETE_CURRENT_STATE_UPDATE_NAME,
+            self._background_remove_left_rooms,
         )
 
     async def _background_remove_left_rooms(self, progress, batch_size):
@@ -487,7 +486,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
 
 
 class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
-    """ Keeps track of the state at a given event.
+    """Keeps track of the state at a given event.
 
     This is done by the concept of `state groups`. Every event is a assigned
     a state group (identified by an arbitrary string), which references a
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 356623fc6e..0dbb501f16 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -64,7 +64,7 @@ class StateDeltasStore(SQLBaseStore):
         def get_current_state_deltas_txn(txn):
             # First we calculate the max stream id that will give us less than
             # N results.
-            # We arbitarily limit to 100 stream_id entries to ensure we don't
+            # We arbitrarily limit to 100 stream_id entries to ensure we don't
             # select toooo many.
             sql = """
                 SELECT stream_id, count(*)
@@ -81,7 +81,7 @@ class StateDeltasStore(SQLBaseStore):
             for stream_id, count in txn:
                 total += count
                 if total > 100:
-                    # We arbitarily limit to 100 entries to ensure we don't
+                    # We arbitrarily limit to 100 entries to ensure we don't
                     # select toooo many.
                     logger.debug(
                         "Clipping current_state_delta_stream rows to stream_id %i",
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0cdb3ec1f7..1c99393c65 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,11 +15,12 @@
 # limitations under the License.
 
 import logging
-from collections import Counter
 from enum import Enum
 from itertools import chain
 from typing import Any, Dict, List, Optional, Tuple
 
+from typing_extensions import Counter
+
 from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
@@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore):
         return slice_list
 
     @cached()
-    async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
+    async def get_earliest_token_for_stats(
+        self, stats_type: str, id: str
+    ) -> Optional[int]:
         """
         Fetch the "earliest token". This is used by the room stats delta
         processor to ignore deltas that have been processed between the
@@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore):
         )
 
     async def bulk_update_stats_delta(
-        self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+        self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
     ) -> None:
         """Bulk update stats tables for a given stream_id and updates the stats
         incremental position.
@@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore):
 
     async def get_changes_room_total_events_and_bytes(
         self, min_pos: int, max_pos: int
-    ) -> Dict[str, Dict[str, int]]:
+    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
         """Fetches the counts of events in the given range of stream IDs.
 
         Args:
@@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore):
             max_pos,
         )
 
-    def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+    def get_changes_room_total_events_and_bytes_txn(
+        self, txn, low_pos: int, high_pos: int
+    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
         """Gets the total_events and total_event_bytes counts for rooms and
         senders, in a range of stream_orderings (including backfilled events).
 
         Args:
             txn
-            low_pos (int): Low stream ordering
-            high_pos (int): High stream ordering
+            low_pos: Low stream ordering
+            high_pos: High stream ordering
 
         Returns:
-            tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
-            room and user deltas for total_events/total_event_bytes in the
+            The room and user deltas for total_events/total_event_bytes in the
             format of `stats_id` -> fields
         """
 
@@ -997,7 +1001,9 @@ class StatsStore(StateDeltasStore):
                 ORDER BY {order_by_column} {order}
                 LIMIT ? OFFSET ?
             """.format(
-                sql_base=sql_base, order_by_column=order_by_column, order=order,
+                sql_base=sql_base,
+                order_by_column=order_by_column,
+                order=order,
             )
 
             args += [limit, start]
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e3b9ff5ca6..91f8abb67d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -565,7 +565,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                     AND e.stream_ordering > ? AND e.stream_ordering <= ?
                 ORDER BY e.stream_ordering ASC
             """
-            txn.execute(sql, (user_id, min_from_id, max_to_id,))
+            txn.execute(
+                sql,
+                (
+                    user_id,
+                    min_from_id,
+                    max_to_id,
+                ),
+            )
 
             rows = [
                 _EventDictReturn(event_id, None, stream_ordering)
@@ -695,7 +702,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             return "t%d-%d" % (topo, token)
 
     def get_stream_id_for_event_txn(
-        self, txn: LoggingTransaction, event_id: str, allow_none=False,
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        allow_none=False,
     ) -> int:
         return self.db_pool.simple_select_one_onecol_txn(
             txn=txn,
@@ -706,8 +716,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         )
 
     async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
-        """Get the persisted position for an event
-        """
+        """Get the persisted position for an event"""
         row = await self.db_pool.simple_select_one(
             table="events",
             keyvalues={"event_id": event_id},
@@ -897,19 +906,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     ) -> Tuple[int, List[EventBase]]:
         """Get all new events
 
-         Returns all events with from_id < stream_ordering <= current_id.
+        Returns all events with from_id < stream_ordering <= current_id.
 
-         Args:
-             from_id:  the stream_ordering of the last event we processed
-             current_id:  the stream_ordering of the most recently processed event
-             limit: the maximum number of events to return
+        Args:
+            from_id:  the stream_ordering of the last event we processed
+            current_id:  the stream_ordering of the most recently processed event
+            limit: the maximum number of events to return
 
-         Returns:
-             A tuple of (next_id, events), where `next_id` is the next value to
-             pass as `from_id` (it will either be the stream_ordering of the
-             last returned event, or, if fewer than `limit` events were found,
-             the `current_id`).
-         """
+        Returns:
+            A tuple of (next_id, events), where `next_id` is the next value to
+            pass as `from_id` (it will either be the stream_ordering of the
+            last returned event, or, if fewer than `limit` events were found,
+            the `current_id`).
+        """
 
         def get_all_new_events_stream_txn(txn):
             sql = (
@@ -1238,8 +1247,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
     @cached()
     async def get_id_for_instance(self, instance_name: str) -> int:
-        """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
-        """
+        """Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
 
         def _get_id_for_instance_txn(txn):
             instance_id = self.db_pool.simple_select_one_onecol_txn(
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index cea595ff19..b921d63d30 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -64,8 +64,7 @@ class TransactionWorkerStore(SQLBaseStore):
 
 
 class TransactionStore(TransactionWorkerStore):
-    """A collection of queries for handling PDUs.
-    """
+    """A collection of queries for handling PDUs."""
 
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -198,7 +197,7 @@ class TransactionStore(TransactionWorkerStore):
         retry_interval: int,
     ) -> None:
         """Sets the current retry timings for a given destination.
-        Both timings should be zero if retrying is no longer occuring.
+        Both timings should be zero if retrying is no longer occurring.
 
         Args:
             destination
@@ -299,7 +298,10 @@ class TransactionStore(TransactionWorkerStore):
             )
 
     async def store_destination_rooms_entries(
-        self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+        self,
+        destinations: Iterable[str],
+        room_id: str,
+        stream_ordering: int,
     ) -> None:
         """
         Updates or creates `destination_rooms` entries in batch for a single event.
@@ -394,7 +396,9 @@ class TransactionStore(TransactionWorkerStore):
         )
 
     async def get_catch_up_room_event_ids(
-        self, destination: str, last_successful_stream_ordering: int,
+        self,
+        destination: str,
+        last_successful_stream_ordering: int,
     ) -> List[str]:
         """
         Returns at most 50 event IDs and their corresponding stream_orderings
@@ -418,7 +422,9 @@ class TransactionStore(TransactionWorkerStore):
 
     @staticmethod
     def _get_catch_up_room_event_ids_txn(
-        txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
+        txn: LoggingTransaction,
+        destination: str,
+        last_successful_stream_ordering: int,
     ) -> List[str]:
         q = """
                 SELECT event_id FROM destination_rooms
@@ -429,7 +435,8 @@ class TransactionStore(TransactionWorkerStore):
                 LIMIT 50
             """
         txn.execute(
-            q, (destination, last_successful_stream_ordering),
+            q,
+            (destination, last_successful_stream_ordering),
         )
         event_ids = [row[0] for row in txn]
         return event_ids
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 79b7ece330..5473ec1485 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -44,7 +44,11 @@ class UIAuthWorkerStore(SQLBaseStore):
     """
 
     async def create_ui_auth_session(
-        self, clientdict: JsonDict, uri: str, method: str, description: str,
+        self,
+        clientdict: JsonDict,
+        uri: str,
+        method: str,
+        description: str,
     ) -> UIAuthSessionData:
         """
         Creates a new user interactive authentication session.
@@ -123,7 +127,10 @@ class UIAuthWorkerStore(SQLBaseStore):
         return UIAuthSessionData(session_id, **result)
 
     async def mark_ui_auth_stage_complete(
-        self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
+        self,
+        session_id: str,
+        stage_type: str,
+        result: Union[str, bool, JsonDict],
     ):
         """
         Mark a session stage as completed.
@@ -261,10 +268,12 @@ class UIAuthWorkerStore(SQLBaseStore):
         return serverdict.get(key, default)
 
     async def add_user_agent_ip_to_ui_auth_session(
-        self, session_id: str, user_agent: str, ip: str,
+        self,
+        session_id: str,
+        user_agent: str,
+        ip: str,
     ):
-        """Add the given user agent / IP to the tracking table
-        """
+        """Add the given user agent / IP to the tracking table"""
         await self.db_pool.simple_upsert(
             table="ui_auth_sessions_ips",
             keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
@@ -273,7 +282,8 @@ class UIAuthWorkerStore(SQLBaseStore):
         )
 
     async def get_user_agents_ips_to_ui_auth_session(
-        self, session_id: str,
+        self,
+        session_id: str,
     ) -> List[Tuple[str, str]]:
         """Get the given user agents / IPs used during the ui auth process
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ef11f1c3b3..63f88eac51 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -336,8 +336,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         return len(users_to_work_on)
 
     async def is_room_world_readable_or_publicly_joinable(self, room_id):
-        """Check if the room is either world_readable or publically joinable
-        """
+        """Check if the room is either world_readable or publically joinable"""
 
         # Create a state filter that only queries join and history state event
         types_to_filter = (
@@ -516,8 +515,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
     async def delete_all_from_user_dir(self) -> None:
-        """Delete the entire user directory
-        """
+        """Delete the entire user directory"""
 
         def _delete_all_from_user_dir_txn(txn):
             txn.execute("DELETE FROM user_directory")
@@ -540,7 +538,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             desc="get_user_in_directory",
         )
 
-    async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+    async def update_user_directory_stream_pos(self, stream_id: int) -> None:
         await self.db_pool.simple_update_one(
             table="user_directory_stream_pos",
             keyvalues={},
@@ -709,7 +707,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
 
         return {row["room_id"] for row in rows}
 
-    async def get_user_directory_stream_pos(self) -> int:
+    async def get_user_directory_stream_pos(self) -> Optional[int]:
+        """
+        Get the stream ID of the user directory stream.
+
+        Returns:
+            The stream token or None if the initial background update hasn't happened yet.
+        """
         return await self.db_pool.simple_select_one_onecol(
             table="user_directory_stream_pos",
             keyvalues={},
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index acb24e33af..1fd333b707 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -27,7 +27,7 @@ MAX_STATE_DELTA_HOPS = 100
 
 
 class StateGroupBackgroundUpdateStore(SQLBaseStore):
-    """Defines functions related to state groups needed to run the state backgroud
+    """Defines functions related to state groups needed to run the state background
     updates.
     """
 
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..b16b9905d8 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -48,8 +48,7 @@ class _GetStateGroupDelta(
 
 
 class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
-    """A data store for fetching/storing state groups.
-    """
+    """A data store for fetching/storing state groups."""
 
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -89,7 +88,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             50000,
         )
         self._state_group_members_cache = DictionaryCache(
-            "*stateGroupMembersCache*", 500000,
+            "*stateGroupMembersCache*",
+            500000,
         )
 
         def get_max_state_group_txn(txn: Cursor):
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             )
 
         logger.info("[purge] removing redundant state groups")
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             ((sg,) for sg in state_groups_to_delete),
         )
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups WHERE id = ?",
             ((sg,) for sg in state_groups_to_delete),
         )
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 035f9ea6e9..d15ccfacde 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import platform
 
 from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 from .postgres import PostgresEngine
@@ -28,11 +27,8 @@ def create_engine(database_config) -> BaseDatabaseEngine:
         return Sqlite3Engine(sqlite3, database_config)
 
     if name == "psycopg2":
-        # pypy requires psycopg2cffi rather than psycopg2
-        if platform.python_implementation() == "PyPy":
-            import psycopg2cffi as psycopg2  # type: ignore
-        else:
-            import psycopg2  # type: ignore
+        # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
+        import psycopg2  # type: ignore
 
         return PostgresEngine(psycopg2, database_config)
 
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index d6d632dc10..cca839c70f 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -94,14 +94,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
     @property
     @abc.abstractmethod
     def server_version(self) -> str:
-        """Gets a string giving the server version. For example: '3.22.0'
-        """
+        """Gets a string giving the server version. For example: '3.22.0'"""
         ...
 
     @abc.abstractmethod
     def in_transaction(self, conn: Connection) -> bool:
-        """Whether the connection is currently in a transaction.
-        """
+        """Whether the connection is currently in a transaction."""
         ...
 
     @abc.abstractmethod
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 7719ac32f7..80a3558aec 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -138,8 +138,7 @@ class PostgresEngine(BaseDatabaseEngine):
 
     @property
     def supports_using_any_list(self):
-        """Do we support using `a = ANY(?)` and passing a list
-        """
+        """Do we support using `a = ANY(?)` and passing a list"""
         return True
 
     def is_deadlock(self, error):
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 5db0f0b520..b87e7798da 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import platform
 import struct
 import threading
 import typing
@@ -28,7 +29,15 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         super().__init__(database_module, database_config)
 
         database = database_config.get("args", {}).get("database")
-        self._is_in_memory = database in (None, ":memory:",)
+        self._is_in_memory = database in (
+            None,
+            ":memory:",
+        )
+
+        if platform.python_implementation() == "PyPy":
+            # pypy's sqlite3 module doesn't handle bytearrays, convert them
+            # back to bytes.
+            database_module.register_adapter(bytearray, lambda array: bytes(array))
 
         # The current max state_group, or None if we haven't looked
         # in the DB yet.
@@ -57,8 +66,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
 
     @property
     def supports_using_any_list(self):
-        """Do we support using `a = ANY(?)` and passing a list
-        """
+        """Do we support using `a = ANY(?)` and passing a list"""
         return False
 
     def check_database(self, db_conn, allow_outdated_version: bool = False):
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 61fc49c69c..3a0d6fb32e 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -411,8 +411,8 @@ class EventsPersistenceStorage:
                         )
 
                     for room_id, ev_ctx_rm in events_by_room.items():
-                        latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
-                            room_id
+                        latest_event_ids = (
+                            await self.main_store.get_latest_event_ids_in_room(room_id)
                         )
                         new_latest_event_ids = await self._calculate_new_extremities(
                             room_id, ev_ctx_rm, latest_event_ids
@@ -889,7 +889,8 @@ class EventsPersistenceStorage:
                 continue
 
             logger.debug(
-                "Not dropping as too new and not in new_senders: %s", new_senders,
+                "Not dropping as too new and not in new_senders: %s",
+                new_senders,
             )
 
             return new_latest_event_ids
@@ -1004,7 +1005,10 @@ class EventsPersistenceStorage:
 
         remote_event_ids = [
             event_id
-            for (typ, state_key,), event_id in current_state.items()
+            for (
+                typ,
+                state_key,
+            ), event_id in current_state.items()
             if typ == EventTypes.Member and not self.is_mine_id(state_key)
         ]
         rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 566ea19bae..6c3c2da520 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -113,7 +113,7 @@ def prepare_database(
             # which should be empty.
             if config is None:
                 raise ValueError(
-                    "config==None in prepare_database, but databse is not empty"
+                    "config==None in prepare_database, but database is not empty"
                 )
 
             # if it's a worker app, refuse to upgrade the database, to avoid multiple
@@ -425,7 +425,10 @@ def _upgrade_existing_database(
             # We don't support using the same file name in the same delta version.
             raise PrepareDatabaseException(
                 "Found multiple delta files with the same name in v%d: %s"
-                % (v, duplicates,)
+                % (
+                    v,
+                    duplicates,
+                )
             )
 
         # We sort to ensure that we apply the delta files in a consistent
@@ -532,7 +535,8 @@ def _apply_module_schema_files(
         names_and_streams: the names and streams of schemas to be applied
     """
     cur.execute(
-        "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
+        "SELECT file FROM applied_module_schemas WHERE module_name = ?",
+        (modname,),
     )
     applied_deltas = {d for d, in cur}
     for (name, stream) in names_and_streams:
@@ -619,9 +623,9 @@ def _get_or_create_schema_state(
 
     txn.execute("SELECT version, upgraded FROM schema_version")
     row = txn.fetchone()
-    current_version = int(row[0]) if row else None
 
-    if current_version:
+    if row is not None:
+        current_version = int(row[0])
         txn.execute(
             "SELECT file FROM applied_schema_deltas WHERE version >= ?",
             (current_version,),
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index 6c359c1aae..3c4908865f 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -26,15 +26,13 @@ logger = logging.getLogger(__name__)
 
 
 class PurgeEventsStorage:
-    """High level interface for purging rooms and event history.
-    """
+    """High level interface for purging rooms and event history."""
 
     def __init__(self, hs: "HomeServer", stores: Databases):
         self.stores = stores
 
     async def purge_room(self, room_id: str) -> None:
-        """Deletes all record of a room
-        """
+        """Deletes all record of a room"""
 
         state_groups_to_delete = await self.stores.main.purge_room(room_id)
         await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 31ccbf23dc..d179a41884 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -340,8 +340,7 @@ class StateFilter:
 
 
 class StateGroupStorage:
-    """High level interface to fetching state for event.
-    """
+    """High level interface to fetching state for event."""
 
     def __init__(self, hs: "HomeServer", stores: "Databases"):
         self.stores = stores
@@ -400,7 +399,7 @@ class StateGroupStorage:
     async def get_state_groups(
         self, room_id: str, event_ids: Iterable[str]
     ) -> Dict[int, List[EventBase]]:
-        """ Get the state groups for the given list of event_ids
+        """Get the state groups for the given list of event_ids
 
         Args:
             room_id: ID of the room for these events.
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 9cadcba18f..17291c9d5e 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Iterable, Iterator, List, Optional, Tuple
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
 
 from typing_extensions import Protocol
 
@@ -20,23 +20,44 @@ from typing_extensions import Protocol
 Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
 """
 
+_Parameters = Union[Sequence[Any], Mapping[str, Any]]
+
 
 class Cursor(Protocol):
-    def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+    def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
         ...
 
-    def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+    def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
         ...
 
-    def fetchall(self) -> List[Tuple]:
+    def fetchone(self) -> Optional[Tuple]:
+        ...
+
+    def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
         ...
 
-    def fetchone(self) -> Tuple:
+    def fetchall(self) -> List[Tuple]:
         ...
 
     @property
-    def description(self) -> Any:
-        return None
+    def description(
+        self,
+    ) -> Optional[
+        Sequence[
+            # Note that this is an approximate typing based on sqlite3 and other
+            # drivers, and may not be entirely accurate.
+            Tuple[
+                str,
+                Optional[Any],
+                Optional[int],
+                Optional[int],
+                Optional[int],
+                Optional[int],
+                Optional[int],
+            ]
+        ]
+    ]:
+        ...
 
     @property
     def rowcount(self) -> int:
@@ -59,7 +80,7 @@ class Connection(Protocol):
     def commit(self) -> None:
         ...
 
-    def rollback(self, *args, **kwargs) -> None:
+    def rollback(self) -> None:
         ...
 
     def __enter__(self) -> "Connection":
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index bb84c0d792..d4643c4fdf 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
 import heapq
 import logging
 import threading
-from collections import deque
+from collections import OrderedDict
 from contextlib import contextmanager
 from typing import Dict, List, Optional, Set, Tuple, Union
 
 import attr
-from typing_extensions import Deque
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
             )
-        self._unfinished_ids = deque()  # type: Deque[int]
+
+        # We use this as an ordered set, as we want to efficiently append items,
+        # remove items and get the first item. Since we insert IDs in order, the
+        # insertion ordering will ensure its in the correct ordering.
+        #
+        # The key and values are the same, but we never look at the values.
+        self._unfinished_ids = OrderedDict()  # type: OrderedDict[int, int]
 
     def get_next(self):
         """
@@ -113,7 +118,7 @@ class StreamIdGenerator:
             self._current += self._step
             next_id = self._current
 
-            self._unfinished_ids.append(next_id)
+            self._unfinished_ids[next_id] = next_id
 
         @contextmanager
         def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
                 yield next_id
             finally:
                 with self._lock:
-                    self._unfinished_ids.remove(next_id)
+                    self._unfinished_ids.pop(next_id)
 
         return _AsyncCtxManagerWrapper(manager())
 
@@ -140,7 +145,7 @@ class StreamIdGenerator:
             self._current += n * self._step
 
             for next_id in next_ids:
-                self._unfinished_ids.append(next_id)
+                self._unfinished_ids[next_id] = next_id
 
         @contextmanager
         def manager():
@@ -149,7 +154,7 @@ class StreamIdGenerator:
             finally:
                 with self._lock:
                     for next_id in next_ids:
-                        self._unfinished_ids.remove(next_id)
+                        self._unfinished_ids.pop(next_id)
 
         return _AsyncCtxManagerWrapper(manager())
 
@@ -162,7 +167,7 @@ class StreamIdGenerator:
         """
         with self._lock:
             if self._unfinished_ids:
-                return self._unfinished_ids[0] - self._step
+                return next(iter(self._unfinished_ids)) - self._step
 
             return self._current
 
@@ -240,7 +245,7 @@ class MultiWriterIdGenerator:
         # and b) noting that if we have seen a run of persisted positions
         # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
         #
-        # Note: There is no guarentee that the IDs generated by the sequence
+        # Note: There is no guarantee that the IDs generated by the sequence
         # will be gapless; gaps can form when e.g. a transaction was rolled
         # back. This means that sometimes we won't be able to skip forward the
         # position even though everything has been persisted. However, since
@@ -272,7 +277,9 @@ class MultiWriterIdGenerator:
         self._load_current_ids(db_conn, tables)
 
     def _load_current_ids(
-        self, db_conn, tables: List[Tuple[str, str, str]],
+        self,
+        db_conn,
+        tables: List[Tuple[str, str, str]],
     ):
         cur = db_conn.cursor(txn_name="_load_current_ids")
 
@@ -359,7 +366,10 @@ class MultiWriterIdGenerator:
             rows.sort()
 
             with self._lock:
-                for (instance, stream_id,) in rows:
+                for (
+                    instance,
+                    stream_id,
+                ) in rows:
                     stream_id = self._return_factor * stream_id
                     self._add_persisted_position(stream_id)
 
@@ -413,7 +423,7 @@ class MultiWriterIdGenerator:
         # bother, as nothing will read it).
         #
         # We only do this on the success path so that the persisted current
-        # position points to a persited row with the correct instance name.
+        # position points to a persisted row with the correct instance name.
         if self._writers:
             txn.call_after(
                 run_as_background_process,
@@ -476,8 +486,7 @@ class MultiWriterIdGenerator:
         return self.get_persisted_upto_position()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer.
-        """
+        """Returns the position of the given writer."""
 
         # If we don't have an entry for the given instance name, we assume it's a
         # new writer.
@@ -504,7 +513,7 @@ class MultiWriterIdGenerator:
             }
 
     def advance(self, instance_name: str, new_id: int):
-        """Advance the postion of the named writer to the given ID, if greater
+        """Advance the position of the named writer to the given ID, if greater
         than existing entry.
         """
 
@@ -576,8 +585,7 @@ class MultiWriterIdGenerator:
                 break
 
     def _update_stream_positions_table_txn(self, txn: Cursor):
-        """Update the `stream_positions` table with newly persisted position.
-        """
+        """Update the `stream_positions` table with newly persisted position."""
 
         if not self._writers:
             return
@@ -617,8 +625,7 @@ class _AsyncCtxManagerWrapper:
 
 @attr.s(slots=True)
 class _MultiWriterCtxManager:
-    """Async context manager returned by MultiWriterIdGenerator
-    """
+    """Async context manager returned by MultiWriterIdGenerator"""
 
     id_gen = attr.ib(type=MultiWriterIdGenerator)
     multiple_ids = attr.ib(type=Optional[int], default=None)
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c780ade077..3ea637b281 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -70,6 +70,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        """Get the next `n` IDs in the sequence"""
+        ...
+
+    @abc.abstractmethod
     def check_consistency(
         self,
         db_conn: "LoggingDatabaseConnection",
@@ -101,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
 
     def get_next_id_txn(self, txn: Cursor) -> int:
         txn.execute("SELECT nextval(?)", (self._sequence_name,))
-        return txn.fetchone()[0]
+        fetch_res = txn.fetchone()
+        assert fetch_res is not None
+        return fetch_res[0]
 
     def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
         txn.execute(
@@ -117,8 +124,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
         stream_name: Optional[str] = None,
         positive: bool = True,
     ):
-        """See SequenceGenerator.check_consistency for docstring.
-        """
+        """See SequenceGenerator.check_consistency for docstring."""
 
         txn = db_conn.cursor(txn_name="sequence.check_consistency")
 
@@ -142,7 +148,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
         txn.execute(
             "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
         )
-        last_value, is_called = txn.fetchone()
+        fetch_res = txn.fetchone()
+        assert fetch_res is not None
+        last_value, is_called = fetch_res
 
         # If we have an associated stream check the stream_positions table.
         max_in_stream_positions = None
@@ -219,6 +227,17 @@ class LocalSequenceGenerator(SequenceGenerator):
             self._current_max_id += 1
             return self._current_max_id
 
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            first_id = self._current_max_id + 1
+            self._current_max_id += n
+            return [first_id + i for i in range(n)]
+
     def check_consistency(
         self,
         db_conn: Connection,
diff --git a/synapse/types.py b/synapse/types.py
index eafe729dfe..721343f0b5 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -469,8 +469,7 @@ class RoomStreamToken:
     )
 
     def __attrs_post_init__(self):
-        """Validates that both `topological` and `instance_map` aren't set.
-        """
+        """Validates that both `topological` and `instance_map` aren't set."""
 
         if self.instance_map and self.topological:
             raise ValueError(
@@ -498,7 +497,11 @@ class RoomStreamToken:
                     instance_name = await store.get_name_from_instance_id(instance_id)
                     instance_map[instance_name] = pos
 
-                return cls(topological=None, stream=stream, instance_map=instance_map,)
+                return cls(
+                    topological=None,
+                    stream=stream,
+                    instance_map=instance_map,
+                )
         except Exception:
             pass
         raise SynapseError(400, "Invalid token %r" % (string,))
@@ -675,7 +678,7 @@ class PersistedEventPosition:
         persisted in the same room after this position will be after the
         returned `RoomStreamToken`.
 
-        Note: no guarentees are made about ordering w.r.t. events in other
+        Note: no guarantees are made about ordering w.r.t. events in other
         rooms.
         """
         # Doing the naive thing satisfies the desired properties described in
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 9a873c8e8e..719e35b78d 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -252,8 +252,7 @@ class Linearizer:
         self.key_to_defer = {}  # type: Dict[Hashable, _LinearizerEntry]
 
     def is_queued(self, key: Hashable) -> bool:
-        """Checks whether there is a process queued up waiting
-        """
+        """Checks whether there is a process queued up waiting"""
         entry = self.key_to_defer.get(key)
         if not entry:
             # No entry so nothing is waiting.
@@ -452,7 +451,9 @@ R = TypeVar("R")
 
 
 def timeout_deferred(
-    deferred: defer.Deferred, timeout: float, reactor: IReactorTime,
+    deferred: defer.Deferred,
+    timeout: float,
+    reactor: IReactorTime,
 ) -> defer.Deferred:
     """The in built twisted `Deferred.addTimeout` fails to time out deferreds
     that have a canceller that throws exceptions. This method creates a new
@@ -497,7 +498,7 @@ def timeout_deferred(
     delayed_call = reactor.callLater(timeout, time_it_out)
 
     def convert_cancelled(value: failure.Failure):
-        # if the orgininal deferred was cancelled, and our timeout has fired, then
+        # if the original deferred was cancelled, and our timeout has fired, then
         # the reason it was cancelled was due to our timeout. Turn the CancelledError
         # into a TimeoutError.
         if timed_out[0] and value.check(CancelledError):
@@ -529,8 +530,7 @@ def timeout_deferred(
 
 @attr.s(slots=True, frozen=True)
 class DoneAwaitable:
-    """Simple awaitable that returns the provided value.
-    """
+    """Simple awaitable that returns the provided value."""
 
     value = attr.ib()
 
@@ -545,8 +545,7 @@ class DoneAwaitable:
 
 
 def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
-    """Convert a value to an awaitable if not already an awaitable.
-    """
+    """Convert a value to an awaitable if not already an awaitable."""
     if inspect.isawaitable(value):
         assert isinstance(value, Awaitable)
         return value
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 89f0b38535..e676c2cac4 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -149,8 +149,7 @@ KNOWN_KEYS = {
 
 
 def intern_string(string):
-    """Takes a (potentially) unicode string and interns it if it's ascii
-    """
+    """Takes a (potentially) unicode string and interns it if it's ascii"""
     if string is None:
         return None
 
@@ -161,8 +160,7 @@ def intern_string(string):
 
 
 def intern_dict(dictionary):
-    """Takes a dictionary and interns well known keys and their values
-    """
+    """Takes a dictionary and interns well known keys and their values"""
     return {
         KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
         for key, value in dictionary.items()
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
new file mode 100644
index 0000000000..3ee0f2317a
--- /dev/null
+++ b/synapse/util/caches/cached_call.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
+
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+
+TV = TypeVar("TV")
+
+
+class CachedCall(Generic[TV]):
+    """A wrapper for asynchronous calls whose results should be shared
+
+    This is useful for wrapping asynchronous functions, where there might be multiple
+    callers, but we only want to call the underlying function once (and have the result
+    returned to all callers).
+
+    Similar results can be achieved via a lock of some form, but that typically requires
+    more boilerplate (and ends up being less efficient).
+
+    Correctly handles Synapse logcontexts (logs and resource usage for the underlying
+    function are logged against the logcontext which is active when get() is first
+    called).
+
+    Example usage:
+
+        _cached_val = CachedCall(_load_prop)
+
+        async def handle_request() -> X:
+            # We can call this multiple times, but it will result in a single call to
+            # _load_prop().
+            return await _cached_val.get()
+
+        async def _load_prop() -> X:
+            await difficult_operation()
+
+
+    The implementation is deliberately single-shot (ie, once the call is initiated,
+    there is no way to ask for it to be run). This keeps the implementation and
+    semantics simple. If you want to make a new call, simply replace the whole
+    CachedCall object.
+    """
+
+    __slots__ = ["_callable", "_deferred", "_result"]
+
+    def __init__(self, f: Callable[[], Awaitable[TV]]):
+        """
+        Args:
+            f: The underlying function. Only one call to this function will be alive
+                at once (per instance of CachedCall)
+        """
+        self._callable = f  # type: Optional[Callable[[], Awaitable[TV]]]
+        self._deferred = None  # type: Optional[Deferred]
+        self._result = None  # type: Union[None, Failure, TV]
+
+    async def get(self) -> TV:
+        """Kick off the call if necessary, and return the result"""
+
+        # Fire off the callable now if this is our first time
+        if not self._deferred:
+            self._deferred = run_in_background(self._callable)
+
+            # we will never need the callable again, so make sure it can be GCed
+            self._callable = None
+
+            # once the deferred completes, store the result. We cannot simply leave the
+            # result in the deferred, since if it's a Failure, GCing the deferred
+            # would then log a critical error about unhandled Failures.
+            def got_result(r):
+                self._result = r
+
+            self._deferred.addBoth(got_result)
+
+        # TODO: consider cancellation semantics. Currently, if the call to get()
+        #    is cancelled, the underlying call will continue (and any future calls
+        #    will get the result/exception), which I think is *probably* ok, modulo
+        #    the fact the underlying call may be logged to a cancelled logcontext,
+        #    and any eventual exception may not be reported.
+
+        # we can now await the deferred, and once it completes, return the result.
+        await make_deferred_yieldable(self._deferred)
+
+        # I *think* this is the easiest way to correctly raise a Failure without having
+        # to gut-wrench into the implementation of Deferred.
+        d = Deferred()
+        d.callback(self._result)
+        return await d
+
+
+class RetryOnExceptionCachedCall(Generic[TV]):
+    """A wrapper around CachedCall which will retry the call if an exception is thrown
+
+    This is used in much the same way as CachedCall, but adds some extra functionality
+    so that if the underlying function throws an exception, then the next call to get()
+    will initiate another call to the underlying function. (Any calls to get() which
+    are already pending will raise the exception.)
+    """
+
+    slots = ["_cachedcall"]
+
+    def __init__(self, f: Callable[[], Awaitable[TV]]):
+        async def _wrapper() -> TV:
+            try:
+                return await f()
+            except Exception:
+                # the call raised an exception: replace the underlying CachedCall to
+                # trigger another call next time get() is called
+                self._cachedcall = CachedCall(_wrapper)
+                raise
+
+        self._cachedcall = CachedCall(_wrapper)
+
+    async def get(self) -> TV:
+        return await self._cachedcall.get()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index a924140cdf..4e84379914 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -122,7 +122,8 @@ class _LruCachedFunction(Generic[F]):
 
 
 def lru_cache(
-    max_entries: int = 1000, cache_context: bool = False,
+    max_entries: int = 1000,
+    cache_context: bool = False,
 ) -> Callable[[F], _LruCachedFunction[F]]:
     """A method decorator that applies a memoizing cache around the function.
 
@@ -156,7 +157,9 @@ def lru_cache(
 
     def func(orig: F) -> _LruCachedFunction[F]:
         desc = LruCacheDescriptor(
-            orig, max_entries=max_entries, cache_context=cache_context,
+            orig,
+            max_entries=max_entries,
+            cache_context=cache_context,
         )
         return cast(_LruCachedFunction[F], desc)
 
@@ -170,14 +173,18 @@ class LruCacheDescriptor(_CacheDescriptorBase):
         sentinel = object()
 
     def __init__(
-        self, orig, max_entries: int = 1000, cache_context: bool = False,
+        self,
+        orig,
+        max_entries: int = 1000,
+        cache_context: bool = False,
     ):
         super().__init__(orig, num_args=None, cache_context=cache_context)
         self.max_entries = max_entries
 
     def __get__(self, obj, owner):
         cache = LruCache(
-            cache_name=self.orig.__name__, max_size=self.max_entries,
+            cache_name=self.orig.__name__,
+            max_size=self.max_entries,
         )  # type: LruCache[CacheKey, Any]
 
         get_cache_key = self.cache_key_builder
@@ -212,7 +219,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
 
 class DeferredCacheDescriptor(_CacheDescriptorBase):
-    """ A method decorator that applies a memoizing cache around the function.
+    """A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
     fail are removed from the cache.
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index c541bf4579..644e9e778a 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -84,8 +84,7 @@ class StreamChangeCache:
         return False
 
     def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
-        """Returns True if the entity may have been updated since stream_pos
-        """
+        """Returns True if the entity may have been updated since stream_pos"""
         assert isinstance(stream_pos, int)
 
         if stream_pos < self._earliest_known_stream_pos:
@@ -133,8 +132,7 @@ class StreamChangeCache:
         return result
 
     def has_any_entity_changed(self, stream_pos: int) -> bool:
-        """Returns if any entity has changed
-        """
+        """Returns if any entity has changed"""
         assert type(stream_pos) is int
 
         if not self._cache:
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index a6ee9edaec..3c47285d05 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -108,7 +108,10 @@ class Signal:
                 return await maybe_awaitable(observer(*args, **kwargs))
             except Exception as e:
                 logger.warning(
-                    "%s signal observer %s failed: %r", self.name, observer, e,
+                    "%s signal observer %s failed: %r",
+                    self.name,
+                    observer,
+                    e,
                 )
 
         deferreds = [run_in_background(do, o) for o in self.observers]
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 733f5e26e6..68dc632491 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -83,15 +83,13 @@ class BackgroundFileConsumer:
             self._producer.resumeProducing()
 
     def unregisterProducer(self):
-        """Part of IProducer interface
-        """
+        """Part of IProducer interface"""
         self._producer = None
         if not self._finished_deferred.called:
             self._bytes_queue.put_nowait(None)
 
     def write(self, bytes):
-        """Part of IProducer interface
-        """
+        """Part of IProducer interface"""
         if self._write_exception:
             raise self._write_exception
 
@@ -107,8 +105,7 @@ class BackgroundFileConsumer:
             self._producer.pauseProducing()
 
     def _writer(self):
-        """This is run in a background thread to write to the file.
-        """
+        """This is run in a background thread to write to the file."""
         try:
             while self._producer or not self._bytes_queue.empty():
                 # If we've paused the producer check if we should resume the
@@ -135,13 +132,11 @@ class BackgroundFileConsumer:
             self._file_obj.close()
 
     def wait(self):
-        """Returns a deferred that resolves when finished writing to file
-        """
+        """Returns a deferred that resolves when finished writing to file"""
         return make_deferred_yieldable(self._finished_deferred)
 
     def _resume_paused_producer(self):
-        """Gets called if we should resume producing after being paused
-        """
+        """Gets called if we should resume producing after being paused"""
         if self._paused_producer and self._producer:
             self._paused_producer = False
             self._producer.resumeProducing()
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 8d2411513f..98707c119d 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -62,7 +62,8 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
 
 
 def sorted_topologically(
-    nodes: Iterable[T], graph: Mapping[T, Collection[T]],
+    nodes: Iterable[T],
+    graph: Mapping[T, Collection[T]],
 ) -> Generator[T, None, None]:
     """Given a set of nodes and a graph, yield the nodes in toplogical order.
 
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 50516926f3..e3a8ed5b2f 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -15,7 +15,7 @@
 
 
 class JsonEncodedObject:
-    """ A common base class for defining protocol units that are represented
+    """A common base class for defining protocol units that are represented
     as JSON.
 
     Attributes:
@@ -39,7 +39,7 @@ class JsonEncodedObject:
     """
 
     def __init__(self, **kwargs):
-        """ Takes the dict of `kwargs` and loads all keys that are *valid*
+        """Takes the dict of `kwargs` and loads all keys that are *valid*
         (i.e., are included in the `valid_keys` list) into the dictionary`
         instance variable.
 
@@ -61,7 +61,7 @@ class JsonEncodedObject:
                 self.unrecognized_keys[k] = v
 
     def get_dict(self):
-        """ Converts this protocol unit into a :py:class:`dict`, ready to be
+        """Converts this protocol unit into a :py:class:`dict`, ready to be
         encoded as JSON.
 
         The keys it encodes are: `valid_keys` - `internal_keys`
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index f4de6b9f54..1023c856d1 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -161,8 +161,7 @@ class Measure:
         return self._logging_context.get_resource_usage()
 
     def _update_in_flight(self, metrics):
-        """Gets called when processing in flight metrics
-        """
+        """Gets called when processing in flight metrics"""
         duration = self.clock.time() - self.start
 
         metrics.real_time_max = max(metrics.real_time_max, duration)
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 1ee61851e4..d184e2a90c 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -25,7 +25,7 @@ from synapse.config._util import json_error_to_config_error
 
 
 def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
-    """ Loads a synapse module with its config
+    """Loads a synapse module with its config
 
     Args:
         provider: a dict with keys 'module' (the module name) and 'config'
@@ -49,7 +49,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
     module = importlib.import_module(module)
     provider_class = getattr(module, clz)
 
-    module_config = provider.get("config")
+    # Load the module config. If None, pass an empty dictionary instead
+    module_config = provider.get("config") or {}
     try:
         provider_config = provider_class.parse_config(module_config)
     except jsonschema.ValidationError as e:
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 72574d3af2..d9f9ae99d6 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -204,16 +204,13 @@ def _check_yield_points(f: Callable, changes: List[str]):
                 # We don't raise here as its perfectly valid for contexts to
                 # change in a function, as long as it sets the correct context
                 # on resolving (which is checked separately).
-                err = (
-                    "%s changed context from %s to %s, happened between lines %d and %d in %s"
-                    % (
-                        frame.f_code.co_name,
-                        expected_context,
-                        current_context(),
-                        last_yield_line_no,
-                        frame.f_lineno,
-                        frame.f_code.co_filename,
-                    )
+                err = "%s changed context from %s to %s, happened between lines %d and %d in %s" % (
+                    frame.f_code.co_name,
+                    expected_context,
+                    current_context(),
+                    last_yield_line_no,
+                    frame.f_lineno,
+                    frame.f_code.co_filename,
                 )
                 changes.append(err)
 
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index f8038bf861..9ce7873ab5 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError
 _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
 
 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
-client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
+CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
 
 # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
 # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
@@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
 rand = random.SystemRandom()
 
 
-def random_string(length):
+def random_string(length: int) -> str:
     return "".join(rand.choice(string.ascii_letters) for _ in range(length))
 
 
-def random_string_with_symbols(length):
+def random_string_with_symbols(length: int) -> str:
     return "".join(rand.choice(_string_with_symbols) for _ in range(length))
 
 
-def is_ascii(s):
-    if isinstance(s, bytes):
-        try:
-            s.decode("ascii").encode("ascii")
-        except UnicodeDecodeError:
-            return False
-        except UnicodeEncodeError:
-            return False
-        return True
+def is_ascii(s: bytes) -> bool:
+    try:
+        s.decode("ascii").encode("ascii")
+    except UnicodeDecodeError:
+        return False
+    except UnicodeEncodeError:
+        return False
+    return True
 
 
-def assert_valid_client_secret(client_secret):
-    """Validate that a given string matches the client_secret regex defined by the spec"""
-    if client_secret_regex.match(client_secret) is None:
+def assert_valid_client_secret(client_secret: str) -> None:
+    """Validate that a given string matches the client_secret defined by the spec"""
+    if (
+        len(client_secret) <= 0
+        or len(client_secret) > 255
+        or CLIENT_SECRET_REGEX.match(client_secret) is None
+    ):
         raise SynapseError(
             400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
         )
diff --git a/synapse/util/templates.py b/synapse/util/templates.py
new file mode 100644
index 0000000000..392dae4a40
--- /dev/null
+++ b/synapse/util/templates.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for dealing with jinja2 templates"""
+
+import time
+import urllib.parse
+from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
+
+import jinja2
+
+if TYPE_CHECKING:
+    from synapse.config.homeserver import HomeServerConfig
+
+
+def build_jinja_env(
+    template_search_directories: Iterable[str],
+    config: "HomeServerConfig",
+    autoescape: Union[bool, Callable[[str], bool], None] = None,
+) -> jinja2.Environment:
+    """Set up a Jinja2 environment to load templates from the given search path
+
+    The returned environment defines the following filters:
+        - format_ts: formats timestamps as strings in the server's local timezone
+             (XXX: why is that useful??)
+        - mxc_to_http: converts mxc: uris to http URIs. Args are:
+             (uri, width, height, resize_method="crop")
+
+    and the following global variables:
+        - server_name: matrix server name
+
+    Args:
+        template_search_directories: directories to search for templates
+
+        config: homeserver config, for things like `server_name` and `public_baseurl`
+
+        autoescape: whether template variables should be autoescaped. bool, or
+           a function mapping from template name to bool. Defaults to escaping templates
+           whose names end in .html, .xml or .htm.
+
+    Returns:
+        jinja environment
+    """
+
+    if autoescape is None:
+        autoescape = jinja2.select_autoescape()
+
+    loader = jinja2.FileSystemLoader(template_search_directories)
+    env = jinja2.Environment(loader=loader, autoescape=autoescape)
+
+    # Update the environment with our custom filters
+    env.filters.update(
+        {
+            "format_ts": _format_ts_filter,
+            "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl),
+        }
+    )
+
+    # common variables for all templates
+    env.globals.update({"server_name": config.server_name})
+
+    return env
+
+
+def _create_mxc_to_http_filter(
+    public_baseurl: Optional[str],
+) -> Callable[[str, int, int, str], str]:
+    """Create and return a jinja2 filter that converts MXC urls to HTTP
+
+    Args:
+        public_baseurl: The public, accessible base URL of the homeserver
+    """
+
+    def mxc_to_http_filter(
+        value: str, width: int, height: int, resize_method: str = "crop"
+    ) -> str:
+        if not public_baseurl:
+            raise RuntimeError(
+                "public_baseurl must be set in the homeserver config to convert MXC URLs to HTTP URLs."
+            )
+
+        if value[0:6] != "mxc://":
+            return ""
+
+        server_and_media_id = value[6:]
+        fragment = None
+        if "#" in server_and_media_id:
+            server_and_media_id, fragment = server_and_media_id.split("#", 1)
+            fragment = "#" + fragment
+
+        params = {"width": width, "height": height, "method": resize_method}
+        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+            public_baseurl,
+            server_and_media_id,
+            urllib.parse.urlencode(params),
+            fragment or "",
+        )
+
+    return mxc_to_http_filter
+
+
+def _format_ts_filter(value: int, format: str):
+    return time.strftime(format, time.localtime(value / 1000))
diff --git a/synapse/visibility.py b/synapse/visibility.py
index ec50e7e977..e39d02602a 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -80,6 +80,7 @@ async def filter_events_for_client(
     events = [e for e in events if not e.internal_metadata.is_soft_failed()]
 
     types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
+
     event_id_to_state = await storage.state.get_state_for_events(
         frozenset(e.event_id for e in events),
         state_filter=StateFilter.from_types(types),
@@ -233,7 +234,7 @@ async def filter_events_for_client(
 
         elif visibility == HistoryVisibility.SHARED and is_peeking:
             # if the visibility is shared, users cannot see the event unless
-            # they have *subequently* joined the room (or were members at the
+            # they have *subsequently* joined the room (or were members at the
             # time, of course)
             #
             # XXX: if the user has subsequently joined and then left again,