summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py14
-rw-r--r--synapse/api/auth_blocking.py2
-rw-r--r--synapse/api/constants.py26
-rw-r--r--synapse/api/errors.py16
-rw-r--r--synapse/api/filtering.py10
-rw-r--r--synapse/api/presence.py (renamed from synapse/storage/presence.py)0
-rw-r--r--synapse/api/ratelimiting.py2
-rw-r--r--synapse/api/room_versions.py10
-rw-r--r--synapse/api/urls.py3
-rw-r--r--synapse/app/_base.py11
-rw-r--r--synapse/app/admin_cmd.py23
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/app/homeserver.py28
-rw-r--r--synapse/appservice/__init__.py25
-rw-r--r--synapse/appservice/api.py19
-rw-r--r--synapse/appservice/scheduler.py8
-rw-r--r--synapse/config/_base.py125
-rw-r--r--synapse/config/_base.pyi1
-rw-r--r--synapse/config/cache.py2
-rw-r--r--synapse/config/emailconfig.py157
-rw-r--r--synapse/config/key.py2
-rw-r--r--synapse/config/logger.py25
-rw-r--r--synapse/config/metrics.py2
-rw-r--r--synapse/config/ratelimiting.py4
-rw-r--r--synapse/config/room.py2
-rw-r--r--synapse/config/room_directory.py2
-rw-r--r--synapse/config/saml2_config.py42
-rw-r--r--synapse/config/server.py77
-rw-r--r--synapse/config/sso.py37
-rw-r--r--synapse/config/workers.py37
-rw-r--r--synapse/crypto/context_factory.py8
-rw-r--r--synapse/crypto/keyring.py15
-rw-r--r--synapse/event_auth.py2
-rw-r--r--synapse/events/__init__.py8
-rw-r--r--synapse/events/builder.py23
-rw-r--r--synapse/events/spamcheck.py37
-rw-r--r--synapse/events/third_party_rules.py2
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/events/validator.py61
-rw-r--r--synapse/federation/federation_base.py2
-rw-r--r--synapse/federation/federation_server.py7
-rw-r--r--synapse/federation/persistence.py37
-rw-r--r--synapse/federation/send_queue.py6
-rw-r--r--synapse/federation/sender/__init__.py26
-rw-r--r--synapse/federation/sender/per_destination_queue.py177
-rw-r--r--synapse/federation/sender/transaction_manager.py33
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server.py4
-rw-r--r--synapse/federation/units.py4
-rw-r--r--synapse/groups/attestations.py4
-rw-r--r--synapse/groups/groups_server.py2
-rw-r--r--synapse/handlers/__init__.py2
-rw-r--r--synapse/handlers/_base.py2
-rw-r--r--synapse/handlers/account_data.py2
-rw-r--r--synapse/handlers/account_validity.py22
-rw-r--r--synapse/handlers/acme.py2
-rw-r--r--synapse/handlers/acme_issuing_service.py4
-rw-r--r--synapse/handlers/admin.py8
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/auth.py112
-rw-r--r--synapse/handlers/cas_handler.py11
-rw-r--r--synapse/handlers/device.py18
-rw-r--r--synapse/handlers/devicemessage.py7
-rw-r--r--synapse/handlers/directory.py6
-rw-r--r--synapse/handlers/e2e_keys.py18
-rw-r--r--synapse/handlers/e2e_room_keys.py2
-rw-r--r--synapse/handlers/events.py53
-rw-r--r--synapse/handlers/federation.py148
-rw-r--r--synapse/handlers/groups_local.py2
-rw-r--r--synapse/handlers/identity.py5
-rw-r--r--synapse/handlers/initial_sync.py93
-rw-r--r--synapse/handlers/message.py145
-rw-r--r--synapse/handlers/oidc_handler.py39
-rw-r--r--synapse/handlers/pagination.py184
-rw-r--r--synapse/handlers/password_policy.py2
-rw-r--r--synapse/handlers/presence.py13
-rw-r--r--synapse/handlers/profile.py23
-rw-r--r--synapse/handlers/receipts.py17
-rw-r--r--synapse/handlers/register.py34
-rw-r--r--synapse/handlers/room.py179
-rw-r--r--synapse/handlers/room_member.py131
-rw-r--r--synapse/handlers/room_member_worker.py9
-rw-r--r--synapse/handlers/saml_handler.py165
-rw-r--r--synapse/handlers/state_deltas.py2
-rw-r--r--synapse/handlers/sync.py104
-rw-r--r--synapse/handlers/typing.py23
-rw-r--r--synapse/handlers/ui_auth/checkers.py5
-rw-r--r--synapse/handlers/user_directory.py8
-rw-r--r--synapse/http/client.py19
-rw-r--r--synapse/http/connectproxyclient.py2
-rw-r--r--synapse/http/federation/matrix_federation_agent.py10
-rw-r--r--synapse/http/federation/srv_resolver.py4
-rw-r--r--synapse/http/federation/well_known_resolver.py68
-rw-r--r--synapse/http/matrixfederationclient.py13
-rw-r--r--synapse/http/request_metrics.py2
-rw-r--r--synapse/http/server.py110
-rw-r--r--synapse/http/servlet.py7
-rw-r--r--synapse/logging/_structured.py6
-rw-r--r--synapse/logging/_terse_json.py4
-rw-r--r--synapse/logging/context.py10
-rw-r--r--synapse/logging/opentracing.py14
-rw-r--r--synapse/logging/utils.py6
-rw-r--r--synapse/metrics/__init__.py20
-rw-r--r--synapse/metrics/background_process_metrics.py6
-rw-r--r--synapse/module_api/__init__.py12
-rw-r--r--synapse/notifier.py89
-rw-r--r--synapse/push/action_generator.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py101
-rw-r--r--synapse/push/emailpusher.py4
-rw-r--r--synapse/push/httppusher.py4
-rw-r--r--synapse/push/mailer.py76
-rw-r--r--synapse/push/push_rule_evaluator.py2
-rw-r--r--synapse/push/pusher.py33
-rw-r--r--synapse/push/pusherpool.py19
-rw-r--r--synapse/python_dependencies.py8
-rw-r--r--synapse/replication/http/_base.py2
-rw-r--r--synapse/replication/http/federation.py12
-rw-r--r--synapse/replication/http/membership.py10
-rw-r--r--synapse/replication/http/register.py4
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py14
-rw-r--r--synapse/replication/slave/storage/account_data.py4
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py2
-rw-r--r--synapse/replication/slave/storage/devices.py7
-rw-r--r--synapse/replication/slave/storage/groups.py2
-rw-r--r--synapse/replication/slave/storage/presence.py2
-rw-r--r--synapse/replication/slave/storage/push_rule.py12
-rw-r--r--synapse/replication/slave/storage/pushers.py2
-rw-r--r--synapse/replication/slave/storage/receipts.py2
-rw-r--r--synapse/replication/slave/storage/room.py2
-rw-r--r--synapse/replication/tcp/client.py14
-rw-r--r--synapse/replication/tcp/commands.py12
-rw-r--r--synapse/replication/tcp/handler.py2
-rw-r--r--synapse/replication/tcp/protocol.py2
-rw-r--r--synapse/replication/tcp/resource.py4
-rw-r--r--synapse/replication/tcp/streams/_base.py10
-rw-r--r--synapse/replication/tcp/streams/events.py8
-rw-r--r--synapse/res/templates/password_reset_confirmation.html16
-rw-r--r--synapse/res/templates/saml_error.html52
-rw-r--r--synapse/res/templates/sso_error.html43
-rw-r--r--synapse/rest/__init__.py10
-rw-r--r--synapse/rest/admin/rooms.py3
-rw-r--r--synapse/rest/admin/users.py4
-rw-r--r--synapse/rest/client/transactions.py2
-rw-r--r--synapse/rest/client/v1/login.py60
-rw-r--r--synapse/rest/client/v1/push_rule.py24
-rw-r--r--synapse/rest/client/v1/room.py173
-rw-r--r--synapse/rest/client/v2_alpha/account.py190
-rw-r--r--synapse/rest/client/v2_alpha/groups.py4
-rw-r--r--synapse/rest/client/v2_alpha/register.py45
-rw-r--r--synapse/rest/client/v2_alpha/relations.py18
-rw-r--r--synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py14
-rw-r--r--synapse/rest/client/v2_alpha/shared_rooms.py68
-rw-r--r--synapse/rest/client/v2_alpha/sync.py6
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py14
-rw-r--r--synapse/rest/media/v1/_base.py4
-rw-r--r--synapse/rest/media/v1/filepath.py21
-rw-r--r--synapse/rest/media/v1/media_repository.py71
-rw-r--r--synapse/rest/media/v1/media_storage.py30
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py2
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py5
-rw-r--r--synapse/rest/media/v1/thumbnailer.py20
-rw-r--r--synapse/rest/saml2/response_resource.py16
-rw-r--r--synapse/rest/synapse/__init__.py14
-rw-r--r--synapse/rest/synapse/client/__init__.py14
-rw-r--r--synapse/rest/synapse/client/password_reset.py127
-rw-r--r--synapse/rest/well_known.py6
-rw-r--r--synapse/secrets.py2
-rw-r--r--synapse/server_notices/consent_server_notices.py2
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py2
-rw-r--r--synapse/server_notices/server_notices_manager.py2
-rw-r--r--synapse/server_notices/server_notices_sender.py2
-rw-r--r--synapse/server_notices/worker_server_notices_sender.py2
-rw-r--r--synapse/spam_checker_api/__init__.py21
-rw-r--r--synapse/state/__init__.py226
-rw-r--r--synapse/state/v1.py89
-rw-r--r--synapse/state/v2.py259
-rw-r--r--synapse/storage/__init__.py7
-rw-r--r--synapse/storage/_base.py7
-rw-r--r--synapse/storage/background_updates.py36
-rw-r--r--synapse/storage/database.py821
-rw-r--r--synapse/storage/databases/__init__.py36
-rw-r--r--synapse/storage/databases/main/__init__.py118
-rw-r--r--synapse/storage/databases/main/account_data.py63
-rw-r--r--synapse/storage/databases/main/appservice.py26
-rw-r--r--synapse/storage/databases/main/cache.py4
-rw-r--r--synapse/storage/databases/main/client_ips.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py16
-rw-r--r--synapse/storage/databases/main/devices.py148
-rw-r--r--synapse/storage/databases/main/directory.py16
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py64
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py340
-rw-r--r--synapse/storage/databases/main/event_federation.py158
-rw-r--r--synapse/storage/databases/main/event_push_actions.py275
-rw-r--r--synapse/storage/databases/main/events.py74
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py46
-rw-r--r--synapse/storage/databases/main/events_worker.py457
-rw-r--r--synapse/storage/databases/main/filtering.py5
-rw-r--r--synapse/storage/databases/main/group_server.py325
-rw-r--r--synapse/storage/databases/main/keys.py76
-rw-r--r--synapse/storage/databases/main/media_repository.py139
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py30
-rw-r--r--synapse/storage/databases/main/openid.py14
-rw-r--r--synapse/storage/databases/main/presence.py34
-rw-r--r--synapse/storage/databases/main/profile.py51
-rw-r--r--synapse/storage/databases/main/purge_events.py32
-rw-r--r--synapse/storage/databases/main/push_rule.py267
-rw-r--r--synapse/storage/databases/main/pusher.py108
-rw-r--r--synapse/storage/databases/main/receipts.py116
-rw-r--r--synapse/storage/databases/main/registration.py393
-rw-r--r--synapse/storage/databases/main/rejections.py5
-rw-r--r--synapse/storage/databases/main/relations.py103
-rw-r--r--synapse/storage/databases/main/room.py97
-rw-r--r--synapse/storage/databases/main/roommember.py81
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres33
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite44
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql25
-rw-r--r--synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql28
-rw-r--r--synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres26
-rw-r--r--synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql42
-rw-r--r--synapse/storage/databases/main/schema/delta/58/15unread_count.sql26
-rw-r--r--synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql22
-rw-r--r--synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql21
-rw-r--r--synapse/storage/databases/main/search.py15
-rw-r--r--synapse/storage/databases/main/signatures.py40
-rw-r--r--synapse/storage/databases/main/state.py28
-rw-r--r--synapse/storage/databases/main/state_deltas.py21
-rw-r--r--synapse/storage/databases/main/stats.py166
-rw-r--r--synapse/storage/databases/main/stream.py485
-rw-r--r--synapse/storage/databases/main/tags.py15
-rw-r--r--synapse/storage/databases/main/transactions.py200
-rw-r--r--synapse/storage/databases/main/ui_auth.py67
-rw-r--r--synapse/storage/databases/main/user_directory.py102
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py34
-rw-r--r--synapse/storage/databases/state/store.py26
-rw-r--r--synapse/storage/keys.py2
-rw-r--r--synapse/storage/persist_events.py20
-rw-r--r--synapse/storage/prepare_database.py118
-rw-r--r--synapse/storage/purge_events.py2
-rw-r--r--synapse/storage/relations.py8
-rw-r--r--synapse/storage/state.py20
-rw-r--r--synapse/storage/util/id_generators.py274
-rw-r--r--synapse/storage/util/sequence.py8
-rw-r--r--synapse/streams/config.py63
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/types.py197
-rw-r--r--synapse/util/__init__.py20
-rw-r--r--synapse/util/async_helpers.py159
-rw-r--r--synapse/util/caches/__init__.py4
-rw-r--r--synapse/util/caches/descriptors.py100
-rw-r--r--synapse/util/caches/dictionary_cache.py4
-rw-r--r--synapse/util/caches/expiringcache.py4
-rw-r--r--synapse/util/caches/lrucache.py4
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--synapse/util/caches/treecache.py4
-rw-r--r--synapse/util/caches/ttlcache.py4
-rw-r--r--synapse/util/distributor.py54
-rw-r--r--synapse/util/file_consumer.py2
-rw-r--r--synapse/util/frozenutils.py5
-rw-r--r--synapse/util/jsonobject.py2
-rw-r--r--synapse/util/metrics.py2
-rw-r--r--synapse/util/ratelimitutils.py4
-rw-r--r--synapse/util/retryutils.py2
-rw-r--r--synapse/util/stringutils.py4
-rw-r--r--synapse/util/wheel_timer.py4
269 files changed, 7548 insertions, 4906 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 1282d19b3c..bf0bf192a5 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.19.1"
+__version__ = "1.20.0rc3"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d8190f92ab..75388643ee 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -58,7 +58,7 @@ class _InvalidMacaroonException(Exception):
     pass
 
 
-class Auth(object):
+class Auth:
     """
     FIXME: This class contains a mix of functions for authenticating users
     of our client-server API and authenticating events added to room graphs.
@@ -213,6 +213,7 @@ class Auth(object):
             user = user_info["user"]
             token_id = user_info["token_id"]
             is_guest = user_info["is_guest"]
+            shadow_banned = user_info["shadow_banned"]
 
             # Deny the request if the user account has expired.
             if self._account_validity.enabled and not allow_expired:
@@ -252,7 +253,12 @@ class Auth(object):
                 opentracing.set_tag("device_id", device_id)
 
             return synapse.types.create_requester(
-                user, token_id, is_guest, device_id, app_service=app_service
+                user,
+                token_id,
+                is_guest,
+                shadow_banned,
+                device_id,
+                app_service=app_service,
             )
         except KeyError:
             raise MissingClientTokenError()
@@ -297,6 +303,7 @@ class Auth(object):
             dict that includes:
                `user` (UserID)
                `is_guest` (bool)
+               `shadow_banned` (bool)
                `token_id` (int|None): access token id. May be None if guest
                `device_id` (str|None): device corresponding to access token
         Raises:
@@ -356,6 +363,7 @@ class Auth(object):
                 ret = {
                     "user": user,
                     "is_guest": True,
+                    "shadow_banned": False,
                     "token_id": None,
                     # all guests get the same device id
                     "device_id": GUEST_DEVICE_ID,
@@ -365,6 +373,7 @@ class Auth(object):
                 ret = {
                     "user": user,
                     "is_guest": False,
+                    "shadow_banned": False,
                     "token_id": None,
                     "device_id": None,
                 }
@@ -488,6 +497,7 @@ class Auth(object):
             "user": UserID.from_string(ret.get("name")),
             "token_id": ret.get("token_id", None),
             "is_guest": False,
+            "shadow_banned": ret.get("shadow_banned"),
             "device_id": ret.get("device_id"),
             "valid_until_ms": ret.get("valid_until_ms"),
         }
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index 49093bf181..d8fafd7cb8 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -22,7 +22,7 @@ from synapse.config.server import is_threepid_reserved
 logger = logging.getLogger(__name__)
 
 
-class AuthBlocking(object):
+class AuthBlocking:
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 6a6d32c302..46013cde15 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -28,7 +28,7 @@ MAX_ALIAS_LENGTH = 255
 MAX_USERID_LENGTH = 255
 
 
-class Membership(object):
+class Membership:
 
     """Represents the membership states of a user in a room."""
 
@@ -40,7 +40,7 @@ class Membership(object):
     LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
 
 
-class PresenceState(object):
+class PresenceState:
     """Represents the presence state of a user."""
 
     OFFLINE = "offline"
@@ -48,14 +48,14 @@ class PresenceState(object):
     ONLINE = "online"
 
 
-class JoinRules(object):
+class JoinRules:
     PUBLIC = "public"
     KNOCK = "knock"
     INVITE = "invite"
     PRIVATE = "private"
 
 
-class LoginType(object):
+class LoginType:
     PASSWORD = "m.login.password"
     EMAIL_IDENTITY = "m.login.email.identity"
     MSISDN = "m.login.msisdn"
@@ -65,7 +65,7 @@ class LoginType(object):
     DUMMY = "m.login.dummy"
 
 
-class EventTypes(object):
+class EventTypes:
     Member = "m.room.member"
     Create = "m.room.create"
     Tombstone = "m.room.tombstone"
@@ -96,17 +96,17 @@ class EventTypes(object):
     Presence = "m.presence"
 
 
-class RejectedReason(object):
+class RejectedReason:
     AUTH_ERROR = "auth_error"
 
 
-class RoomCreationPreset(object):
+class RoomCreationPreset:
     PRIVATE_CHAT = "private_chat"
     PUBLIC_CHAT = "public_chat"
     TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
 
 
-class ThirdPartyEntityKind(object):
+class ThirdPartyEntityKind:
     USER = "user"
     LOCATION = "location"
 
@@ -115,7 +115,7 @@ ServerNoticeMsgType = "m.server_notice"
 ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
 
 
-class UserTypes(object):
+class UserTypes:
     """Allows for user type specific behaviour. With the benefit of hindsight
     'admin' and 'guest' users should also be UserTypes. Normal users are type None
     """
@@ -125,7 +125,7 @@ class UserTypes(object):
     ALL_USER_TYPES = (SUPPORT, BOT)
 
 
-class RelationTypes(object):
+class RelationTypes:
     """The types of relations known to this server.
     """
 
@@ -134,14 +134,14 @@ class RelationTypes(object):
     REFERENCE = "m.reference"
 
 
-class LimitBlockingTypes(object):
+class LimitBlockingTypes:
     """Reasons that a server may be blocked"""
 
     MONTHLY_ACTIVE_USER = "monthly_active_user"
     HS_DISABLED = "hs_disabled"
 
 
-class EventContentFields(object):
+class EventContentFields:
     """Fields found in events' content, regardless of type."""
 
     # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
@@ -152,6 +152,6 @@ class EventContentFields(object):
     SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
 
 
-class RoomEncryptionAlgorithms(object):
+class RoomEncryptionAlgorithms:
     MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
     DEFAULT = MEGOLM_V1_AES_SHA2
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 6e40630ab6..94a9e58eae 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -21,17 +21,17 @@ import typing
 from http import HTTPStatus
 from typing import Dict, List, Optional, Union
 
-from canonicaljson import json
-
 from twisted.web import http
 
+from synapse.util import json_decoder
+
 if typing.TYPE_CHECKING:
     from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
 
-class Codes(object):
+class Codes:
     UNRECOGNIZED = "M_UNRECOGNIZED"
     UNAUTHORIZED = "M_UNAUTHORIZED"
     FORBIDDEN = "M_FORBIDDEN"
@@ -593,7 +593,7 @@ class HttpResponseException(CodeMessageException):
         # try to parse the body as json, to get better errcode/msg, but
         # default to M_UNKNOWN with the HTTP status as the error text
         try:
-            j = json.loads(self.response.decode("utf-8"))
+            j = json_decoder.decode(self.response.decode("utf-8"))
         except ValueError:
             j = {}
 
@@ -604,3 +604,11 @@ class HttpResponseException(CodeMessageException):
         errmsg = j.pop("error", self.msg)
 
         return ProxiedRequestError(self.code, errmsg, errcode, j)
+
+
+class ShadowBanError(Exception):
+    """
+    Raised when a shadow-banned user attempts to perform an action.
+
+    This should be caught and a proper "fake" success response sent to the user.
+    """
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 7393d6cb74..bb33345be6 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,15 +15,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import json
 from typing import List
 
 import jsonschema
-from canonicaljson import json
 from jsonschema import FormatChecker
 
 from synapse.api.constants import EventContentFields
 from synapse.api.errors import SynapseError
-from synapse.storage.presence import UserPresenceState
+from synapse.api.presence import UserPresenceState
 from synapse.types import RoomID, UserID
 
 FILTER_SCHEMA = {
@@ -130,7 +130,7 @@ def matrix_user_id_validator(user_id_str):
     return UserID.from_string(user_id_str)
 
 
-class Filtering(object):
+class Filtering:
     def __init__(self, hs):
         super(Filtering, self).__init__()
         self.store = hs.get_datastore()
@@ -168,7 +168,7 @@ class Filtering(object):
             raise SynapseError(400, str(e))
 
 
-class FilterCollection(object):
+class FilterCollection:
     def __init__(self, filter_json):
         self._filter_json = filter_json
 
@@ -249,7 +249,7 @@ class FilterCollection(object):
         )
 
 
-class Filter(object):
+class Filter:
     def __init__(self, filter_json):
         self.filter_json = filter_json
 
diff --git a/synapse/storage/presence.py b/synapse/api/presence.py
index 18a462f0ee..18a462f0ee 100644
--- a/synapse/storage/presence.py
+++ b/synapse/api/presence.py
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index e62ae50ac2..5d9d5a228f 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -21,7 +21,7 @@ from synapse.types import Requester
 from synapse.util import Clock
 
 
-class Ratelimiter(object):
+class Ratelimiter:
     """
     Ratelimit actions marked by arbitrary keys.
 
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index d7baf2bc39..f3ecbf36b6 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -18,7 +18,7 @@ from typing import Dict
 import attr
 
 
-class EventFormatVersions(object):
+class EventFormatVersions:
     """This is an internal enum for tracking the version of the event format,
     independently from the room version.
     """
@@ -35,20 +35,20 @@ KNOWN_EVENT_FORMAT_VERSIONS = {
 }
 
 
-class StateResolutionVersions(object):
+class StateResolutionVersions:
     """Enum to identify the state resolution algorithms"""
 
     V1 = 1  # room v1 state res
     V2 = 2  # MSC1442 state res: room v2 and later
 
 
-class RoomDisposition(object):
+class RoomDisposition:
     STABLE = "stable"
     UNSTABLE = "unstable"
 
 
 @attr.s(slots=True, frozen=True)
-class RoomVersion(object):
+class RoomVersion:
     """An object which describes the unique attributes of a room version."""
 
     identifier = attr.ib()  # str; the identifier for this version
@@ -69,7 +69,7 @@ class RoomVersion(object):
     limit_notifications_power_levels = attr.ib(type=bool)
 
 
-class RoomVersions(object):
+class RoomVersions:
     V1 = RoomVersion(
         "1",
         RoomDisposition.STABLE,
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index bd03ebca5a..6379c86dde 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -21,6 +21,7 @@ from urllib.parse import urlencode
 
 from synapse.config import ConfigError
 
+SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client"
 CLIENT_API_PREFIX = "/_matrix/client"
 FEDERATION_PREFIX = "/_matrix/federation"
 FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
@@ -33,7 +34,7 @@ MEDIA_PREFIX = "/_matrix/media/r0"
 LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
 
 
-class ConsentURIBuilder(object):
+class ConsentURIBuilder:
     def __init__(self, hs_config):
         """
         Args:
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 2b2cd795e0..fb476ddaf5 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -334,6 +334,13 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
     This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we
     can run out of file descriptors and infinite loop if we attempt to do too
     many DNS queries at once
+
+    XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless
+    you explicitly install twisted.names as the resolver; rather it uses a GAIResolver
+    backed by the reactor's default threadpool (which is limited to 10 threads). So
+    (a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't
+    understand why we would run out of FDs if we did too many lookups at once.
+    -- richvdh 2020/08/29
     """
     new_resolver = _LimitedHostnameResolver(
         reactor.nameResolver, max_dns_requests_in_flight
@@ -342,7 +349,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
     reactor.installNameResolver(new_resolver)
 
 
-class _LimitedHostnameResolver(object):
+class _LimitedHostnameResolver:
     """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
     """
 
@@ -402,7 +409,7 @@ class _LimitedHostnameResolver(object):
             yield deferred
 
 
-class _DeferredResolutionReceiver(object):
+class _DeferredResolutionReceiver:
     """Wraps a IResolutionReceiver and simply resolves the given deferred when
     resolution is complete
     """
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index a37818fe9a..7d309b1bb0 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -14,13 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import argparse
+import json
 import logging
 import os
 import sys
 import tempfile
 
-from canonicaljson import json
-
 from twisted.internet import defer, task
 
 import synapse
@@ -79,8 +78,7 @@ class AdminCmdServer(HomeServer):
         pass
 
 
-@defer.inlineCallbacks
-def export_data_command(hs, args):
+async def export_data_command(hs, args):
     """Export data for a user.
 
     Args:
@@ -91,10 +89,8 @@ def export_data_command(hs, args):
     user_id = args.user_id
     directory = args.output_directory
 
-    res = yield defer.ensureDeferred(
-        hs.get_handlers().admin_handler.export_user_data(
-            user_id, FileExfiltrationWriter(user_id, directory=directory)
-        )
+    res = await hs.get_handlers().admin_handler.export_user_data(
+        user_id, FileExfiltrationWriter(user_id, directory=directory)
     )
     print(res)
 
@@ -232,14 +228,15 @@ def start(config_options):
     # We also make sure that `_base.start` gets run before we actually run the
     # command.
 
-    @defer.inlineCallbacks
-    def run(_reactor):
+    async def run():
         with LoggingContext("command"):
-            yield _base.start(ss, [])
-            yield args.func(ss, args)
+            _base.start(ss, [])
+            await args.func(ss, args)
 
     _base.start_worker_reactor(
-        "synapse-admin-cmd", config, run_command=lambda: task.react(run)
+        "synapse-admin-cmd",
+        config,
+        run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())),
     )
 
 
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 739b013d4c..f985810e88 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -745,7 +745,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
             self.send_handler.wake_destination(server)
 
 
-class FederationSenderHandler(object):
+class FederationSenderHandler:
     """Processes the fedration replication stream
 
     This class is only instantiate on the worker responsible for sending outbound
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 98d0d14a12..b08319ca77 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -48,6 +48,7 @@ from synapse.api.urls import (
 from synapse.app import _base
 from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
 from synapse.config._base import ConfigError
+from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig
 from synapse.federation.transport.server import TransportLayerServer
@@ -209,6 +210,15 @@ class SynapseHomeServer(HomeServer):
 
                 resources["/_matrix/saml2"] = SAML2Resource(self)
 
+            if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+                from synapse.rest.synapse.client.password_reset import (
+                    PasswordResetSubmitTokenResource,
+                )
+
+                resources[
+                    "/_synapse/client/password_reset/email/submit_token"
+                ] = PasswordResetSubmitTokenResource(self)
+
         if name == "consent":
             from synapse.rest.consent.consent_resource import ConsentResource
 
@@ -411,26 +421,24 @@ def setup(config_options):
 
         return provision
 
-    @defer.inlineCallbacks
-    def reprovision_acme():
+    async def reprovision_acme():
         """
         Provision a certificate from ACME, if required, and reload the TLS
         certificate if it's renewed.
         """
-        reprovisioned = yield defer.ensureDeferred(do_acme())
+        reprovisioned = await do_acme()
         if reprovisioned:
             _base.refresh_certificate(hs)
 
-    @defer.inlineCallbacks
-    def start():
+    async def start():
         try:
             # Run the ACME provisioning code, if it's enabled.
             if hs.config.acme_enabled:
                 acme = hs.get_acme_handler()
                 # Start up the webservices which we will respond to ACME
                 # challenges with, and then provision.
-                yield defer.ensureDeferred(acme.start_listening())
-                yield defer.ensureDeferred(do_acme())
+                await acme.start_listening()
+                await do_acme()
 
                 # Check if it needs to be reprovisioned every day.
                 hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
@@ -439,8 +447,8 @@ def setup(config_options):
             if hs.config.oidc_enabled:
                 oidc = hs.get_oidc_handler()
                 # Loading the provider metadata also ensures the provider config is valid.
-                yield defer.ensureDeferred(oidc.load_metadata())
-                yield defer.ensureDeferred(oidc.load_jwks())
+                await oidc.load_metadata()
+                await oidc.load_jwks()
 
             _base.start(hs, config.listeners)
 
@@ -456,7 +464,7 @@ def setup(config_options):
                 reactor.stop()
             sys.exit(1)
 
-    reactor.callWhenRunning(start)
+    reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
 
     return hs
 
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 1ffdc1ed95..13ec1f71a6 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -14,20 +14,25 @@
 # limitations under the License.
 import logging
 import re
+from typing import TYPE_CHECKING
 
 from synapse.api.constants import EventTypes
+from synapse.appservice.api import ApplicationServiceApi
 from synapse.types import GroupID, get_domain_from_id
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.storage.databases.main import DataStore
+
 logger = logging.getLogger(__name__)
 
 
-class ApplicationServiceState(object):
+class ApplicationServiceState:
     DOWN = "down"
     UP = "up"
 
 
-class AppServiceTransaction(object):
+class AppServiceTransaction:
     """Represents an application service transaction."""
 
     def __init__(self, service, id, events):
@@ -35,19 +40,19 @@ class AppServiceTransaction(object):
         self.id = id
         self.events = events
 
-    def send(self, as_api):
+    async def send(self, as_api: ApplicationServiceApi) -> bool:
         """Sends this transaction using the provided AS API interface.
 
         Args:
-            as_api(ApplicationServiceApi): The API to use to send.
+            as_api: The API to use to send.
         Returns:
-            An Awaitable which resolves to True if the transaction was sent.
+            True if the transaction was sent.
         """
-        return as_api.push_bulk(
+        return await as_api.push_bulk(
             service=self.service, events=self.events, txn_id=self.id
         )
 
-    def complete(self, store):
+    async def complete(self, store: "DataStore") -> None:
         """Completes this transaction as successful.
 
         Marks this transaction ID on the application service and removes the
@@ -55,13 +60,11 @@ class AppServiceTransaction(object):
 
         Args:
             store: The database store to operate on.
-        Returns:
-            A Deferred which resolves to True if the transaction was completed.
         """
-        return store.complete_appservice_txn(service=self.service, txn_id=self.id)
+        await store.complete_appservice_txn(service=self.service, txn_id=self.id)
 
 
-class ApplicationService(object):
+class ApplicationService:
     """Defines an application service. This definition is mostly what is
     provided to the /register AS API.
 
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e72a0b9ac0..bb6fa8299a 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -14,18 +14,20 @@
 # limitations under the License.
 import logging
 import urllib
+from typing import TYPE_CHECKING, Optional
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, ThirdPartyEntityKind
 from synapse.api.errors import CodeMessageException
 from synapse.events.utils import serialize_event
 from synapse.http.client import SimpleHttpClient
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.response_cache import ResponseCache
 
+if TYPE_CHECKING:
+    from synapse.appservice import ApplicationService
+
 logger = logging.getLogger(__name__)
 
 sent_transactions_counter = Counter(
@@ -163,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_3pe to %s threw exception %s", uri, ex)
             return []
 
-    def get_3pe_protocol(self, service, protocol):
+    async def get_3pe_protocol(
+        self, service: "ApplicationService", protocol: str
+    ) -> Optional[JsonDict]:
         if service.url is None:
             return {}
 
-        @defer.inlineCallbacks
-        def _get():
+        async def _get() -> Optional[JsonDict]:
             uri = "%s%s/thirdparty/protocol/%s" % (
                 service.url,
                 APP_SERVICE_PREFIX,
                 urllib.parse.quote(protocol),
             )
             try:
-                info = yield defer.ensureDeferred(self.get_json(uri, {}))
+                info = await self.get_json(uri, {})
 
                 if not _is_valid_3pe_metadata(info):
                     logger.warning(
@@ -196,7 +199,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                 return None
 
         key = (service.id, protocol)
-        return self.protocol_meta_cache.wrap(key, _get)
+        return await self.protocol_meta_cache.wrap(key, _get)
 
     async def push_bulk(self, service, events, txn_id=None):
         if service.url is None:
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index d5204b1314..8eb8c6f51c 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 logger = logging.getLogger(__name__)
 
 
-class ApplicationServiceScheduler(object):
+class ApplicationServiceScheduler:
     """ Public facing API for this module. Does the required DI to tie the
     components together. This also serves as the "event_pool", which in this
     case is a simple array.
@@ -86,7 +86,7 @@ class ApplicationServiceScheduler(object):
         self.queuer.enqueue(service, event)
 
 
-class _ServiceQueuer(object):
+class _ServiceQueuer:
     """Queue of events waiting to be sent to appservices.
 
     Groups events into transactions per-appservice, and sends them on to the
@@ -133,7 +133,7 @@ class _ServiceQueuer(object):
             self.requests_in_flight.discard(service.id)
 
 
-class _TransactionController(object):
+class _TransactionController:
     """Transaction manager.
 
     Builds AppServiceTransactions and runs their lifecycle. Also starts a Recoverer
@@ -209,7 +209,7 @@ class _TransactionController(object):
         return state == ApplicationServiceState.UP or state is None
 
 
-class _Recoverer(object):
+class _Recoverer:
     """Manages retries and backoff for a DOWN appservice.
 
     We have one of these for each appservice which is currently considered DOWN.
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index fd137853b1..bb9bf8598d 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,12 +18,16 @@
 import argparse
 import errno
 import os
+import time
+import urllib.parse
 from collections import OrderedDict
 from hashlib import sha256
 from textwrap import dedent
-from typing import Any, List, MutableMapping, Optional
+from typing import Any, Callable, List, MutableMapping, Optional
 
 import attr
+import jinja2
+import pkg_resources
 import yaml
 
 
@@ -84,7 +88,7 @@ def path_exists(file_path):
         return False
 
 
-class Config(object):
+class Config:
     """
     A configuration section, containing configuration keys and values.
 
@@ -100,6 +104,11 @@ class Config(object):
     def __init__(self, root_config=None):
         self.root = root_config
 
+        # Get the path to the default Synapse template directory
+        self.default_template_dir = pkg_resources.resource_filename(
+            "synapse", "res/templates"
+        )
+
     def __getattr__(self, item: str) -> Any:
         """
         Try and fetch a configuration option that does not exist on this class.
@@ -184,8 +193,97 @@ class Config(object):
         with open(file_path) as file_stream:
             return file_stream.read()
 
+    def read_templates(
+        self, filenames: List[str], custom_template_directory: Optional[str] = None,
+    ) -> List[jinja2.Template]:
+        """Load a list of template files from disk using the given variables.
+
+        This function will attempt to load the given templates from the default Synapse
+        template directory. If `custom_template_directory` is supplied, that directory
+        is tried first.
+
+        Files read are treated as Jinja templates. These templates are not rendered yet.
+
+        Args:
+            filenames: A list of template filenames to read.
+
+            custom_template_directory: A directory to try to look for the templates
+                before using the default Synapse template directory instead.
+
+        Raises:
+            ConfigError: if the file's path is incorrect or otherwise cannot be read.
+
+        Returns:
+            A list of jinja2 templates.
+        """
+        templates = []
+        search_directories = [self.default_template_dir]
+
+        # The loader will first look in the custom template directory (if specified) for the
+        # given filename. If it doesn't find it, it will use the default template dir instead
+        if custom_template_directory:
+            # Check that the given template directory exists
+            if not self.path_exists(custom_template_directory):
+                raise ConfigError(
+                    "Configured template directory does not exist: %s"
+                    % (custom_template_directory,)
+                )
+
+            # Search the custom template directory as well
+            search_directories.insert(0, custom_template_directory)
+
+        loader = jinja2.FileSystemLoader(search_directories)
+        env = jinja2.Environment(loader=loader, autoescape=True)
+
+        # Update the environment with our custom filters
+        env.filters.update(
+            {
+                "format_ts": _format_ts_filter,
+                "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
+            }
+        )
+
+        for filename in filenames:
+            # Load the template
+            template = env.get_template(filename)
+            templates.append(template)
+
+        return templates
 
-class RootConfig(object):
+
+def _format_ts_filter(value: int, format: str):
+    return time.strftime(format, time.localtime(value / 1000))
+
+
+def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
+    """Create and return a jinja2 filter that converts MXC urls to HTTP
+
+    Args:
+        public_baseurl: The public, accessible base URL of the homeserver
+    """
+
+    def mxc_to_http_filter(value, width, height, resize_method="crop"):
+        if value[0:6] != "mxc://":
+            return ""
+
+        server_and_media_id = value[6:]
+        fragment = None
+        if "#" in server_and_media_id:
+            server_and_media_id, fragment = server_and_media_id.split("#", 1)
+            fragment = "#" + fragment
+
+        params = {"width": width, "height": height, "method": resize_method}
+        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+            public_baseurl,
+            server_and_media_id,
+            urllib.parse.urlencode(params),
+            fragment or "",
+        )
+
+    return mxc_to_http_filter
+
+
+class RootConfig:
     """
     Holder of an application's configuration.
 
@@ -734,11 +832,26 @@ class ShardedWorkerHandlingConfig:
     def should_handle(self, instance_name: str, key: str) -> bool:
         """Whether this instance is responsible for handling the given key.
         """
-
-        # If multiple instances are not defined we always return true.
+        # If multiple instances are not defined we always return true
         if not self.instances or len(self.instances) == 1:
             return True
 
+        return self.get_instance(key) == instance_name
+
+    def get_instance(self, key: str) -> str:
+        """Get the instance responsible for handling the given key.
+
+        Note: For things like federation sending the config for which instance
+        is sending is known only to the sender instance if there is only one.
+        Therefore `should_handle` should be used where possible.
+        """
+
+        if not self.instances:
+            return "master"
+
+        if len(self.instances) == 1:
+            return self.instances[0]
+
         # We shard by taking the hash, modulo it by the number of instances and
         # then checking whether this instance matches the instance at that
         # index.
@@ -748,7 +861,7 @@ class ShardedWorkerHandlingConfig:
         dest_hash = sha256(key.encode("utf8")).digest()
         dest_int = int.from_bytes(dest_hash, byteorder="little")
         remainder = dest_int % (len(self.instances))
-        return self.instances[remainder] == instance_name
+        return self.instances[remainder]
 
 
 __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index eb911e8f9f..b8faafa9bd 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig:
     instances: List[str]
     def __init__(self, instances: List[str]) -> None: ...
     def should_handle(self, instance_name: str, key: str) -> bool: ...
+    def get_instance(self, key: str) -> str: ...
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index aff5b21ab2..8e03f14005 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -33,7 +33,7 @@ _DEFAULT_FACTOR_SIZE = 0.5
 _DEFAULT_EVENT_CACHE_SIZE = "10K"
 
 
-class CacheProperties(object):
+class CacheProperties:
     def __init__(self):
         # The default factor size for all caches
         self.default_factor_size = float(
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index a63acbdc63..72b42bfd62 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -23,7 +23,6 @@ from enum import Enum
 from typing import Optional
 
 import attr
-import pkg_resources
 
 from ._base import Config, ConfigError
 
@@ -98,21 +97,18 @@ class EmailConfig(Config):
             if parsed[1] == "":
                 raise RuntimeError("Invalid notif_from address")
 
+        # A user-configurable template directory
         template_dir = email_config.get("template_dir")
-        # we need an absolute path, because we change directory after starting (and
-        # we don't yet know what auxiliary templates like mail.css we will need).
-        # (Note that loading as package_resources with jinja.PackageLoader doesn't
-        # work for the same reason.)
-        if not template_dir:
-            template_dir = pkg_resources.resource_filename("synapse", "res/templates")
-
-        self.email_template_dir = os.path.abspath(template_dir)
+        if isinstance(template_dir, str):
+            # We need an absolute path, because we change directory after starting (and
+            # we don't yet know what auxiliary templates like mail.css we will need).
+            template_dir = os.path.abspath(template_dir)
+        elif template_dir is not None:
+            # If template_dir is something other than a str or None, warn the user
+            raise ConfigError("Config option email.template_dir must be type str")
 
         self.email_enable_notifs = email_config.get("enable_notifs", False)
 
-        account_validity_config = config.get("account_validity") or {}
-        account_validity_renewal_enabled = account_validity_config.get("renew_at")
-
         self.threepid_behaviour_email = (
             # Have Synapse handle the email sending if account_threepid_delegates.email
             # is not defined
@@ -166,19 +162,6 @@ class EmailConfig(Config):
             email_config.get("validation_token_lifetime", "1h")
         )
 
-        if (
-            self.email_enable_notifs
-            or account_validity_renewal_enabled
-            or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
-        ):
-            # make sure we can import the required deps
-            import bleach
-            import jinja2
-
-            # prevent unused warnings
-            jinja2
-            bleach
-
         if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
             missing = []
             if not self.email_notif_from:
@@ -196,84 +179,90 @@ class EmailConfig(Config):
 
             # These email templates have placeholders in them, and thus must be
             # parsed using a templating engine during a request
-            self.email_password_reset_template_html = email_config.get(
+            password_reset_template_html = email_config.get(
                 "password_reset_template_html", "password_reset.html"
             )
-            self.email_password_reset_template_text = email_config.get(
+            password_reset_template_text = email_config.get(
                 "password_reset_template_text", "password_reset.txt"
             )
-            self.email_registration_template_html = email_config.get(
+            registration_template_html = email_config.get(
                 "registration_template_html", "registration.html"
             )
-            self.email_registration_template_text = email_config.get(
+            registration_template_text = email_config.get(
                 "registration_template_text", "registration.txt"
             )
-            self.email_add_threepid_template_html = email_config.get(
+            add_threepid_template_html = email_config.get(
                 "add_threepid_template_html", "add_threepid.html"
             )
-            self.email_add_threepid_template_text = email_config.get(
+            add_threepid_template_text = email_config.get(
                 "add_threepid_template_text", "add_threepid.txt"
             )
 
-            self.email_password_reset_template_failure_html = email_config.get(
+            password_reset_template_failure_html = email_config.get(
                 "password_reset_template_failure_html", "password_reset_failure.html"
             )
-            self.email_registration_template_failure_html = email_config.get(
+            registration_template_failure_html = email_config.get(
                 "registration_template_failure_html", "registration_failure.html"
             )
-            self.email_add_threepid_template_failure_html = email_config.get(
+            add_threepid_template_failure_html = email_config.get(
                 "add_threepid_template_failure_html", "add_threepid_failure.html"
             )
 
             # These templates do not support any placeholder variables, so we
             # will read them from disk once during setup
-            email_password_reset_template_success_html = email_config.get(
+            password_reset_template_success_html = email_config.get(
                 "password_reset_template_success_html", "password_reset_success.html"
             )
-            email_registration_template_success_html = email_config.get(
+            registration_template_success_html = email_config.get(
                 "registration_template_success_html", "registration_success.html"
             )
-            email_add_threepid_template_success_html = email_config.get(
+            add_threepid_template_success_html = email_config.get(
                 "add_threepid_template_success_html", "add_threepid_success.html"
             )
 
-            # Check templates exist
-            for f in [
+            # Read all templates from disk
+            (
                 self.email_password_reset_template_html,
                 self.email_password_reset_template_text,
                 self.email_registration_template_html,
                 self.email_registration_template_text,
                 self.email_add_threepid_template_html,
                 self.email_add_threepid_template_text,
+                self.email_password_reset_template_confirmation_html,
                 self.email_password_reset_template_failure_html,
                 self.email_registration_template_failure_html,
                 self.email_add_threepid_template_failure_html,
-                email_password_reset_template_success_html,
-                email_registration_template_success_html,
-                email_add_threepid_template_success_html,
-            ]:
-                p = os.path.join(self.email_template_dir, f)
-                if not os.path.isfile(p):
-                    raise ConfigError("Unable to find template file %s" % (p,))
-
-            # Retrieve content of web templates
-            filepath = os.path.join(
-                self.email_template_dir, email_password_reset_template_success_html
-            )
-            self.email_password_reset_template_success_html = self.read_file(
-                filepath, "email.password_reset_template_success_html"
+                password_reset_template_success_html_template,
+                registration_template_success_html_template,
+                add_threepid_template_success_html_template,
+            ) = self.read_templates(
+                [
+                    password_reset_template_html,
+                    password_reset_template_text,
+                    registration_template_html,
+                    registration_template_text,
+                    add_threepid_template_html,
+                    add_threepid_template_text,
+                    "password_reset_confirmation.html",
+                    password_reset_template_failure_html,
+                    registration_template_failure_html,
+                    add_threepid_template_failure_html,
+                    password_reset_template_success_html,
+                    registration_template_success_html,
+                    add_threepid_template_success_html,
+                ],
+                template_dir,
             )
-            filepath = os.path.join(
-                self.email_template_dir, email_registration_template_success_html
-            )
-            self.email_registration_template_success_html_content = self.read_file(
-                filepath, "email.registration_template_success_html"
+
+            # Render templates that do not contain any placeholders
+            self.email_password_reset_template_success_html_content = (
+                password_reset_template_success_html_template.render()
             )
-            filepath = os.path.join(
-                self.email_template_dir, email_add_threepid_template_success_html
+            self.email_registration_template_success_html_content = (
+                registration_template_success_html_template.render()
             )
-            self.email_add_threepid_template_success_html_content = self.read_file(
-                filepath, "email.add_threepid_template_success_html"
+            self.email_add_threepid_template_success_html_content = (
+                add_threepid_template_success_html_template.render()
             )
 
         if self.email_enable_notifs:
@@ -290,17 +279,19 @@ class EmailConfig(Config):
                     % (", ".join(missing),)
                 )
 
-            self.email_notif_template_html = email_config.get(
+            notif_template_html = email_config.get(
                 "notif_template_html", "notif_mail.html"
             )
-            self.email_notif_template_text = email_config.get(
+            notif_template_text = email_config.get(
                 "notif_template_text", "notif_mail.txt"
             )
 
-            for f in self.email_notif_template_text, self.email_notif_template_html:
-                p = os.path.join(self.email_template_dir, f)
-                if not os.path.isfile(p):
-                    raise ConfigError("Unable to find email template file %s" % (p,))
+            (
+                self.email_notif_template_html,
+                self.email_notif_template_text,
+            ) = self.read_templates(
+                [notif_template_html, notif_template_text], template_dir,
+            )
 
             self.email_notif_for_new_users = email_config.get(
                 "notif_for_new_users", True
@@ -309,18 +300,20 @@ class EmailConfig(Config):
                 "client_base_url", email_config.get("riot_base_url", None)
             )
 
-        if account_validity_renewal_enabled:
-            self.email_expiry_template_html = email_config.get(
+        if self.account_validity.renew_by_email_enabled:
+            expiry_template_html = email_config.get(
                 "expiry_template_html", "notice_expiry.html"
             )
-            self.email_expiry_template_text = email_config.get(
+            expiry_template_text = email_config.get(
                 "expiry_template_text", "notice_expiry.txt"
             )
 
-            for f in self.email_expiry_template_text, self.email_expiry_template_html:
-                p = os.path.join(self.email_template_dir, f)
-                if not os.path.isfile(p):
-                    raise ConfigError("Unable to find email template file %s" % (p,))
+            (
+                self.account_validity_template_html,
+                self.account_validity_template_text,
+            ) = self.read_templates(
+                [expiry_template_html, expiry_template_text], template_dir,
+            )
 
         subjects_config = email_config.get("subjects", {})
         subjects = {}
@@ -400,9 +393,7 @@ class EmailConfig(Config):
           # Directory in which Synapse will try to find the template files below.
           # If not set, default templates from within the Synapse package will be used.
           #
-          # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
-          # If you *do* uncomment it, you will need to make sure that all the templates
-          # below are in the directory.
+          # Do not uncomment this setting unless you want to customise the templates.
           #
           # Synapse will look for the following templates in this directory:
           #
@@ -415,9 +406,13 @@ class EmailConfig(Config):
           # * The contents of password reset emails sent by the homeserver:
           #   'password_reset.html' and 'password_reset.txt'
           #
-          # * HTML pages for success and failure that a user will see when they follow
-          #   the link in the password reset email: 'password_reset_success.html' and
-          #   'password_reset_failure.html'
+          # * An HTML page that a user will see when they follow the link in the password
+          #   reset email. The user will be asked to confirm the action before their
+          #   password is reset: 'password_reset_confirmation.html'
+          #
+          # * HTML pages for success and failure that a user will see when they confirm
+          #   the password reset flow using the page above: 'password_reset_success.html'
+          #   and 'password_reset_failure.html'
           #
           # * The contents of address verification emails sent during registration:
           #   'registration.html' and 'registration.txt'
diff --git a/synapse/config/key.py b/synapse/config/key.py
index b529ea5da0..de964dff13 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -82,7 +82,7 @@ logger = logging.getLogger(__name__)
 
 
 @attr.s
-class TrustedKeyServer(object):
+class TrustedKeyServer:
     # string: name of the server.
     server_name = attr.ib()
 
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index c96e6ef62a..13d6f6a3ea 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -17,6 +17,7 @@ import logging
 import logging.config
 import os
 import sys
+import threading
 from string import Template
 
 import yaml
@@ -25,6 +26,7 @@ from twisted.logger import (
     ILogObserver,
     LogBeginner,
     STDLibLogObserver,
+    eventAsText,
     globalLogBeginner,
 )
 
@@ -216,8 +218,9 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
     # system.
     observer = STDLibLogObserver()
 
-    def _log(event):
+    threadlocal = threading.local()
 
+    def _log(event):
         if "log_text" in event:
             if event["log_text"].startswith("DNSDatagramProtocol starting on "):
                 return
@@ -228,7 +231,25 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
             if event["log_text"].startswith("Timing out client"):
                 return
 
-        return observer(event)
+        # this is a workaround to make sure we don't get stack overflows when the
+        # logging system raises an error which is written to stderr which is redirected
+        # to the logging system, etc.
+        if getattr(threadlocal, "active", False):
+            # write the text of the event, if any, to the *real* stderr (which may
+            # be redirected to /dev/null, but there's not much we can do)
+            try:
+                event_text = eventAsText(event)
+                print("logging during logging: %s" % event_text, file=sys.__stderr__)
+            except Exception:
+                # gah.
+                pass
+            return
+
+        try:
+            threadlocal.active = True
+            return observer(event)
+        finally:
+            threadlocal.active = False
 
     logBeginner.beginLoggingTo([_log], redirectStandardIO=not config.no_redirect_stdio)
     if not config.no_redirect_stdio:
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 6aad0d37c0..dfd27e1523 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -22,7 +22,7 @@ from ._base import Config, ConfigError
 
 
 @attr.s
-class MetricsFlags(object):
+class MetricsFlags:
     known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool))
 
     @classmethod
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index b2c78ac40c..14b8836197 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -17,7 +17,7 @@ from typing import Dict
 from ._base import Config
 
 
-class RateLimitConfig(object):
+class RateLimitConfig:
     def __init__(
         self,
         config: Dict[str, float],
@@ -27,7 +27,7 @@ class RateLimitConfig(object):
         self.burst_count = config.get("burst_count", defaults["burst_count"])
 
 
-class FederationRateLimitConfig(object):
+class FederationRateLimitConfig:
     _items_and_default = {
         "window_size": 1000,
         "sleep_limit": 10,
diff --git a/synapse/config/room.py b/synapse/config/room.py
index 52cf0b62fc..692d7a1936 100644
--- a/synapse/config/room.py
+++ b/synapse/config/room.py
@@ -22,7 +22,7 @@ from ._base import Config, ConfigError
 logger = logging.Logger(__name__)
 
 
-class RoomDefaultEncryptionTypes(object):
+class RoomDefaultEncryptionTypes:
     """Possible values for the encryption_enabled_by_default_for_room_type config option"""
 
     ALL = "all"
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 7ac7699676..6de1f9d103 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -149,7 +149,7 @@ class RoomDirectoryConfig(Config):
         return False
 
 
-class _RoomDirectoryRule(object):
+class _RoomDirectoryRule:
     """Helper class to test whether a room directory action is allowed, like
     creating an alias or publishing a room.
     """
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 9277b5f342..99aa8b3bf1 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -18,8 +18,6 @@ import logging
 from typing import Any, List
 
 import attr
-import jinja2
-import pkg_resources
 
 from synapse.python_dependencies import DependencyException, check_requirements
 from synapse.util.module_loader import load_module, load_python_module
@@ -171,16 +169,6 @@ class SAML2Config(Config):
             saml2_config.get("saml_session_lifetime", "15m")
         )
 
-        template_dir = saml2_config.get("template_dir")
-        if not template_dir:
-            template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
-
-        loader = jinja2.FileSystemLoader(template_dir)
-        # enable auto-escape here, to having to remember to escape manually in the
-        # template
-        env = jinja2.Environment(loader=loader, autoescape=True)
-        self.saml2_error_html_template = env.get_template("saml_error.html")
-
     def _default_saml_config_dict(
         self, required_attributes: set, optional_attributes: set
     ):
@@ -233,11 +221,14 @@ class SAML2Config(Config):
         # At least one of `sp_config` or `config_path` must be set in this section to
         # enable SAML login.
         #
-        # (You will probably also want to set the following options to `false` to
+        # You will probably also want to set the following options to `false` to
         # disable the regular login/registration flows:
         #   * enable_registration
         #   * password_config.enabled
         #
+        # You will also want to investigate the settings under the "sso" configuration
+        # section below.
+        #
         # Once SAML support is enabled, a metadata file will be exposed at
         # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
         # use to configure your SAML IdP with. Alternatively, you can manually configure
@@ -359,31 +350,6 @@ class SAML2Config(Config):
           #    value: "staff"
           #  - attribute: department
           #    value: "sales"
-
-          # Directory in which Synapse will try to find the template files below.
-          # If not set, default templates from within the Synapse package will be used.
-          #
-          # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
-          # If you *do* uncomment it, you will need to make sure that all the templates
-          # below are in the directory.
-          #
-          # Synapse will look for the following templates in this directory:
-          #
-          # * HTML page to display to users if something goes wrong during the
-          #   authentication process: 'saml_error.html'.
-          #
-          #   When rendering, this template is given the following variables:
-          #     * code: an HTML error code corresponding to the error that is being
-          #       returned (typically 400 or 500)
-          #
-          #     * msg: a textual message describing the error.
-          #
-          #   The variables will automatically be HTML-escaped.
-          #
-          # You can see the default templates at:
-          # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
-          #
-          #template_dir: "res/templates"
         """ % {
             "config_dir_path": config_dir_path
         }
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 9f15ed109e..532b910470 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,14 +19,13 @@ import logging
 import os.path
 import re
 from textwrap import indent
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Set
 
 import attr
 import yaml
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.http.endpoint import parse_and_validate_server_name
-from synapse.python_dependencies import DependencyException, check_requirements
 
 from ._base import Config, ConfigError
 
@@ -425,7 +424,7 @@ class ServerConfig(Config):
         self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
 
         @attr.s
-        class LimitRemoteRoomsConfig(object):
+        class LimitRemoteRoomsConfig:
             enabled = attr.ib(
                 validator=attr.validators.instance_of(bool), default=False
             )
@@ -508,8 +507,6 @@ class ServerConfig(Config):
                 )
             )
 
-        _check_resource_config(self.listeners)
-
         self.cleanup_extremities_with_dummy_events = config.get(
             "cleanup_extremities_with_dummy_events", True
         )
@@ -545,6 +542,19 @@ class ServerConfig(Config):
             users_new_default_push_rules
         )  # type: set
 
+        # Whitelist of domain names that given next_link parameters must have
+        next_link_domain_whitelist = config.get(
+            "next_link_domain_whitelist"
+        )  # type: Optional[List[str]]
+
+        self.next_link_domain_whitelist = None  # type: Optional[Set[str]]
+        if next_link_domain_whitelist is not None:
+            if not isinstance(next_link_domain_whitelist, list):
+                raise ConfigError("'next_link_domain_whitelist' must be a list")
+
+            # Turn the list into a set to improve lookup speed.
+            self.next_link_domain_whitelist = set(next_link_domain_whitelist)
+
     def has_tls_listener(self) -> bool:
         return any(listener.tls for listener in self.listeners)
 
@@ -964,11 +974,10 @@ class ServerConfig(Config):
           #  min_lifetime: 1d
           #  max_lifetime: 1y
 
-          # Retention policy limits. If set, a user won't be able to send a
-          # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
-          # that's not within this range. This is especially useful in closed federations,
-          # in which server admins can make sure every federating server applies the same
-          # rules.
+          # Retention policy limits. If set, and the state of a room contains a
+          # 'm.room.retention' event in its state which contains a 'min_lifetime' or a
+          # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
+          # to these limits when running purge jobs.
           #
           #allowed_lifetime_min: 1d
           #allowed_lifetime_max: 1y
@@ -994,12 +1003,19 @@ class ServerConfig(Config):
           # (e.g. every 12h), but not want that purge to be performed by a job that's
           # iterating over every room it knows, which could be heavy on the server.
           #
+          # If any purge job is configured, it is strongly recommended to have at least
+          # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
+          # set, or one job without 'shortest_max_lifetime' and one job without
+          # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
+          # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
+          # room's policy to these values is done after the policies are retrieved from
+          # Synapse's database (which is done using the range specified in a purge job's
+          # configuration).
+          #
           #purge_jobs:
-          #  - shortest_max_lifetime: 1d
-          #    longest_max_lifetime: 3d
+          #  - longest_max_lifetime: 3d
           #    interval: 12h
           #  - shortest_max_lifetime: 3d
-          #    longest_max_lifetime: 1y
           #    interval: 1d
 
         # Inhibits the /requestToken endpoints from returning an error that might leak
@@ -1011,6 +1027,24 @@ class ServerConfig(Config):
         # act as if no error happened and return a fake session ID ('sid') to clients.
         #
         #request_token_inhibit_3pid_errors: true
+
+        # A list of domains that the domain portion of 'next_link' parameters
+        # must match.
+        #
+        # This parameter is optionally provided by clients while requesting
+        # validation of an email or phone number, and maps to a link that
+        # users will be automatically redirected to after validation
+        # succeeds. Clients can make use this parameter to aid the validation
+        # process.
+        #
+        # The whitelist is applied whether the homeserver or an
+        # identity server is handling validation.
+        #
+        # The default value is no whitelist functionality; all domains are
+        # allowed. Setting this value to an empty list will instead disallow
+        # all domains.
+        #
+        #next_link_domain_whitelist: ["matrix.org"]
         """
             % locals()
         )
@@ -1133,20 +1167,3 @@ def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
                 if name == "webclient":
                     logger.warning(NO_MORE_WEB_CLIENT_WARNING)
                     return
-
-
-def _check_resource_config(listeners: Iterable[ListenerConfig]) -> None:
-    resource_names = {
-        res_name
-        for listener in listeners
-        if listener.http_options
-        for res in listener.http_options.resources
-        for res_name in res.names
-    }
-
-    for resource in resource_names:
-        if resource == "consent":
-            try:
-                check_requirements("resources.consent")
-            except DependencyException as e:
-                raise ConfigError(e.message)
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 73b7296399..4427676167 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -12,11 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import os
 from typing import Any, Dict
 
-import pkg_resources
-
 from ._base import Config
 
 
@@ -29,22 +26,32 @@ class SSOConfig(Config):
     def read_config(self, config, **kwargs):
         sso_config = config.get("sso") or {}  # type: Dict[str, Any]
 
-        # Pick a template directory in order of:
-        # * The sso-specific template_dir
-        # * /path/to/synapse/install/res/templates
+        # The sso-specific template_dir
         template_dir = sso_config.get("template_dir")
-        if not template_dir:
-            template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
 
-        self.sso_template_dir = template_dir
-        self.sso_account_deactivated_template = self.read_file(
-            os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
-            "sso_account_deactivated_template",
+        # Read templates from disk
+        (
+            self.sso_redirect_confirm_template,
+            self.sso_auth_confirm_template,
+            self.sso_error_template,
+            sso_account_deactivated_template,
+            sso_auth_success_template,
+        ) = self.read_templates(
+            [
+                "sso_redirect_confirm.html",
+                "sso_auth_confirm.html",
+                "sso_error.html",
+                "sso_account_deactivated.html",
+                "sso_auth_success.html",
+            ],
+            template_dir,
         )
-        self.sso_auth_success_template = self.read_file(
-            os.path.join(self.sso_template_dir, "sso_auth_success.html"),
-            "sso_auth_success_template",
+
+        # These templates have no placeholders, so render them here
+        self.sso_account_deactivated_template = (
+            sso_account_deactivated_template.render()
         )
+        self.sso_auth_success_template = sso_auth_success_template.render()
 
         self.sso_client_whitelist = sso_config.get("client_whitelist") or []
 
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index c784a71508..f23e42cdf9 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -13,12 +13,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import List, Union
+
 import attr
 
 from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
 from .server import ListenerConfig, parse_listener_def
 
 
+def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
+    """Helper for allowing parsing a string or list of strings to a config
+    option expecting a list of strings.
+    """
+
+    if isinstance(obj, str):
+        return [obj]
+    return obj
+
+
 @attr.s
 class InstanceLocationConfig:
     """The host and port to talk to an instance via HTTP replication.
@@ -33,11 +45,13 @@ class WriterLocations:
     """Specifies the instances that write various streams.
 
     Attributes:
-        events: The instance that writes to the event and backfill streams.
-        events: The instance that writes to the typing stream.
+        events: The instances that write to the event and backfill streams.
+        typing: The instance that writes to the typing stream.
     """
 
-    events = attr.ib(default="master", type=str)
+    events = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter
+    )
     typing = attr.ib(default="master", type=str)
 
 
@@ -105,15 +119,18 @@ class WorkerConfig(Config):
         writers = config.get("stream_writers") or {}
         self.writers = WriterLocations(**writers)
 
-        # Check that the configured writer for events and typing also appears in
+        # Check that the configured writers for events and typing also appears in
         # `instance_map`.
         for stream in ("events", "typing"):
-            instance = getattr(self.writers, stream)
-            if instance != "master" and instance not in self.instance_map:
-                raise ConfigError(
-                    "Instance %r is configured to write %s but does not appear in `instance_map` config."
-                    % (instance, stream)
-                )
+            instances = _instance_to_list_converter(getattr(self.writers, stream))
+            for instance in instances:
+                if instance != "master" and instance not in self.instance_map:
+                    raise ConfigError(
+                        "Instance %r is configured to write %s but does not appear in `instance_map` config."
+                        % (instance, stream)
+                    )
+
+        self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 777c0f00b1..2b03f5ac76 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -83,7 +83,7 @@ class ServerContextFactory(ContextFactory):
 
 
 @implementer(IPolicyForHTTPS)
-class FederationPolicyForHTTPS(object):
+class FederationPolicyForHTTPS:
     """Factory for Twisted SSLClientConnectionCreators that are used to make connections
     to remote servers for federation.
 
@@ -152,7 +152,7 @@ class FederationPolicyForHTTPS(object):
 
 
 @implementer(IPolicyForHTTPS)
-class RegularPolicyForHTTPS(object):
+class RegularPolicyForHTTPS:
     """Factory for Twisted SSLClientConnectionCreators that are used to make connections
     to remote servers, for other than federation.
 
@@ -189,7 +189,7 @@ def _context_info_cb(ssl_connection, where, ret):
 
 
 @implementer(IOpenSSLClientConnectionCreator)
-class SSLClientConnectionCreator(object):
+class SSLClientConnectionCreator:
     """Creates openssl connection objects for client connections.
 
     Replaces twisted.internet.ssl.ClientTLSOptions
@@ -214,7 +214,7 @@ class SSLClientConnectionCreator(object):
         return connection
 
 
-class ConnectionVerifier(object):
+class ConnectionVerifier:
     """Set the SNI, and do cert verification
 
     This is a thing which is attached to the TLSMemoryBIOProtocol, and is called by
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 28ef7cfdb9..32c31b1cd1 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
 
 
 @attr.s(slots=True, cmp=False)
-class VerifyJsonRequest(object):
+class VerifyJsonRequest:
     """
     A request to verify a JSON object.
 
@@ -96,7 +96,7 @@ class KeyLookupError(ValueError):
     pass
 
 
-class Keyring(object):
+class Keyring:
     def __init__(self, hs, key_fetchers=None):
         self.clock = hs.get_clock()
 
@@ -420,7 +420,7 @@ class Keyring(object):
         remaining_requests.difference_update(completed)
 
 
-class KeyFetcher(object):
+class KeyFetcher:
     async def get_keys(self, keys_to_fetch):
         """
         Args:
@@ -456,7 +456,7 @@ class StoreKeyFetcher(KeyFetcher):
         return keys
 
 
-class BaseV2KeyFetcher(object):
+class BaseV2KeyFetcher:
     def __init__(self, hs):
         self.store = hs.get_datastore()
         self.config = hs.get_config()
@@ -757,9 +757,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             except Exception:
                 logger.exception("Error getting keys %s from %s", key_ids, server_name)
 
-        return await yieldable_gather_results(
-            get_key, keys_to_fetch.items()
-        ).addCallback(lambda _: results)
+        await yieldable_gather_results(get_key, keys_to_fetch.items())
+        return results
 
     async def get_server_verify_key_v2_direct(self, server_name, key_ids):
         """
@@ -769,7 +768,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             key_ids (iterable[str]):
 
         Returns:
-            Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+            dict[str, FetchKeyResult]: map from key ID to lookup result
 
         Raises:
             KeyLookupError if there was a problem making the lookup
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c0981eee62..8c907ad596 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -47,7 +47,7 @@ def check(
     Args:
         room_version_obj: the version of the room
         event: the event being checked.
-        auth_events (dict: event-key -> event): the existing room state.
+        auth_events: the existing room state.
 
     Raises:
         AuthError if the checks fail
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index cc5deca75b..bf800a3852 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -18,7 +18,7 @@
 import abc
 import os
 from distutils.util import strtobool
-from typing import Dict, Optional, Type
+from typing import Dict, Optional, Tuple, Type
 
 from unpaddedbase64 import encode_base64
 
@@ -96,7 +96,7 @@ class DefaultDictProperty(DictProperty):
         return instance._dict.get(self.key, self.default)
 
 
-class _EventInternalMetadata(object):
+class _EventInternalMetadata:
     __slots__ = ["_dict"]
 
     def __init__(self, internal_metadata_dict: JsonDict):
@@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
     # be here
     before = DictProperty("before")  # type: str
     after = DictProperty("after")  # type: str
-    order = DictProperty("order")  # type: int
+    order = DictProperty("order")  # type: Tuple[int, int]
 
     def get_dict(self) -> JsonDict:
         return dict(self._dict)
@@ -133,6 +133,8 @@ class _EventInternalMetadata(object):
         rejection. This is needed as those events are marked as outliers, but
         they still need to be processed as if they're new events (e.g. updating
         invite state in the database, relaying to clients, etc).
+
+        (Added in synapse 0.99.0, so may be unreliable for events received before that)
         """
         return self._dict.get("out_of_band_membership", False)
 
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 9ed24380dd..b6c47be646 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Optional
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import attr
 from nacl.signing import SigningKey
@@ -36,7 +36,7 @@ from synapse.util.stringutils import random_string
 
 
 @attr.s(slots=True, cmp=False, frozen=True)
-class EventBuilder(object):
+class EventBuilder:
     """A format independent event builder used to build up the event content
     before signing the event.
 
@@ -97,14 +97,14 @@ class EventBuilder(object):
     def is_state(self):
         return self._state_key is not None
 
-    async def build(self, prev_event_ids):
+    async def build(self, prev_event_ids: List[str]) -> EventBase:
         """Transform into a fully signed and hashed event
 
         Args:
-            prev_event_ids (list[str]): The event IDs to use as the prev events
+            prev_event_ids: The event IDs to use as the prev events
 
         Returns:
-            FrozenEvent
+            The signed and hashed event.
         """
 
         state_ids = await self._state.get_current_state_ids(
@@ -114,8 +114,13 @@ class EventBuilder(object):
 
         format_version = self.room_version.event_format
         if format_version == EventFormatVersions.V1:
-            auth_events = await self._store.add_event_hashes(auth_ids)
-            prev_events = await self._store.add_event_hashes(prev_event_ids)
+            # The types of auth/prev events changes between event versions.
+            auth_events = await self._store.add_event_hashes(
+                auth_ids
+            )  # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
+            prev_events = await self._store.add_event_hashes(
+                prev_event_ids
+            )  # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
         else:
             auth_events = auth_ids
             prev_events = prev_event_ids
@@ -138,7 +143,7 @@ class EventBuilder(object):
             "unsigned": self.unsigned,
             "depth": depth,
             "prev_state": [],
-        }
+        }  # type: Dict[str, Any]
 
         if self.is_state():
             event_dict["state_key"] = self._state_key
@@ -159,7 +164,7 @@ class EventBuilder(object):
         )
 
 
-class EventBuilderFactory(object):
+class EventBuilderFactory:
     def __init__(self, hs):
         self.clock = hs.get_clock()
         self.hostname = hs.hostname
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 1ffc9525d1..b0fc859a47 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,16 +15,17 @@
 # limitations under the License.
 
 import inspect
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional, Tuple
 
-from synapse.spam_checker_api import SpamCheckerApi
+from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
+from synapse.types import Collection
 
 MYPY = False
 if MYPY:
     import synapse.server
 
 
-class SpamChecker(object):
+class SpamChecker:
     def __init__(self, hs: "synapse.server.HomeServer"):
         self.spam_checkers = []  # type: List[Any]
 
@@ -160,3 +161,33 @@ class SpamChecker(object):
                     return True
 
         return False
+
+    def check_registration_for_spam(
+        self,
+        email_threepid: Optional[dict],
+        username: Optional[str],
+        request_info: Collection[Tuple[str, str]],
+    ) -> RegistrationBehaviour:
+        """Checks if we should allow the given registration request.
+
+        Args:
+            email_threepid: The email threepid used for registering, if any
+            username: The request user name, if any
+            request_info: List of tuples of user agent and IP that
+                were used during the registration process.
+
+        Returns:
+            Enum for how the request should be handled
+        """
+
+        for spam_checker in self.spam_checkers:
+            # For backwards compatibility, only run if the method exists on the
+            # spam checker
+            checker = getattr(spam_checker, "check_registration_for_spam", None)
+            if checker:
+                behaviour = checker(email_threepid, username, request_info)
+                assert isinstance(behaviour, RegistrationBehaviour)
+                if behaviour != RegistrationBehaviour.ALLOW:
+                    return behaviour
+
+        return RegistrationBehaviour.ALLOW
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 2956a64234..9d5310851c 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -18,7 +18,7 @@ from synapse.events.snapshot import EventContext
 from synapse.types import Requester
 
 
-class ThirdPartyEventRules(object):
+class ThirdPartyEventRules:
     """Allows server admins to provide a Python module implementing an extra
     set of rules to apply when processing events.
 
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 2d42e268c6..32c73d3413 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -322,7 +322,7 @@ def serialize_event(
     return d
 
 
-class EventClientSerializer(object):
+class EventClientSerializer:
     """Serializes events that are to be sent to clients.
 
     This is used for bundling extra information with any events to be sent to
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 588d222f36..9df35b54ba 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -20,7 +20,7 @@ from synapse.events.utils import validate_canonicaljson
 from synapse.types import EventID, RoomID, UserID
 
 
-class EventValidator(object):
+class EventValidator:
     def validate_new(self, event, config):
         """Validates the event has roughly the right format
 
@@ -74,15 +74,14 @@ class EventValidator(object):
                         )
 
         if event.type == EventTypes.Retention:
-            self._validate_retention(event, config)
+            self._validate_retention(event)
 
-    def _validate_retention(self, event, config):
+    def _validate_retention(self, event):
         """Checks that an event that defines the retention policy for a room respects the
-        boundaries imposed by the server's administrator.
+        format enforced by the spec.
 
         Args:
             event (FrozenEvent): The event to validate.
-            config (Config): The homeserver's configuration.
         """
         min_lifetime = event.content.get("min_lifetime")
         max_lifetime = event.content.get("max_lifetime")
@@ -95,32 +94,6 @@ class EventValidator(object):
                     errcode=Codes.BAD_JSON,
                 )
 
-            if (
-                config.retention_allowed_lifetime_min is not None
-                and min_lifetime < config.retention_allowed_lifetime_min
-            ):
-                raise SynapseError(
-                    code=400,
-                    msg=(
-                        "'min_lifetime' can't be lower than the minimum allowed"
-                        " value enforced by the server's administrator"
-                    ),
-                    errcode=Codes.BAD_JSON,
-                )
-
-            if (
-                config.retention_allowed_lifetime_max is not None
-                and min_lifetime > config.retention_allowed_lifetime_max
-            ):
-                raise SynapseError(
-                    code=400,
-                    msg=(
-                        "'min_lifetime' can't be greater than the maximum allowed"
-                        " value enforced by the server's administrator"
-                    ),
-                    errcode=Codes.BAD_JSON,
-                )
-
         if max_lifetime is not None:
             if not isinstance(max_lifetime, int):
                 raise SynapseError(
@@ -129,32 +102,6 @@ class EventValidator(object):
                     errcode=Codes.BAD_JSON,
                 )
 
-            if (
-                config.retention_allowed_lifetime_min is not None
-                and max_lifetime < config.retention_allowed_lifetime_min
-            ):
-                raise SynapseError(
-                    code=400,
-                    msg=(
-                        "'max_lifetime' can't be lower than the minimum allowed value"
-                        " enforced by the server's administrator"
-                    ),
-                    errcode=Codes.BAD_JSON,
-                )
-
-            if (
-                config.retention_allowed_lifetime_max is not None
-                and max_lifetime > config.retention_allowed_lifetime_max
-            ):
-                raise SynapseError(
-                    code=400,
-                    msg=(
-                        "'max_lifetime' can't be greater than the maximum allowed"
-                        " value enforced by the server's administrator"
-                    ),
-                    errcode=Codes.BAD_JSON,
-                )
-
         if (
             min_lifetime is not None
             and max_lifetime is not None
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 420df2385f..38aa47963f 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -39,7 +39,7 @@ from synapse.types import JsonDict, get_domain_from_id
 logger = logging.getLogger(__name__)
 
 
-class FederationBase(object):
+class FederationBase:
     def __init__(self, hs):
         self.hs = hs
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 11c5d63298..218df884b0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -28,7 +28,6 @@ from typing import (
     Union,
 )
 
-from canonicaljson import json
 from prometheus_client import Counter, Histogram
 
 from twisted.internet import defer
@@ -63,7 +62,7 @@ from synapse.replication.http.federation import (
     ReplicationGetQueryRestServlet,
 )
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import glob_to_regex, unwrapFirstError
+from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -551,7 +550,7 @@ class FederationServer(FederationBase):
             for device_id, keys in device_keys.items():
                 for key_id, json_str in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json.loads(json_str)
+                        key_id: json_decoder.decode(json_str)
                     }
 
         logger.info(
@@ -786,7 +785,7 @@ def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
     return regex.match(server_name)
 
 
-class FederationHandlerRegistry(object):
+class FederationHandlerRegistry:
     """Allows classes to register themselves as handlers for a given EDU or
     query type for incoming federation traffic.
     """
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index d68b4bd670..079e2b2fe0 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -20,13 +20,16 @@ These actions are mostly only used by the :py:mod:`.replication` module.
 """
 
 import logging
+from typing import Optional, Tuple
 
+from synapse.federation.units import Transaction
 from synapse.logging.utils import log_function
+from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
 
-class TransactionActions(object):
+class TransactionActions:
     """ Defines persistence actions that relate to handling Transactions.
     """
 
@@ -34,30 +37,32 @@ class TransactionActions(object):
         self.store = datastore
 
     @log_function
-    def have_responded(self, origin, transaction):
-        """ Have we already responded to a transaction with the same id and
+    async def have_responded(
+        self, origin: str, transaction: Transaction
+    ) -> Optional[Tuple[int, JsonDict]]:
+        """Have we already responded to a transaction with the same id and
         origin?
 
         Returns:
-            Deferred: Results in `None` if we have not previously responded to
-            this transaction or a 2-tuple of `(int, dict)` representing the
-            response code and response body.
+            `None` if we have not previously responded to this transaction or a
+            2-tuple of `(int, dict)` representing the response code and response body.
         """
-        if not transaction.transaction_id:
+        transaction_id = transaction.transaction_id  # type: ignore
+        if not transaction_id:
             raise RuntimeError("Cannot persist a transaction with no transaction_id")
 
-        return self.store.get_received_txn_response(transaction.transaction_id, origin)
+        return await self.store.get_received_txn_response(transaction_id, origin)
 
     @log_function
-    def set_response(self, origin, transaction, code, response):
-        """ Persist how we responded to a transaction.
-
-        Returns:
-            Deferred
+    async def set_response(
+        self, origin: str, transaction: Transaction, code: int, response: JsonDict
+    ) -> None:
+        """Persist how we responded to a transaction.
         """
-        if not transaction.transaction_id:
+        transaction_id = transaction.transaction_id  # type: ignore
+        if not transaction_id:
             raise RuntimeError("Cannot persist a transaction with no transaction_id")
 
-        return self.store.set_received_txn_response(
-            transaction.transaction_id, origin, code, response
+        await self.store.set_received_txn_response(
+            transaction_id, origin, code, response
         )
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 2b0ab2dcbf..8e46957d15 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -37,8 +37,8 @@ from sortedcontainers import SortedDict
 
 from twisted.internet import defer
 
+from synapse.api.presence import UserPresenceState
 from synapse.metrics import LaterGauge
-from synapse.storage.presence import UserPresenceState
 from synapse.util.metrics import Measure
 
 from .units import Edu
@@ -46,7 +46,7 @@ from .units import Edu
 logger = logging.getLogger(__name__)
 
 
-class FederationRemoteSendQueue(object):
+class FederationRemoteSendQueue:
     """A drop in replacement for FederationSender"""
 
     def __init__(self, hs):
@@ -365,7 +365,7 @@ class FederationRemoteSendQueue(object):
         )
 
 
-class BaseFederationRow(object):
+class BaseFederationRow:
     """Base class for rows to be sent in the federation stream.
 
     Specifies how to identify, serialize and deserialize the different types.
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 94cc63001e..41a726878d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -22,6 +22,7 @@ from twisted.internet import defer
 
 import synapse
 import synapse.metrics
+from synapse.api.presence import UserPresenceState
 from synapse.events import EventBase
 from synapse.federation.sender.per_destination_queue import PerDestinationQueue
 from synapse.federation.sender.transaction_manager import TransactionManager
@@ -39,7 +40,6 @@ from synapse.metrics import (
     events_processed_counter,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.presence import UserPresenceState
 from synapse.types import ReadReceipt
 from synapse.util.metrics import Measure, measure_func
 
@@ -56,7 +56,7 @@ sent_pdus_destination_dist_total = Counter(
 )
 
 
-class FederationSender(object):
+class FederationSender:
     def __init__(self, hs: "synapse.server.HomeServer"):
         self.hs = hs
         self.server_name = hs.hostname
@@ -108,8 +108,6 @@ class FederationSender(object):
             ),
         )
 
-        self._order = 1
-
         self._is_processing = False
         self._last_poked_id = -1
 
@@ -211,7 +209,7 @@ class FederationSender(object):
                     logger.debug("Sending %s to %r", event, destinations)
 
                     if destinations:
-                        self._send_pdu(event, destinations)
+                        await self._send_pdu(event, destinations)
 
                         now = self.clock.time_msec()
                         ts = await self.store.get_received_ts(event.event_id)
@@ -267,14 +265,11 @@ class FederationSender(object):
         finally:
             self._is_processing = False
 
-    def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
+    async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
         # We loop through all destinations to see whether we already have
         # a transaction in progress. If we do, stick it in the pending_pdus
         # table and we'll get back to it later.
 
-        order = self._order
-        self._order += 1
-
         destinations = set(destinations)
         destinations.discard(self.server_name)
         logger.debug("Sending to: %s", str(destinations))
@@ -285,8 +280,15 @@ class FederationSender(object):
         sent_pdus_destination_dist_total.inc(len(destinations))
         sent_pdus_destination_dist_count.inc()
 
+        # track the fact that we have a PDU for these destinations,
+        # to allow us to perform catch-up later on if the remote is unreachable
+        # for a while.
+        await self.store.store_destination_rooms_entries(
+            destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
+        )
+
         for destination in destinations:
-            self._get_per_destination_queue(destination).send_pdu(pdu, order)
+            self._get_per_destination_queue(destination).send_pdu(pdu)
 
     async def send_read_receipt(self, receipt: ReadReceipt) -> None:
         """Send a RR to any other servers in the room
@@ -329,10 +331,10 @@ class FederationSender(object):
         room_id = receipt.room_id
 
         # Work out which remote servers should be poked and poke them.
-        domains = await self.state.get_current_hosts_in_room(room_id)
+        domains_set = await self.state.get_current_hosts_in_room(room_id)
         domains = [
             d
-            for d in domains
+            for d in domains_set
             if d != self.server_name
             and self._federation_shard_config.should_handle(self._instance_name, d)
         ]
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index dd150f89a6..2657767fd1 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 import datetime
 import logging
-from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
 
 from prometheus_client import Counter
 
@@ -24,12 +24,12 @@ from synapse.api.errors import (
     HttpResponseException,
     RequestSendFailed,
 )
+from synapse.api.presence import UserPresenceState
 from synapse.events import EventBase
 from synapse.federation.units import Edu
 from synapse.handlers.presence import format_user_presence_state
 from synapse.metrics import sent_transactions_counter
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.presence import UserPresenceState
 from synapse.types import ReadReceipt
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
@@ -53,7 +53,7 @@ sent_edus_by_type = Counter(
 )
 
 
-class PerDestinationQueue(object):
+class PerDestinationQueue:
     """
     Manages the per-destination transmission queues.
 
@@ -92,8 +92,23 @@ class PerDestinationQueue(object):
         self._destination = destination
         self.transmission_loop_running = False
 
-        # a list of tuples of (pending pdu, order)
-        self._pending_pdus = []  # type: List[Tuple[EventBase, int]]
+        # True whilst we are sending events that the remote homeserver missed
+        # because it was unreachable. We start in this state so we can perform
+        # catch-up at startup.
+        # New events will only be sent once this is finished, at which point
+        # _catching_up is flipped to False.
+        self._catching_up = True  # type: bool
+
+        # The stream_ordering of the most recent PDU that was discarded due to
+        # being in catch-up mode.
+        self._catchup_last_skipped = 0  # type: int
+
+        # Cache of the last successfully-transmitted stream ordering for this
+        # destination (we are the only updater so this is safe)
+        self._last_successful_stream_ordering = None  # type: Optional[int]
+
+        # a list of pending PDUs
+        self._pending_pdus = []  # type: List[EventBase]
 
         # XXX this is never actually used: see
         # https://github.com/matrix-org/synapse/issues/7549
@@ -132,14 +147,19 @@ class PerDestinationQueue(object):
             + len(self._pending_edus_keyed)
         )
 
-    def send_pdu(self, pdu: EventBase, order: int) -> None:
+    def send_pdu(self, pdu: EventBase) -> None:
         """Add a PDU to the queue, and start the transmission loop if necessary
 
         Args:
             pdu: pdu to send
-            order
         """
-        self._pending_pdus.append((pdu, order))
+        if not self._catching_up or self._last_successful_stream_ordering is None:
+            # only enqueue the PDU if we are not catching up (False) or do not
+            # yet know if we have anything to catch up (None)
+            self._pending_pdus.append(pdu)
+        else:
+            self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
+
         self.attempt_new_transaction()
 
     def send_presence(self, states: Iterable[UserPresenceState]) -> None:
@@ -185,7 +205,7 @@ class PerDestinationQueue(object):
         returns immediately. Otherwise kicks off the process of sending a
         transaction in the background.
         """
-        # list of (pending_pdu, deferred, order)
+
         if self.transmission_loop_running:
             # XXX: this can get stuck on by a never-ending
             # request at which point pending_pdus just keeps growing.
@@ -210,7 +230,7 @@ class PerDestinationQueue(object):
         )
 
     async def _transaction_transmission_loop(self) -> None:
-        pending_pdus = []  # type: List[Tuple[EventBase, int]]
+        pending_pdus = []  # type: List[EventBase]
         try:
             self.transmission_loop_running = True
 
@@ -219,6 +239,13 @@ class PerDestinationQueue(object):
             # hence why we throw the result away.
             await get_retry_limiter(self._destination, self._clock, self._store)
 
+            if self._catching_up:
+                # we potentially need to catch-up first
+                await self._catch_up_transmission_loop()
+                if self._catching_up:
+                    # not caught up yet
+                    return
+
             pending_pdus = []
             while True:
                 # We have to keep 2 free slots for presence and rr_edus
@@ -326,6 +353,17 @@ class PerDestinationQueue(object):
 
                     self._last_device_stream_id = device_stream_id
                     self._last_device_list_stream_id = dev_list_id
+
+                    if pending_pdus:
+                        # we sent some PDUs and it was successful, so update our
+                        # last_successful_stream_ordering in the destinations table.
+                        final_pdu = pending_pdus[-1]
+                        last_successful_stream_ordering = (
+                            final_pdu.internal_metadata.stream_ordering
+                        )
+                        await self._store.set_destination_last_successful_stream_ordering(
+                            self._destination, last_successful_stream_ordering
+                        )
                 else:
                     break
         except NotRetryingDestination as e:
@@ -337,6 +375,30 @@ class PerDestinationQueue(object):
                     (e.retry_last_ts + e.retry_interval) / 1000.0
                 ),
             )
+
+            if e.retry_interval > 60 * 60 * 1000:
+                # we won't retry for another hour!
+                # (this suggests a significant outage)
+                # We drop pending EDUs because otherwise they will
+                # rack up indefinitely.
+                # (Dropping PDUs is already performed by `_start_catching_up`.)
+                # Note that:
+                # - the EDUs that are being dropped here are those that we can
+                #   afford to drop (specifically, only typing notifications,
+                #   read receipts and presence updates are being dropped here)
+                # - Other EDUs such as to_device messages are queued with a
+                #   different mechanism
+                # - this is all volatile state that would be lost if the
+                #   federation sender restarted anyway
+
+                # dropping read receipts is a bit sad but should be solved
+                # through another mechanism, because this is all volatile!
+                self._pending_edus = []
+                self._pending_edus_keyed = {}
+                self._pending_presence = {}
+                self._pending_rrs = {}
+
+            self._start_catching_up()
         except FederationDeniedError as e:
             logger.info(e)
         except HttpResponseException as e:
@@ -346,25 +408,107 @@ class PerDestinationQueue(object):
                 e.code,
                 e,
             )
+
+            self._start_catching_up()
         except RequestSendFailed as e:
             logger.warning(
                 "TX [%s] Failed to send transaction: %s", self._destination, e
             )
 
-            for p, _ in pending_pdus:
+            for p in pending_pdus:
                 logger.info(
                     "Failed to send event %s to %s", p.event_id, self._destination
                 )
+
+            self._start_catching_up()
         except Exception:
             logger.exception("TX [%s] Failed to send transaction", self._destination)
-            for p, _ in pending_pdus:
+            for p in pending_pdus:
                 logger.info(
                     "Failed to send event %s to %s", p.event_id, self._destination
                 )
+
+            self._start_catching_up()
         finally:
             # We want to be *very* sure we clear this after we stop processing
             self.transmission_loop_running = False
 
+    async def _catch_up_transmission_loop(self) -> None:
+        first_catch_up_check = self._last_successful_stream_ordering is None
+
+        if first_catch_up_check:
+            # first catchup so get last_successful_stream_ordering from database
+            self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering(
+                self._destination
+            )
+
+        if self._last_successful_stream_ordering is None:
+            # if it's still None, then this means we don't have the information
+            # in our database ­ we haven't successfully sent a PDU to this server
+            # (at least since the introduction of the feature tracking
+            # last_successful_stream_ordering).
+            # Sadly, this means we can't do anything here as we don't know what
+            # needs catching up — so catching up is futile; let's stop.
+            self._catching_up = False
+            return
+
+        # get at most 50 catchup room/PDUs
+        while True:
+            event_ids = await self._store.get_catch_up_room_event_ids(
+                self._destination, self._last_successful_stream_ordering,
+            )
+
+            if not event_ids:
+                # No more events to catch up on, but we can't ignore the chance
+                # of a race condition, so we check that no new events have been
+                # skipped due to us being in catch-up mode
+
+                if self._catchup_last_skipped > self._last_successful_stream_ordering:
+                    # another event has been skipped because we were in catch-up mode
+                    continue
+
+                # we are done catching up!
+                self._catching_up = False
+                break
+
+            if first_catch_up_check:
+                # as this is our check for needing catch-up, we may have PDUs in
+                # the queue from before we *knew* we had to do catch-up, so
+                # clear those out now.
+                self._start_catching_up()
+
+            # fetch the relevant events from the event store
+            # - redacted behaviour of REDACT is fine, since we only send metadata
+            #   of redacted events to the destination.
+            # - don't need to worry about rejected events as we do not actively
+            #   forward received events over federation.
+            catchup_pdus = await self._store.get_events_as_list(event_ids)
+            if not catchup_pdus:
+                raise AssertionError(
+                    "No events retrieved when we asked for %r. "
+                    "This should not happen." % event_ids
+                )
+
+            if logger.isEnabledFor(logging.INFO):
+                rooms = (p.room_id for p in catchup_pdus)
+                logger.info("Catching up rooms to %s: %r", self._destination, rooms)
+
+            success = await self._transaction_manager.send_new_transaction(
+                self._destination, catchup_pdus, []
+            )
+
+            if not success:
+                return
+
+            sent_transactions_counter.inc()
+            final_pdu = catchup_pdus[-1]
+            self._last_successful_stream_ordering = cast(
+                int, final_pdu.internal_metadata.stream_ordering
+            )
+            await self._store.set_destination_last_successful_stream_ordering(
+                self._destination, self._last_successful_stream_ordering
+            )
+
     def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
         if not self._pending_rrs:
             return
@@ -425,3 +569,12 @@ class PerDestinationQueue(object):
         ]
 
         return (edus, stream_id)
+
+    def _start_catching_up(self) -> None:
+        """
+        Marks this destination as being in catch-up mode.
+
+        This throws away the PDU queue.
+        """
+        self._catching_up = True
+        self._pending_pdus = []
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index c7f6cb3d73..c84072ab73 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,9 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, List, Tuple
-
-from canonicaljson import json
+from typing import TYPE_CHECKING, List
 
 from synapse.api.errors import HttpResponseException
 from synapse.events import EventBase
@@ -28,6 +26,7 @@ from synapse.logging.opentracing import (
     tags,
     whitelisted_homeserver,
 )
+from synapse.util import json_decoder
 from synapse.util.metrics import measure_func
 
 if TYPE_CHECKING:
@@ -36,7 +35,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class TransactionManager(object):
+class TransactionManager:
     """Helper class which handles building and sending transactions
 
     shared between PerDestinationQueue objects
@@ -54,11 +53,17 @@ class TransactionManager(object):
 
     @measure_func("_send_new_transaction")
     async def send_new_transaction(
-        self,
-        destination: str,
-        pending_pdus: List[Tuple[EventBase, int]],
-        pending_edus: List[Edu],
-    ):
+        self, destination: str, pdus: List[EventBase], edus: List[Edu],
+    ) -> bool:
+        """
+        Args:
+            destination: The destination to send to (e.g. 'example.org')
+            pdus: In-order list of PDUs to send
+            edus: List of EDUs to send
+
+        Returns:
+            True iff the transaction was successful
+        """
 
         # Make a transaction-sending opentracing span. This span follows on from
         # all the edus in that transaction. This needs to be done since there is
@@ -68,20 +73,14 @@ class TransactionManager(object):
         span_contexts = []
         keep_destination = whitelisted_homeserver(destination)
 
-        for edu in pending_edus:
+        for edu in edus:
             context = edu.get_context()
             if context:
-                span_contexts.append(extract_text_map(json.loads(context)))
+                span_contexts.append(extract_text_map(json_decoder.decode(context)))
             if keep_destination:
                 edu.strip_context()
 
         with start_active_span_follows_from("send_transaction", span_contexts):
-
-            # Sort based on the order field
-            pending_pdus.sort(key=lambda t: t[1])
-            pdus = [x[0] for x in pending_pdus]
-            edus = pending_edus
-
             success = True
 
             logger.debug("TX [%s] _attempt_new_transaction", destination)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9ea821dbb2..17a10f622e 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -30,7 +30,7 @@ from synapse.logging.utils import log_function
 logger = logging.getLogger(__name__)
 
 
-class TransportLayerClient(object):
+class TransportLayerClient:
     """Sends federation HTTP requests to other servers"""
 
     def __init__(self, hs):
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 5e111aa902..9325e0f857 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -100,7 +100,7 @@ class NoAuthenticationError(AuthenticationError):
     pass
 
 
-class Authenticator(object):
+class Authenticator:
     def __init__(self, hs: HomeServer):
         self._clock = hs.get_clock()
         self.keyring = hs.get_keyring()
@@ -228,7 +228,7 @@ def _parse_auth_header(header_bytes):
         )
 
 
-class BaseFederationServlet(object):
+class BaseFederationServlet:
     """Abstract base class for federation servlet classes.
 
     The servlet object should have a PATH attribute which takes the form of a regexp to
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 6b32e0dcbf..64d98fc8f6 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -107,9 +107,7 @@ class Transaction(JsonEncodedObject):
         if "edus" in kwargs and not kwargs["edus"]:
             del kwargs["edus"]
 
-        super(Transaction, self).__init__(
-            transaction_id=transaction_id, pdus=pdus, **kwargs
-        )
+        super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs)
 
     @staticmethod
     def create_new(pdus, **kwargs):
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index e674bf44a2..a86b3debc5 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -60,7 +60,7 @@ DEFAULT_ATTESTATION_JITTER = (0.9, 1.3)
 UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
 
 
-class GroupAttestationSigning(object):
+class GroupAttestationSigning:
     """Creates and verifies group attestations.
     """
 
@@ -124,7 +124,7 @@ class GroupAttestationSigning(object):
         )
 
 
-class GroupAttestionRenewer(object):
+class GroupAttestionRenewer:
     """Responsible for sending and receiving attestation updates.
     """
 
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 8cb922ddc7..1dd20ee4e1 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
 # TODO: Flairs
 
 
-class GroupsServerWorkerHandler(object):
+class GroupsServerWorkerHandler:
     def __init__(self, hs):
         self.hs = hs
         self.store = hs.get_datastore()
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 2dd183018a..286f0054be 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -20,7 +20,7 @@ from .identity import IdentityHandler
 from .search import SearchHandler
 
 
-class Handlers(object):
+class Handlers:
 
     """ Deprecated. A collection of handlers.
 
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index ba2bf99800..0206320e96 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -25,7 +25,7 @@ from synapse.types import UserID
 logger = logging.getLogger(__name__)
 
 
-class BaseHandler(object):
+class BaseHandler:
     """
     Common base class for the event handlers.
     """
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index a8d3fbc6de..9112a0ab86 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 
-class AccountDataEventSource(object):
+class AccountDataEventSource:
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 590135d19c..4caf6d591a 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -26,15 +26,10 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import UserID
 from synapse.util import stringutils
 
-try:
-    from synapse.push.mailer import load_jinja2_templates
-except ImportError:
-    load_jinja2_templates = None
-
 logger = logging.getLogger(__name__)
 
 
-class AccountValidityHandler(object):
+class AccountValidityHandler:
     def __init__(self, hs):
         self.hs = hs
         self.config = hs.config
@@ -47,9 +42,11 @@ class AccountValidityHandler(object):
         if (
             self._account_validity.enabled
             and self._account_validity.renew_by_email_enabled
-            and load_jinja2_templates
         ):
             # Don't do email-specific configuration if renewal by email is disabled.
+            self._template_html = self.config.account_validity_template_html
+            self._template_text = self.config.account_validity_template_text
+
             try:
                 app_name = self.hs.config.email_app_name
 
@@ -65,17 +62,6 @@ class AccountValidityHandler(object):
 
             self._raw_from = email.utils.parseaddr(self._from_string)[1]
 
-            self._template_html, self._template_text = load_jinja2_templates(
-                self.config.email_template_dir,
-                [
-                    self.config.email_expiry_template_html,
-                    self.config.email_expiry_template_text,
-                ],
-                apply_format_ts_filter=True,
-                apply_mxc_to_http_filter=True,
-                public_baseurl=self.config.public_baseurl,
-            )
-
             # Check the renewal emails to send and send them every 30min.
             def send_emails():
                 # run as a background process to make sure that the database transactions
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 7666d3abcd..8476256a59 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -34,7 +34,7 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
 --------------------------------------------------------------------------------"""
 
 
-class AcmeHandler(object):
+class AcmeHandler:
     def __init__(self, hs):
         self.hs = hs
         self.reactor = hs.get_reactor()
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index e1d4224e74..7294649d71 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -76,9 +76,9 @@ def create_issuing_service(reactor, acme_url, account_key_file, well_known_resou
     )
 
 
-@attr.s
+@attr.s(slots=True)
 @implementer(ICertificateStore)
-class ErsatzStore(object):
+class ErsatzStore:
     """
     A store that only stores in memory.
     """
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 506bb2b275..5e5a64037d 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -125,8 +125,8 @@ class AdminHandler(BaseHandler):
             else:
                 stream_ordering = room.stream_ordering
 
-            from_key = str(RoomStreamToken(0, 0))
-            to_key = str(RoomStreamToken(None, stream_ordering))
+            from_key = RoomStreamToken(0, 0)
+            to_key = RoomStreamToken(None, stream_ordering)
 
             written_events = set()  # Events that we've processed in this room
 
@@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
                 if not events:
                     break
 
-                from_key = events[-1].internal_metadata.after
+                from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
 
                 events = await filter_events_for_client(self.storage, user_id, events)
 
@@ -197,7 +197,7 @@ class AdminHandler(BaseHandler):
         return writer.finished()
 
 
-class ExfiltrationWriter(object):
+class ExfiltrationWriter:
     """Interface used to specify how to write exported data.
     """
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index c9044a5019..9d4e87dad6 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
 events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
 
 
-class ApplicationServicesHandler(object):
+class ApplicationServicesHandler:
     def __init__(self, hs):
         self.store = hs.get_datastore()
         self.is_mine_id = hs.is_mine_id
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c24e7bafe0..4e658d9a48 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -42,9 +42,9 @@ from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import ModuleApi
-from synapse.push.mailer import load_jinja2_templates
-from synapse.types import Requester, UserID
+from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
+from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
 from ._base import BaseHandler
@@ -52,6 +52,91 @@ from ._base import BaseHandler
 logger = logging.getLogger(__name__)
 
 
+def convert_client_dict_legacy_fields_to_identifier(
+    submission: JsonDict,
+) -> Dict[str, str]:
+    """
+    Convert a legacy-formatted login submission to an identifier dict.
+
+    Legacy login submissions (used in both login and user-interactive authentication)
+    provide user-identifying information at the top-level instead.
+
+    These are now deprecated and replaced with identifiers:
+    https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
+
+    Args:
+        submission: The client dict to convert
+
+    Returns:
+        The matching identifier dict
+
+    Raises:
+        SynapseError: If the format of the client dict is invalid
+    """
+    identifier = submission.get("identifier", {})
+
+    # Generate an m.id.user identifier if "user" parameter is present
+    user = submission.get("user")
+    if user:
+        identifier = {"type": "m.id.user", "user": user}
+
+    # Generate an m.id.thirdparty identifier if "medium" and "address" parameters are present
+    medium = submission.get("medium")
+    address = submission.get("address")
+    if medium and address:
+        identifier = {
+            "type": "m.id.thirdparty",
+            "medium": medium,
+            "address": address,
+        }
+
+    # We've converted valid, legacy login submissions to an identifier. If the
+    # submission still doesn't have an identifier, it's invalid
+    if not identifier:
+        raise SynapseError(400, "Invalid login submission", Codes.INVALID_PARAM)
+
+    # Ensure the identifier has a type
+    if "type" not in identifier:
+        raise SynapseError(
+            400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
+        )
+
+    return identifier
+
+
+def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
+    """
+    Convert a phone login identifier type to a generic threepid identifier.
+
+    Args:
+        identifier: Login identifier dict of type 'm.id.phone'
+
+    Returns:
+        An equivalent m.id.thirdparty identifier dict
+    """
+    if "country" not in identifier or (
+        # The specification requires a "phone" field, while Synapse used to require a "number"
+        # field. Accept both for backwards compatibility.
+        "phone" not in identifier
+        and "number" not in identifier
+    ):
+        raise SynapseError(
+            400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
+        )
+
+    # Accept both "phone" and "number" as valid keys in m.id.phone
+    phone_number = identifier.get("phone", identifier["number"])
+
+    # Convert user-provided phone number to a consistent representation
+    msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
+
+    return {
+        "type": "m.id.thirdparty",
+        "medium": "msisdn",
+        "address": msisdn,
+    }
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -132,18 +217,17 @@ class AuthHandler(BaseHandler):
         # after the SSO completes and before redirecting them back to their client.
         # It notifies the user they are about to give access to their matrix account
         # to the client.
-        self._sso_redirect_confirm_template = load_jinja2_templates(
-            hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
-        )[0]
+        self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
+
         # The following template is shown during user interactive authentication
         # in the fallback auth scenario. It notifies the user that they are
         # authenticating for an operation to occur on their account.
-        self._sso_auth_confirm_template = load_jinja2_templates(
-            hs.config.sso_template_dir, ["sso_auth_confirm.html"],
-        )[0]
+        self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
+
         # The following template is shown after a successful user interactive
         # authentication session. It tells the user they can close the window.
         self._sso_auth_success_template = hs.config.sso_auth_success_template
+
         # The following template is shown during the SSO authentication process if
         # the account is deactivated.
         self._sso_account_deactivated_template = (
@@ -366,6 +450,14 @@ class AuthHandler(BaseHandler):
             # authentication flow.
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+
+        await self.store.add_user_agent_ip_to_ui_auth_session(
+            session.session_id, user_agent, clientip
+        )
+
         if not authdict:
             raise InteractiveAuthIncompleteError(
                 session.session_id, self._auth_dict_for_flows(flows, session.session_id)
@@ -1143,8 +1235,8 @@ class AuthHandler(BaseHandler):
         return urllib.parse.urlunparse(url_parts)
 
 
-@attr.s
-class MacaroonGenerator(object):
+@attr.s(slots=True)
+class MacaroonGenerator:
 
     hs = attr.ib()
 
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 786e608fa2..a4cc4b9a5a 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -35,6 +35,7 @@ class CasHandler:
     """
 
     def __init__(self, hs):
+        self.hs = hs
         self._hostname = hs.hostname
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
@@ -210,8 +211,16 @@ class CasHandler:
 
         else:
             if not registered_user_id:
+                # Pull out the user-agent and IP from the request.
+                user_agent = request.requestHeaders.getRawHeaders(
+                    b"User-Agent", default=[b""]
+                )[0].decode("ascii", "surrogateescape")
+                ip_address = self.hs.get_ip_from_request(request)
+
                 registered_user_id = await self._registration_handler.register_user(
-                    localpart=localpart, default_display_name=user_display_name
+                    localpart=localpart,
+                    default_display_name=user_display_name,
+                    user_agent_ips=(user_agent, ip_address),
                 )
 
             await self._auth_handler.complete_sso_login(
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index db417d60de..4b0a4f96cc 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -29,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import (
     RoomStreamToken,
+    StreamToken,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
@@ -104,18 +105,15 @@ class DeviceWorkerHandler(BaseHandler):
 
     @trace
     @measure_func("device.get_user_ids_changed")
-    async def get_user_ids_changed(self, user_id, from_token):
+    async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
         """Get list of users that have had the devices updated, or have newly
         joined a room, that `user_id` may be interested in.
-
-        Args:
-            user_id (str)
-            from_token (StreamToken)
         """
 
         set_tag("user_id", user_id)
         set_tag("from_token", from_token)
-        now_room_key = await self.store.get_room_events_max_id()
+        now_room_id = self.store.get_room_max_stream_ordering()
+        now_room_key = RoomStreamToken(None, now_room_id)
 
         room_ids = await self.store.get_rooms_for_user(user_id)
 
@@ -142,7 +140,7 @@ class DeviceWorkerHandler(BaseHandler):
         )
         rooms_changed.update(event.room_id for event in member_events)
 
-        stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
+        stream_ordering = from_token.room_key.stream
 
         possibly_changed = set(changed)
         possibly_left = set()
@@ -234,7 +232,9 @@ class DeviceWorkerHandler(BaseHandler):
         return result
 
     async def on_federation_query_user_devices(self, user_id):
-        stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
+        stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
+            user_id
+        )
         master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
         self_signing_key = await self.store.get_e2e_cross_signing_key(
             user_id, "self_signing"
@@ -495,7 +495,7 @@ def _update_device_from_client_ips(device, client_ips):
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
 
 
-class DeviceListUpdater(object):
+class DeviceListUpdater:
     "Handles incoming device list updates from federation and updates the DB"
 
     def __init__(self, hs, device_handler):
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 610b08d00b..64ef7f63ab 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -16,8 +16,6 @@
 import logging
 from typing import Any, Dict
 
-from canonicaljson import json
-
 from synapse.api.errors import SynapseError
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
@@ -27,12 +25,13 @@ from synapse.logging.opentracing import (
     start_active_span,
 )
 from synapse.types import UserID, get_domain_from_id
+from synapse.util import json_encoder
 from synapse.util.stringutils import random_string
 
 logger = logging.getLogger(__name__)
 
 
-class DeviceMessageHandler(object):
+class DeviceMessageHandler:
     def __init__(self, hs):
         """
         Args:
@@ -174,7 +173,7 @@ class DeviceMessageHandler(object):
                     "sender": sender_user_id,
                     "type": message_type,
                     "message_id": message_id,
-                    "org.matrix.opentracing_context": json.dumps(context),
+                    "org.matrix.opentracing_context": json_encoder.encode(context),
                 }
 
         log_kv({"local_messages": local_messages})
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 79a2df6201..46826eb784 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -23,6 +23,7 @@ from synapse.api.errors import (
     CodeMessageException,
     Codes,
     NotFoundError,
+    ShadowBanError,
     StoreError,
     SynapseError,
 )
@@ -199,6 +200,8 @@ class DirectoryHandler(BaseHandler):
 
         try:
             await self._update_canonical_alias(requester, user_id, room_id, room_alias)
+        except ShadowBanError as e:
+            logger.info("Failed to update alias events due to shadow-ban: %s", e)
         except AuthError as e:
             logger.info("Failed to update alias events: %s", e)
 
@@ -292,6 +295,9 @@ class DirectoryHandler(BaseHandler):
         """
         Send an updated canonical alias event if the removed alias was set as
         the canonical alias or listed in the alt_aliases field.
+
+        Raises:
+            ShadowBanError if the requester has been shadow-banned.
         """
         alias_event = await self.state.get_current_state(
             room_id, EventTypes.CanonicalAlias, ""
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 84169c1022..dd40fd1299 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -19,7 +19,7 @@ import logging
 from typing import Dict, List, Optional, Tuple
 
 import attr
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
 from signedjson.key import VerifyKey, decode_verify_key_bytes
 from signedjson.sign import SignatureVerifyException, verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -35,7 +35,7 @@ from synapse.types import (
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
-from synapse.util import unwrapFirstError
+from synapse.util import json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
@@ -43,7 +43,7 @@ from synapse.util.retryutils import NotRetryingDestination
 logger = logging.getLogger(__name__)
 
 
-class E2eKeysHandler(object):
+class E2eKeysHandler:
     def __init__(self, hs):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
@@ -353,7 +353,7 @@ class E2eKeysHandler(object):
             # make sure that each queried user appears in the result dict
             result_dict[user_id] = {}
 
-        results = await self.store.get_e2e_device_keys(local_query)
+        results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
 
         # Build the result structure
         for user_id, device_keys in results.items():
@@ -404,7 +404,7 @@ class E2eKeysHandler(object):
             for device_id, keys in device_keys.items():
                 for key_id, json_bytes in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json.loads(json_bytes)
+                        key_id: json_decoder.decode(json_bytes)
                     }
 
         @trace
@@ -734,7 +734,7 @@ class E2eKeysHandler(object):
             # fetch our stored devices.  This is used to 1. verify
             # signatures on the master key, and 2. to compare with what
             # was sent if the device was signed
-            devices = await self.store.get_e2e_device_keys([(user_id, None)])
+            devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
 
             if user_id not in devices:
                 raise NotFoundError("No device keys found")
@@ -1186,7 +1186,7 @@ def _exception_to_failure(e):
 
 
 def _one_time_keys_match(old_key_json, new_key):
-    old_key = json.loads(old_key_json)
+    old_key = json_decoder.decode(old_key_json)
 
     # if either is a string rather than an object, they must match exactly
     if not isinstance(old_key, dict) or not isinstance(new_key, dict):
@@ -1201,7 +1201,7 @@ def _one_time_keys_match(old_key_json, new_key):
     return old_key == new_key_copy
 
 
-@attr.s
+@attr.s(slots=True)
 class SignatureListItem:
     """An item in the signature list as used by upload_signatures_for_device_keys.
     """
@@ -1212,7 +1212,7 @@ class SignatureListItem:
     signature = attr.ib()
 
 
-class SigningKeyEduUpdater(object):
+class SigningKeyEduUpdater:
     """Handles incoming signing key updates from federation and updates the DB"""
 
     def __init__(self, hs, e2e_keys_handler):
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 0bb983dc28..f01b090772 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -29,7 +29,7 @@ from synapse.util.async_helpers import Linearizer
 logger = logging.getLogger(__name__)
 
 
-class E2eRoomKeysHandler(object):
+class E2eRoomKeysHandler:
     """
     Implements an optional realtime backup mechanism for encrypted E2E megolm room keys.
     This gives a way for users to store and recover their megolm keys if they lose all
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 1924636c4d..fdce54c5c3 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -15,33 +15,30 @@
 
 import logging
 import random
+from typing import TYPE_CHECKING, Iterable, List, Optional
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, SynapseError
 from synapse.events import EventBase
 from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.utils import log_function
-from synapse.types import UserID
+from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict, UserID
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
 class EventStreamHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super(EventStreamHandler, self).__init__(hs)
 
-        # Count of active streams per user
-        self._streams_per_user = {}
-        # Grace timers per user to delay the "stopped" signal
-        self._stop_timer_per_user = {}
-
-        self.distributor = hs.get_distributor()
-        self.distributor.declare("started_user_eventstream")
-        self.distributor.declare("stopped_user_eventstream")
-
         self.clock = hs.get_clock()
 
         self.notifier = hs.get_notifier()
@@ -52,14 +49,14 @@ class EventStreamHandler(BaseHandler):
     @log_function
     async def get_stream(
         self,
-        auth_user_id,
-        pagin_config,
-        timeout=0,
-        as_client_event=True,
-        affect_presence=True,
-        room_id=None,
-        is_guest=False,
-    ):
+        auth_user_id: str,
+        pagin_config: PaginationConfig,
+        timeout: int = 0,
+        as_client_event: bool = True,
+        affect_presence: bool = True,
+        room_id: Optional[str] = None,
+        is_guest: bool = False,
+    ) -> JsonDict:
         """Fetches the events stream for a given user.
         """
 
@@ -98,7 +95,7 @@ class EventStreamHandler(BaseHandler):
 
             # When the user joins a new room, or another user joins a currently
             # joined room, we need to send down presence for those users.
-            to_add = []
+            to_add = []  # type: List[JsonDict]
             for event in events:
                 if not isinstance(event, EventBase):
                     continue
@@ -110,7 +107,7 @@ class EventStreamHandler(BaseHandler):
                         # Send down presence for everyone in the room.
                         users = await self.state.get_current_users_in_room(
                             event.room_id
-                        )
+                        )  # type: Iterable[str]
                     else:
                         users = [event.state_key]
 
@@ -144,20 +141,22 @@ class EventStreamHandler(BaseHandler):
 
 
 class EventHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super(EventHandler, self).__init__(hs)
         self.storage = hs.get_storage()
 
-    async def get_event(self, user, room_id, event_id):
+    async def get_event(
+        self, user: UserID, room_id: Optional[str], event_id: str
+    ) -> Optional[EventBase]:
         """Retrieve a single specified event.
 
         Args:
-            user (synapse.types.UserID): The user requesting the event
-            room_id (str|None): The expected room id. We'll return None if the
+            user: The user requesting the event
+            room_id: The expected room id. We'll return None if the
                 event's room does not match.
-            event_id (str): The event ID to obtain.
+            event_id: The event ID to obtain.
         Returns:
-            dict: An event, or None if there is no event matching this ID.
+            An event, or None if there is no event matching this ID.
         Raises:
             SynapseError if there was a problem retrieving this event, or
             AuthError if the user does not have the rights to inspect this
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 593932adb7..262901363f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -69,12 +69,16 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
     ReplicationStoreRoomOnInviteRestServlet,
 )
-from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
+from synapse.types import (
+    JsonDict,
+    MutableStateMap,
+    StateMap,
+    UserID,
+    get_domain_from_id,
+)
 from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.util.stringutils import shortstr
 from synapse.visibility import filter_events_for_server
@@ -82,7 +86,7 @@ from synapse.visibility import filter_events_for_server
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+@attr.s(slots=True)
 class _NewEventInfo:
     """Holds information about a received event, ready for passing to _handle_new_events
 
@@ -96,7 +100,7 @@ class _NewEventInfo:
 
     event = attr.ib(type=EventBase)
     state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
-    auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
+    auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None)
 
 
 class FederationHandler(BaseHandler):
@@ -124,7 +128,6 @@ class FederationHandler(BaseHandler):
         self.keyring = hs.get_keyring()
         self.action_generator = hs.get_action_generator()
         self.is_mine_id = hs.is_mine_id
-        self.pusher_pool = hs.get_pusherpool()
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
         self._message_handler = hs.get_message_handler()
@@ -135,9 +138,6 @@ class FederationHandler(BaseHandler):
         self._replication = hs.get_replication_data_handler()
 
         self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
-        self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
-            hs
-        )
         self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
             hs
         )
@@ -434,11 +434,11 @@ class FederationHandler(BaseHandler):
         if not prevs - seen:
             return
 
-        latest = await self.store.get_latest_event_ids_in_room(room_id)
+        latest_list = await self.store.get_latest_event_ids_in_room(room_id)
 
         # We add the prev events that we have seen to the latest
         # list to ensure the remote server doesn't give them to us
-        latest = set(latest)
+        latest = set(latest_list)
         latest |= seen
 
         logger.info(
@@ -698,31 +698,10 @@ class FederationHandler(BaseHandler):
         logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
 
         try:
-            context = await self._handle_new_event(origin, event, state=state)
+            await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
-        if event.type == EventTypes.Member:
-            if event.membership == Membership.JOIN:
-                # Only fire user_joined_room if the user has acutally
-                # joined the room. Don't bother if the user is just
-                # changing their profile info.
-                newly_joined = True
-
-                prev_state_ids = await context.get_prev_state_ids()
-
-                prev_state_id = prev_state_ids.get((event.type, event.state_key))
-                if prev_state_id:
-                    prev_state = await self.store.get_event(
-                        prev_state_id, allow_none=True
-                    )
-                    if prev_state and prev_state.membership == Membership.JOIN:
-                        newly_joined = False
-
-                if newly_joined:
-                    user = UserID.from_string(event.state_key)
-                    await self.user_joined_room(user, room_id)
-
         # For encrypted messages we check that we know about the sending device,
         # if we don't then we mark the device cache for that user as stale.
         if event.type == EventTypes.Encrypted:
@@ -775,7 +754,7 @@ class FederationHandler(BaseHandler):
                     # keys across all devices.
                     current_keys = [
                         key
-                        for device in cached_devices
+                        for device in cached_devices.values()
                         for key in device.get("keys", {}).get("keys", {}).values()
                     ]
 
@@ -917,7 +896,8 @@ class FederationHandler(BaseHandler):
                 )
             )
 
-        await self._handle_new_events(dest, ev_infos, backfilled=True)
+        if ev_infos:
+            await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -1210,7 +1190,7 @@ class FederationHandler(BaseHandler):
             event_infos.append(_NewEventInfo(event, None, auth))
 
         await self._handle_new_events(
-            destination, event_infos,
+            destination, room_id, event_infos,
         )
 
     def _sanity_check_event(self, ev):
@@ -1357,15 +1337,15 @@ class FederationHandler(BaseHandler):
             )
 
             max_stream_id = await self._persist_auth_tree(
-                origin, auth_chain, state, event, room_version_obj
+                origin, room_id, auth_chain, state, event, room_version_obj
             )
 
             # We wait here until this instance has seen the events come down
             # replication (if we're using replication) as the below uses caches.
-            #
-            # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                self.config.worker.writers.events, "events", max_stream_id
+                self.config.worker.events_shard_config.get_instance(room_id),
+                "events",
+                max_stream_id,
             )
 
             # Check whether this room is the result of an upgrade of a room we already know
@@ -1544,11 +1524,6 @@ class FederationHandler(BaseHandler):
             event.signatures,
         )
 
-        if event.type == EventTypes.Member:
-            if event.content["membership"] == Membership.JOIN:
-                user = UserID.from_string(event.state_key)
-                await self.user_joined_room(user, event.room_id)
-
         prev_state_ids = await context.get_prev_state_ids()
 
         state_ids = list(prev_state_ids.values())
@@ -1619,7 +1594,7 @@ class FederationHandler(BaseHandler):
         )
 
         context = await self.state_handler.compute_event_context(event)
-        await self.persist_events_and_notify([(event, context)])
+        await self.persist_events_and_notify(event.room_id, [(event, context)])
 
         return event
 
@@ -1646,7 +1621,9 @@ class FederationHandler(BaseHandler):
         await self.federation_client.send_leave(host_list, event)
 
         context = await self.state_handler.compute_event_context(event)
-        stream_id = await self.persist_events_and_notify([(event, context)])
+        stream_id = await self.persist_events_and_notify(
+            event.room_id, [(event, context)]
+        )
 
         return event, stream_id
 
@@ -1777,9 +1754,7 @@ class FederationHandler(BaseHandler):
         """Returns the state at the event. i.e. not including said event.
         """
 
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         state_groups = await self.state_store.get_state_groups(room_id, [event_id])
 
@@ -1805,9 +1780,7 @@ class FederationHandler(BaseHandler):
     async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event.
         """
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
 
@@ -1877,8 +1850,8 @@ class FederationHandler(BaseHandler):
         else:
             return None
 
-    def get_min_depth_for_context(self, context):
-        return self.store.get_min_depth(context)
+    async def get_min_depth_for_context(self, context):
+        return await self.store.get_min_depth(context)
 
     async def _handle_new_event(
         self, origin, event, state=None, auth_events=None, backfilled=False
@@ -1898,7 +1871,7 @@ class FederationHandler(BaseHandler):
                 )
 
             await self.persist_events_and_notify(
-                [(event, context)], backfilled=backfilled
+                event.room_id, [(event, context)], backfilled=backfilled
             )
         except Exception:
             run_in_background(
@@ -1911,6 +1884,7 @@ class FederationHandler(BaseHandler):
     async def _handle_new_events(
         self,
         origin: str,
+        room_id: str,
         event_infos: Iterable[_NewEventInfo],
         backfilled: bool = False,
     ) -> None:
@@ -1942,6 +1916,7 @@ class FederationHandler(BaseHandler):
         )
 
         await self.persist_events_and_notify(
+            room_id,
             [
                 (ev_info.event, context)
                 for ev_info, context in zip(event_infos, contexts)
@@ -1952,6 +1927,7 @@ class FederationHandler(BaseHandler):
     async def _persist_auth_tree(
         self,
         origin: str,
+        room_id: str,
         auth_events: List[EventBase],
         state: List[EventBase],
         event: EventBase,
@@ -1966,6 +1942,7 @@ class FederationHandler(BaseHandler):
 
         Args:
             origin: Where the events came from
+            room_id,
             auth_events
             state
             event
@@ -2040,24 +2017,27 @@ class FederationHandler(BaseHandler):
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
         await self.persist_events_and_notify(
+            room_id,
             [
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
-            ]
+            ],
         )
 
         new_event_context = await self.state_handler.compute_event_context(
             event, old_state=state
         )
 
-        return await self.persist_events_and_notify([(event, new_event_context)])
+        return await self.persist_events_and_notify(
+            room_id, [(event, new_event_context)]
+        )
 
     async def _prep_event(
         self,
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
-        auth_events: Optional[StateMap[EventBase]],
+        auth_events: Optional[MutableStateMap[EventBase]],
         backfilled: bool,
     ) -> EventContext:
         context = await self.state_handler.compute_event_context(event, old_state=state)
@@ -2107,8 +2087,8 @@ class FederationHandler(BaseHandler):
         if backfilled or event.internal_metadata.is_outlier():
             return
 
-        extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
-        extrem_ids = set(extrem_ids)
+        extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
+        extrem_ids = set(extrem_ids_list)
         prev_event_ids = set(event.prev_event_ids())
 
         if extrem_ids == prev_event_ids:
@@ -2138,10 +2118,12 @@ class FederationHandler(BaseHandler):
             )
             state_sets = list(state_sets.values())
             state_sets.append(state)
-            current_state_ids = await self.state_handler.resolve_events(
+            current_states = await self.state_handler.resolve_events(
                 room_version, state_sets, event
             )
-            current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+            current_state_ids = {
+                k: e.event_id for k, e in current_states.items()
+            }  # type: StateMap[str]
         else:
             current_state_ids = await self.state_handler.get_current_state_ids(
                 event.room_id, latest_event_ids=extrem_ids
@@ -2153,11 +2135,13 @@ class FederationHandler(BaseHandler):
 
         # Now check if event pass auth against said current state
         auth_types = auth_types_for_event(event)
-        current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
+        current_state_ids_list = [
+            e for k, e in current_state_ids.items() if k in auth_types
+        ]
 
-        current_auth_events = await self.store.get_events(current_state_ids)
+        auth_events_map = await self.store.get_events(current_state_ids_list)
         current_auth_events = {
-            (e.type, e.state_key): e for e in current_auth_events.values()
+            (e.type, e.state_key): e for e in auth_events_map.values()
         }
 
         try:
@@ -2173,9 +2157,7 @@ class FederationHandler(BaseHandler):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        event = await self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id
-        )
+        event = await self.store.get_event(event_id, check_room_id=room_id)
 
         # Just go through and process each event in `remote_auth_chain`. We
         # don't want to fall into the trap of `missing` being wrong.
@@ -2227,7 +2209,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         context: EventContext,
-        auth_events: StateMap[EventBase],
+        auth_events: MutableStateMap[EventBase],
     ) -> EventContext:
         """
 
@@ -2278,7 +2260,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         context: EventContext,
-        auth_events: StateMap[EventBase],
+        auth_events: MutableStateMap[EventBase],
     ) -> EventContext:
         """Helper for do_auth. See there for docs.
 
@@ -2899,6 +2881,7 @@ class FederationHandler(BaseHandler):
 
     async def persist_events_and_notify(
         self,
+        room_id: str,
         event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
     ) -> int:
@@ -2906,14 +2889,19 @@ class FederationHandler(BaseHandler):
         necessary.
 
         Args:
-            event_and_contexts:
+            room_id: The room ID of events being persisted.
+            event_and_contexts: Sequence of events with their associated
+                context that should be persisted. All events must belong to
+                the same room.
             backfilled: Whether these events are a result of
                 backfilling or not
         """
-        if self.config.worker.writers.events != self._instance_name:
+        instance = self.config.worker.events_shard_config.get_instance(room_id)
+        if instance != self._instance_name:
             result = await self._send_events(
-                instance_name=self.config.worker.writers.events,
+                instance_name=instance,
                 store=self.store,
+                room_id=room_id,
                 event_and_contexts=event_and_contexts,
                 backfilled=backfilled,
             )
@@ -2966,8 +2954,6 @@ class FederationHandler(BaseHandler):
             event, event_stream_id, max_stream_id, extra_users=extra_users
         )
 
-        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
     async def _clean_room_for_join(self, room_id: str) -> None:
         """Called to clean up any data in DB for a given room, ready for the
         server to join the room.
@@ -2980,16 +2966,6 @@ class FederationHandler(BaseHandler):
         else:
             await self.store.clean_room_for_join(room_id)
 
-    async def user_joined_room(self, user: UserID, room_id: str) -> None:
-        """Called when a new user has joined the room
-        """
-        if self.config.worker_app:
-            await self._notify_user_membership_change(
-                room_id=room_id, user_id=user.to_string(), change="joined"
-            )
-        else:
-            user_joined_room(self.distributor, user, room_id)
-
     async def get_room_complexity(
         self, remote_room_hosts: List[str], room_id: str
     ) -> Optional[dict]:
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 0e2656ccb3..44df567983 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -52,7 +52,7 @@ def _create_rerouter(func_name):
     return f
 
 
-class GroupsLocalWorkerHandler(object):
+class GroupsLocalWorkerHandler:
     def __init__(self, hs):
         self.hs = hs
         self.store = hs.get_datastore()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 92b7404706..0ce6ddfbe4 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -21,8 +21,6 @@ import logging
 import urllib.parse
 from typing import Awaitable, Callable, Dict, List, Optional, Tuple
 
-from canonicaljson import json
-
 from twisted.internet.error import TimeoutError
 
 from synapse.api.errors import (
@@ -34,6 +32,7 @@ from synapse.api.errors import (
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http.client import SimpleHttpClient
 from synapse.types import JsonDict, Requester
+from synapse.util import json_decoder
 from synapse.util.hash import sha256_and_url_safe_base64
 from synapse.util.stringutils import assert_valid_client_secret, random_string
 
@@ -177,7 +176,7 @@ class IdentityHandler(BaseHandler):
         except TimeoutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except CodeMessageException as e:
-            data = json.loads(e.msg)  # XXX WAT?
+            data = json_decoder.decode(e.msg)  # XXX WAT?
             return data
 
         logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index ae6bd1d352..ba4828c713 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from twisted.internet import defer
 
@@ -22,8 +23,9 @@ from synapse.api.errors import SynapseError
 from synapse.events.validator import EventValidator
 from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.storage.roommember import RoomsForUser
 from synapse.streams.config import PaginationConfig
-from synapse.types import StreamToken, UserID
+from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
@@ -31,11 +33,15 @@ from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
 class InitialSyncHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super(InitialSyncHandler, self).__init__(hs)
         self.hs = hs
         self.state = hs.get_state_handler()
@@ -48,27 +54,25 @@ class InitialSyncHandler(BaseHandler):
 
     def snapshot_all_rooms(
         self,
-        user_id=None,
-        pagin_config=None,
-        as_client_event=True,
-        include_archived=False,
-    ):
+        user_id: str,
+        pagin_config: PaginationConfig,
+        as_client_event: bool = True,
+        include_archived: bool = False,
+    ) -> JsonDict:
         """Retrieve a snapshot of all rooms the user is invited or has joined.
 
         This snapshot may include messages for all rooms where the user is
         joined, depending on the pagination config.
 
         Args:
-            user_id (str): The ID of the user making the request.
-            pagin_config (synapse.api.streams.PaginationConfig): The pagination
-            config used to determine how many messages *PER ROOM* to return.
-            as_client_event (bool): True to get events in client-server format.
-            include_archived (bool): True to get rooms that the user has left
+            user_id: The ID of the user making the request.
+            pagin_config: The pagination config used to determine how many
+                messages *PER ROOM* to return.
+            as_client_event: True to get events in client-server format.
+            include_archived: True to get rooms that the user has left
         Returns:
-            A list of dicts with "room_id" and "membership" keys for all rooms
-            the user is currently invited or joined in on. Rooms where the user
-            is joined on, may return a "messages" key with messages, depending
-            on the specified PaginationConfig.
+            A JsonDict with the same format as the response to `/intialSync`
+            API
         """
         key = (
             user_id,
@@ -91,11 +95,11 @@ class InitialSyncHandler(BaseHandler):
 
     async def _snapshot_all_rooms(
         self,
-        user_id=None,
-        pagin_config=None,
-        as_client_event=True,
-        include_archived=False,
-    ):
+        user_id: str,
+        pagin_config: PaginationConfig,
+        as_client_event: bool = True,
+        include_archived: bool = False,
+    ) -> JsonDict:
 
         memberships = [Membership.INVITE, Membership.JOIN]
         if include_archived:
@@ -112,14 +116,13 @@ class InitialSyncHandler(BaseHandler):
         now_token = self.hs.get_event_sources().get_current_token()
 
         presence_stream = self.hs.get_event_sources().sources["presence"]
-        pagination_config = PaginationConfig(from_token=now_token)
-        presence, _ = await presence_stream.get_pagination_rows(
-            user, pagination_config.get_source_config("presence"), None
+        presence, _ = await presence_stream.get_new_events(
+            user, from_key=None, include_offline=False
         )
 
-        receipt_stream = self.hs.get_event_sources().sources["receipt"]
-        receipt, _ = await receipt_stream.get_pagination_rows(
-            user, pagination_config.get_source_config("receipt"), None
+        joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
+        receipt = await self.store.get_linearized_receipts_for_rooms(
+            joined_rooms, to_key=int(now_token.receipt_key),
         )
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -134,7 +137,7 @@ class InitialSyncHandler(BaseHandler):
         if limit is None:
             limit = 10
 
-        async def handle_room(event):
+        async def handle_room(event: RoomsForUser):
             d = {
                 "room_id": event.room_id,
                 "membership": event.membership,
@@ -164,7 +167,7 @@ class InitialSyncHandler(BaseHandler):
                         self.state_handler.get_current_state, event.room_id
                     )
                 elif event.membership == Membership.LEAVE:
-                    room_end_token = "s%d" % (event.stream_ordering,)
+                    room_end_token = RoomStreamToken(None, event.stream_ordering,)
                     deferred_room_state = run_in_background(
                         self.state_store.get_state_for_events, [event.event_id]
                     )
@@ -251,17 +254,18 @@ class InitialSyncHandler(BaseHandler):
 
         return ret
 
-    async def room_initial_sync(self, requester, room_id, pagin_config=None):
+    async def room_initial_sync(
+        self, requester: Requester, room_id: str, pagin_config: PaginationConfig
+    ) -> JsonDict:
         """Capture the a snapshot of a room. If user is currently a member of
         the room this will be what is currently in the room. If the user left
         the room this will be what was in the room when they left.
 
         Args:
-            requester(Requester): The user to get a snapshot for.
-            room_id(str): The room to get a snapshot of.
-            pagin_config(synapse.streams.config.PaginationConfig):
-                The pagination config used to determine how many messages to
-                return.
+            requester: The user to get a snapshot for.
+            room_id: The room to get a snapshot of.
+            pagin_config: The pagination config used to determine how many
+                messages to return.
         Raises:
             AuthError if the user wasn't in the room.
         Returns:
@@ -305,8 +309,14 @@ class InitialSyncHandler(BaseHandler):
         return result
 
     async def _room_initial_sync_parted(
-        self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
-    ):
+        self,
+        user_id: str,
+        room_id: str,
+        pagin_config: PaginationConfig,
+        membership: Membership,
+        member_event_id: str,
+        is_peeking: bool,
+    ) -> JsonDict:
         room_state = await self.state_store.get_state_for_events([member_event_id])
 
         room_state = room_state[member_event_id]
@@ -350,8 +360,13 @@ class InitialSyncHandler(BaseHandler):
         }
 
     async def _room_initial_sync_joined(
-        self, user_id, room_id, pagin_config, membership, is_peeking
-    ):
+        self,
+        user_id: str,
+        room_id: str,
+        pagin_config: PaginationConfig,
+        membership: Membership,
+        is_peeking: bool,
+    ) -> JsonDict:
         current_state = await self.state.get_current_state(room_id=room_id)
 
         # TODO: These concurrently
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 2643438e84..a8fe5cf4e2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,9 +15,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import random
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
 
 from twisted.internet.interfaces import IDelayedCall
 
@@ -34,6 +35,7 @@ from synapse.api.errors import (
     Codes,
     ConsentNotGivenError,
     NotFoundError,
+    ShadowBanError,
     SynapseError,
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
@@ -47,14 +49,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import (
-    Collection,
-    Requester,
-    RoomAlias,
-    StreamToken,
-    UserID,
-    create_requester,
-)
+from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
+from synapse.util import json_decoder
 from synapse.util.async_helpers import Linearizer
 from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.metrics import measure_func
@@ -68,7 +64,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class MessageHandler(object):
+class MessageHandler:
     """Contains some read only APIs to get state about a room
     """
 
@@ -92,12 +88,7 @@ class MessageHandler(object):
             )
 
     async def get_room_data(
-        self,
-        user_id: str,
-        room_id: str,
-        event_type: str,
-        state_key: str,
-        is_guest: bool,
+        self, user_id: str, room_id: str, event_type: str, state_key: str,
     ) -> dict:
         """ Get data from a room.
 
@@ -106,11 +97,10 @@ class MessageHandler(object):
             room_id
             event_type
             state_key
-            is_guest
         Returns:
             The path data content.
         Raises:
-            SynapseError if something went wrong.
+            SynapseError or AuthError if the user is not in the room
         """
         (
             membership,
@@ -127,6 +117,16 @@ class MessageHandler(object):
                 [membership_event_id], StateFilter.from_types([key])
             )
             data = room_state[membership_event_id].get(key)
+        else:
+            # check_user_in_room_or_world_readable, if it doesn't raise an AuthError, should
+            # only ever return a Membership.JOIN/LEAVE object
+            #
+            # Safeguard in case it returned something else
+            logger.error(
+                "Attempted to retrieve data from a room for a user that has never been in it. "
+                "This should not have happened."
+            )
+            raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN)
 
         return data
 
@@ -361,7 +361,7 @@ class MessageHandler(object):
 _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
 
 
-class EventCreationHandler(object):
+class EventCreationHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
@@ -376,9 +376,8 @@ class EventCreationHandler(object):
         self.notifier = hs.get_notifier()
         self.config = hs.config
         self.require_membership_for_aliases = hs.config.require_membership_for_aliases
-        self._is_event_writer = (
-            self.config.worker.writers.events == hs.get_instance_name()
-        )
+        self._events_shard_config = self.config.worker.events_shard_config
+        self._instance_name = hs.get_instance_name()
 
         self.room_invite_state_types = self.hs.config.room_invite_state_types
 
@@ -387,8 +386,6 @@ class EventCreationHandler(object):
         # This is only used to get at ratelimit function, and maybe_kick_guest_users
         self.base_handler = BaseHandler(hs)
 
-        self.pusher_pool = hs.get_pusherpool()
-
         # We arbitrarily limit concurrent event creation for a room to 5.
         # This is to stop us from diverging history *too* much.
         self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
@@ -439,7 +436,7 @@ class EventCreationHandler(object):
         event_dict: dict,
         token_id: Optional[str] = None,
         txn_id: Optional[str] = None,
-        prev_event_ids: Optional[Collection[str]] = None,
+        prev_event_ids: Optional[List[str]] = None,
         require_consent: bool = True,
     ) -> Tuple[EventBase, EventContext]:
         """
@@ -644,37 +641,48 @@ class EventCreationHandler(object):
         event: EventBase,
         context: EventContext,
         ratelimit: bool = True,
+        ignore_shadow_ban: bool = False,
     ) -> int:
         """
         Persists and notifies local clients and federation of an event.
 
         Args:
-            requester
-            event the event to send.
-            context: the context of the event.
+            requester: The requester sending the event.
+            event: The event to send.
+            context: The context of the event.
             ratelimit: Whether to rate limit this send.
+            ignore_shadow_ban: True if shadow-banned users should be allowed to
+                send this event.
 
         Return:
             The stream_id of the persisted event.
+
+        Raises:
+            ShadowBanError if the requester has been shadow-banned.
         """
         if event.type == EventTypes.Member:
             raise SynapseError(
                 500, "Tried to send member event through non-member codepath"
             )
 
+        if not ignore_shadow_ban and requester.shadow_banned:
+            # We randomly sleep a bit just to annoy the requester.
+            await self.clock.sleep(random.randint(1, 10))
+            raise ShadowBanError()
+
         user = UserID.from_string(event.sender)
 
         assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
 
         if event.is_state():
-            prev_state = await self.deduplicate_state_event(event, context)
-            if prev_state is not None:
+            prev_event = await self.deduplicate_state_event(event, context)
+            if prev_event is not None:
                 logger.info(
                     "Not bothering to persist state event %s duplicated by %s",
                     event.event_id,
-                    prev_state.event_id,
+                    prev_event.event_id,
                 )
-                return prev_state
+                return await self.store.get_stream_id_for_event(prev_event.event_id)
 
         return await self.handle_new_client_event(
             requester=requester, event=event, context=context, ratelimit=ratelimit
@@ -682,27 +690,32 @@ class EventCreationHandler(object):
 
     async def deduplicate_state_event(
         self, event: EventBase, context: EventContext
-    ) -> None:
+    ) -> Optional[EventBase]:
         """
         Checks whether event is in the latest resolved state in context.
 
-        If so, returns the version of the event in context.
-        Otherwise, returns None.
+        Args:
+            event: The event to check for duplication.
+            context: The event context.
+
+        Returns:
+            The previous verion of the event is returned, if it is found in the
+            event context. Otherwise, None is returned.
         """
         prev_state_ids = await context.get_prev_state_ids()
         prev_event_id = prev_state_ids.get((event.type, event.state_key))
         if not prev_event_id:
-            return
+            return None
         prev_event = await self.store.get_event(prev_event_id, allow_none=True)
         if not prev_event:
-            return
+            return None
 
         if prev_event and event.user_id == prev_event.user_id:
             prev_content = encode_canonical_json(prev_event.content)
             next_content = encode_canonical_json(event.content)
             if prev_content == next_content:
                 return prev_event
-        return
+        return None
 
     async def create_and_send_nonmember_event(
         self,
@@ -710,12 +723,28 @@ class EventCreationHandler(object):
         event_dict: dict,
         ratelimit: bool = True,
         txn_id: Optional[str] = None,
+        ignore_shadow_ban: bool = False,
     ) -> Tuple[EventBase, int]:
         """
         Creates an event, then sends it.
 
         See self.create_event and self.send_nonmember_event.
+
+        Args:
+            requester: The requester sending the event.
+            event_dict: An entire event.
+            ratelimit: Whether to rate limit this send.
+            txn_id: The transaction ID.
+            ignore_shadow_ban: True if shadow-banned users should be allowed to
+                send this event.
+
+        Raises:
+            ShadowBanError if the requester has been shadow-banned.
         """
+        if not ignore_shadow_ban and requester.shadow_banned:
+            # We randomly sleep a bit just to annoy the requester.
+            await self.clock.sleep(random.randint(1, 10))
+            raise ShadowBanError()
 
         # We limit the number of concurrent event sends in a room so that we
         # don't fork the DAG too much. If we don't limit then we can end up in
@@ -734,7 +763,11 @@ class EventCreationHandler(object):
                 raise SynapseError(403, spam_error, Codes.FORBIDDEN)
 
             stream_id = await self.send_nonmember_event(
-                requester, event, context, ratelimit=ratelimit
+                requester,
+                event,
+                context,
+                ratelimit=ratelimit,
+                ignore_shadow_ban=ignore_shadow_ban,
             )
         return event, stream_id
 
@@ -743,7 +776,7 @@ class EventCreationHandler(object):
         self,
         builder: EventBuilder,
         requester: Optional[Requester] = None,
-        prev_event_ids: Optional[Collection[str]] = None,
+        prev_event_ids: Optional[List[str]] = None,
     ) -> Tuple[EventBase, EventContext]:
         """Create a new event for a local client
 
@@ -859,7 +892,7 @@ class EventCreationHandler(object):
         # Ensure that we can round trip before trying to persist in db
         try:
             dump = frozendict_json_encoder.encode(event.content)
-            json.loads(dump)
+            json_decoder.decode(dump)
         except Exception:
             logger.exception("Failed to encode content: %r", event.content)
             raise
@@ -868,9 +901,10 @@ class EventCreationHandler(object):
 
         try:
             # If we're a worker we need to hit out to the master.
-            if not self._is_event_writer:
+            writer_instance = self._events_shard_config.get_instance(event.room_id)
+            if writer_instance != self._instance_name:
                 result = await self.send_event(
-                    instance_name=self.config.worker.writers.events,
+                    instance_name=writer_instance,
                     event_id=event.event_id,
                     store=self.store,
                     requester=requester,
@@ -891,9 +925,7 @@ class EventCreationHandler(object):
         except Exception:
             # Ensure that we actually remove the entries in the push actions
             # staging area, if we calculated them.
-            run_in_background(
-                self.store.remove_push_actions_from_staging, event.event_id
-            )
+            await self.store.remove_push_actions_from_staging(event.event_id)
             raise
 
     async def _validate_canonical_alias(
@@ -940,7 +972,10 @@ class EventCreationHandler(object):
 
         This should only be run on the instance in charge of persisting events.
         """
-        assert self._is_event_writer
+        assert self.storage.persistence is not None
+        assert self._events_shard_config.should_handle(
+            self._instance_name, event.room_id
+        )
 
         if ratelimit:
             # We check if this is a room admin redacting an event so that we
@@ -957,7 +992,7 @@ class EventCreationHandler(object):
                     allow_none=True,
                 )
 
-                is_admin_redaction = (
+                is_admin_redaction = bool(
                     original_event and event.sender != original_event.sender
                 )
 
@@ -1077,8 +1112,8 @@ class EventCreationHandler(object):
             auth_events_ids = self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
-            auth_events = await self.store.get_events(auth_events_ids)
-            auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+            auth_events_map = await self.store.get_events(auth_events_ids)
+            auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
 
             room_version = await self.store.get_room_version_id(event.room_id)
             room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -1111,8 +1146,6 @@ class EventCreationHandler(object):
             # If there's an expiry timestamp on the event, schedule its expiry.
             self._message_handler.maybe_schedule_expiry(event)
 
-        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
         def _notify():
             try:
                 self.notifier.on_new_room_event(
@@ -1176,8 +1209,14 @@ class EventCreationHandler(object):
 
                     event.internal_metadata.proactively_send = False
 
+                    # Since this is a dummy-event it is OK if it is sent by a
+                    # shadow-banned user.
                     await self.send_nonmember_event(
-                        requester, event, context, ratelimit=False
+                        requester,
+                        event,
+                        context,
+                        ratelimit=False,
+                        ignore_shadow_ban=True,
                     )
                     dummy_event_sent = True
                     break
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index fa5ee5de8f..4230dbaf99 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
 import logging
 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
 from urllib.parse import urlencode
@@ -38,8 +37,8 @@ from synapse.config import ConfigError
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
-from synapse.push.mailer import load_jinja2_templates
 from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util import json_decoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -94,6 +93,7 @@ class OidcHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
+        self.hs = hs
         self._callback_url = hs.config.oidc_callback_url  # type: str
         self._scopes = hs.config.oidc_scopes  # type: List[str]
         self._client_auth = ClientAuth(
@@ -123,9 +123,7 @@ class OidcHandler:
         self._hostname = hs.hostname  # type: str
         self._server_name = hs.config.server_name  # type: str
         self._macaroon_secret_key = hs.config.macaroon_secret_key
-        self._error_template = load_jinja2_templates(
-            hs.config.sso_template_dir, ["sso_error.html"]
-        )[0]
+        self._error_template = hs.config.sso_error_template
 
         # identifier for the external_ids table
         self._auth_provider_id = "oidc"
@@ -133,10 +131,10 @@ class OidcHandler:
     def _render_error(
         self, request, error: str, error_description: Optional[str] = None
     ) -> None:
-        """Renders the error template and respond with it.
+        """Render the error template and respond to the request with it.
 
         This is used to show errors to the user. The template of this page can
-        be found under ``synapse/res/templates/sso_error.html``.
+        be found under `synapse/res/templates/sso_error.html`.
 
         Args:
             request: The incoming request from the browser.
@@ -370,7 +368,7 @@ class OidcHandler:
             # and check for an error field. If not, we respond with a generic
             # error message.
             try:
-                resp = json.loads(resp_body.decode("utf-8"))
+                resp = json_decoder.decode(resp_body.decode("utf-8"))
                 error = resp["error"]
                 description = resp.get("error_description", error)
             except (ValueError, KeyError):
@@ -387,7 +385,7 @@ class OidcHandler:
 
         # Since it is a not a 5xx code, body should be a valid JSON. It will
         # raise if not.
-        resp = json.loads(resp_body.decode("utf-8"))
+        resp = json_decoder.decode(resp_body.decode("utf-8"))
 
         if "error" in resp:
             error = resp["error"]
@@ -692,9 +690,17 @@ class OidcHandler:
                 self._render_error(request, "invalid_token", str(e))
                 return
 
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+        ip_address = self.hs.get_ip_from_request(request)
+
         # Call the mapper to register/login the user
         try:
-            user_id = await self._map_userinfo_to_user(userinfo, token)
+            user_id = await self._map_userinfo_to_user(
+                userinfo, token, user_agent, ip_address
+            )
         except MappingException as e:
             logger.exception("Could not map user")
             self._render_error(request, "mapping_error", str(e))
@@ -831,7 +837,9 @@ class OidcHandler:
         now = self._clock.time_msec()
         return now < expiry
 
-    async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+    async def _map_userinfo_to_user(
+        self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
+    ) -> str:
         """Maps a UserInfo object to a mxid.
 
         UserInfo should have a claim that uniquely identifies users. This claim
@@ -846,6 +854,8 @@ class OidcHandler:
         Args:
             userinfo: an object representing the user
             token: a dict with the tokens obtained from the provider
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
 
         Raises:
             MappingException: if there was an error while mapping some properties
@@ -859,6 +869,9 @@ class OidcHandler:
             raise MappingException(
                 "Failed to extract subject from OIDC response: %s" % (e,)
             )
+        # Some OIDC providers use integer IDs, but Synapse expects external IDs
+        # to be strings.
+        remote_user_id = str(remote_user_id)
 
         logger.info(
             "Looking for existing mapping for user %s:%s",
@@ -902,7 +915,9 @@ class OidcHandler:
         # It's the first time this user is logging in and the mapped mxid was
         # not taken, register the user
         registered_user_id = await self._registration_handler.register_user(
-            localpart=localpart, default_display_name=attributes["display_name"],
+            localpart=localpart,
+            default_display_name=attributes["display_name"],
+            user_agent_ips=(user_agent, ip_address),
         )
 
         await self._datastore.record_user_external_id(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 487420bb5d..d929a68f7d 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -14,23 +14,30 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import TYPE_CHECKING, Any, Dict, Optional, Set
 
 from twisted.python.failure import Failure
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import SynapseError
+from synapse.api.filtering import Filter
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.state import StateFilter
-from synapse.types import RoomStreamToken
+from synapse.streams.config import PaginationConfig
+from synapse.types import Requester, RoomStreamToken
 from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.stringutils import random_string
 from synapse.visibility import filter_events_for_client
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
-class PurgeStatus(object):
+class PurgeStatus:
     """Object tracking the status of a purge request
 
     This class contains information on the progress of a purge request, for
@@ -58,14 +65,14 @@ class PurgeStatus(object):
         return {"status": PurgeStatus.STATUS_TEXT[self.status]}
 
 
-class PaginationHandler(object):
+class PaginationHandler:
     """Handles pagination and purge history requests.
 
     These are in the same handler due to the fact we need to block clients
     paginating during a purge.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
@@ -75,13 +82,16 @@ class PaginationHandler(object):
         self._server_name = hs.hostname
 
         self.pagination_lock = ReadWriteLock()
-        self._purges_in_progress_by_room = set()
+        self._purges_in_progress_by_room = set()  # type: Set[str]
         # map from purge id to PurgeStatus
-        self._purges_by_id = {}
+        self._purges_by_id = {}  # type: Dict[str, PurgeStatus]
         self._event_serializer = hs.get_event_client_serializer()
 
         self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
 
+        self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
+        self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
+
         if hs.config.retention_enabled:
             # Run the purge jobs described in the configuration file.
             for job in hs.config.retention_purge_jobs:
@@ -96,7 +106,9 @@ class PaginationHandler(object):
                     job["longest_max_lifetime"],
                 )
 
-    async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
+    async def purge_history_for_rooms_in_range(
+        self, min_ms: Optional[int], max_ms: Optional[int]
+    ):
         """Purge outdated events from rooms within the given retention range.
 
         If a default retention policy is defined in the server's configuration and its
@@ -104,14 +116,14 @@ class PaginationHandler(object):
         retention policy.
 
         Args:
-            min_ms (int|None): Duration in milliseconds that define the lower limit of
+            min_ms: Duration in milliseconds that define the lower limit of
                 the range to handle (exclusive). If None, it means that the range has no
                 lower limit.
-            max_ms (int|None): Duration in milliseconds that define the upper limit of
+            max_ms: Duration in milliseconds that define the upper limit of
                 the range to handle (inclusive). If None, it means that the range has no
                 upper limit.
         """
-        # We want the storage layer to to include rooms with no retention policy in its
+        # We want the storage layer to include rooms with no retention policy in its
         # return value only if a default retention policy is defined in the server's
         # configuration and that policy's 'max_lifetime' is either lower (or equal) than
         # max_ms or higher than min_ms (or both).
@@ -152,13 +164,32 @@ class PaginationHandler(object):
                 )
                 continue
 
-            max_lifetime = retention_policy["max_lifetime"]
+            # If max_lifetime is None, it means that the room has no retention policy.
+            # Given we only retrieve such rooms when there's a default retention policy
+            # defined in the server's configuration, we can safely assume that's the
+            # case and use it for this room.
+            max_lifetime = (
+                retention_policy["max_lifetime"] or self._retention_default_max_lifetime
+            )
 
-            if max_lifetime is None:
-                # If max_lifetime is None, it means that include_null equals True,
-                # therefore we can safely assume that there is a default policy defined
-                # in the server's configuration.
-                max_lifetime = self._retention_default_max_lifetime
+            # Cap the effective max_lifetime to be within the range allowed in the
+            # config.
+            # We do this in two steps:
+            #   1. Make sure it's higher or equal to the minimum allowed value, and if
+            #      it's not replace it with that value. This is because the server
+            #      operator can be required to not delete information before a given
+            #      time, e.g. to comply with freedom of information laws.
+            #   2. Make sure the resulting value is lower or equal to the maximum allowed
+            #      value, and if it's not replace it with that value. This is because the
+            #      server operator can be required to delete any data after a specific
+            #      amount of time.
+            if self._retention_allowed_lifetime_min is not None:
+                max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime)
+
+            if self._retention_allowed_lifetime_max is not None:
+                max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max)
+
+            logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime)
 
             # Figure out what token we should start purging at.
             ts = self.clock.time_msec() - max_lifetime
@@ -195,18 +226,19 @@ class PaginationHandler(object):
                 "_purge_history", self._purge_history, purge_id, room_id, token, True,
             )
 
-    def start_purge_history(self, room_id, token, delete_local_events=False):
+    def start_purge_history(
+        self, room_id: str, token: str, delete_local_events: bool = False
+    ) -> str:
         """Start off a history purge on a room.
 
         Args:
-            room_id (str): The room to purge from
-
-            token (str): topological token to delete events before
-            delete_local_events (bool): True to delete local events as well as
+            room_id: The room to purge from
+            token: topological token to delete events before
+            delete_local_events: True to delete local events as well as
                 remote ones
 
         Returns:
-            str: unique ID for this purge transaction.
+            unique ID for this purge transaction.
         """
         if room_id in self._purges_in_progress_by_room:
             raise SynapseError(
@@ -225,15 +257,16 @@ class PaginationHandler(object):
         )
         return purge_id
 
-    async def _purge_history(self, purge_id, room_id, token, delete_local_events):
+    async def _purge_history(
+        self, purge_id: str, room_id: str, token: str, delete_local_events: bool
+    ) -> None:
         """Carry out a history purge on a room.
 
         Args:
-            purge_id (str): The id for this purge
-            room_id (str): The room to purge from
-            token (str): topological token to delete events before
-            delete_local_events (bool): True to delete local events as well as
-                remote ones
+            purge_id: The id for this purge
+            room_id: The room to purge from
+            token: topological token to delete events before
+            delete_local_events: True to delete local events as well as remote ones
         """
         self._purges_in_progress_by_room.add(room_id)
         try:
@@ -258,20 +291,17 @@ class PaginationHandler(object):
 
             self.hs.get_reactor().callLater(24 * 3600, clear_purge)
 
-    def get_purge_status(self, purge_id):
+    def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]:
         """Get the current status of an active purge
 
         Args:
-            purge_id (str): purge_id returned by start_purge_history
-
-        Returns:
-            PurgeStatus|None
+            purge_id: purge_id returned by start_purge_history
         """
         return self._purges_by_id.get(purge_id)
 
-    async def purge_room(self, room_id):
+    async def purge_room(self, room_id: str) -> None:
         """Purge the given room from the database"""
-        with (await self.pagination_lock.write(room_id)):
+        with await self.pagination_lock.write(room_id):
             # check we know about the room
             await self.store.get_room_version_id(room_id)
 
@@ -285,43 +315,38 @@ class PaginationHandler(object):
 
     async def get_messages(
         self,
-        requester,
-        room_id=None,
-        pagin_config=None,
-        as_client_event=True,
-        event_filter=None,
-    ):
+        requester: Requester,
+        room_id: str,
+        pagin_config: PaginationConfig,
+        as_client_event: bool = True,
+        event_filter: Optional[Filter] = None,
+    ) -> Dict[str, Any]:
         """Get messages in a room.
 
         Args:
-            requester (Requester): The user requesting messages.
-            room_id (str): The room they want messages from.
-            pagin_config (synapse.api.streams.PaginationConfig): The pagination
-                config rules to apply, if any.
-            as_client_event (bool): True to get events in client-server format.
-            event_filter (Filter): Filter to apply to results or None
+            requester: The user requesting messages.
+            room_id: The room they want messages from.
+            pagin_config: The pagination config rules to apply, if any.
+            as_client_event: True to get events in client-server format.
+            event_filter: Filter to apply to results or None
         Returns:
-            dict: Pagination API results
+            Pagination API results
         """
         user_id = requester.user.to_string()
 
         if pagin_config.from_token:
-            room_token = pagin_config.from_token.room_key
+            from_token = pagin_config.from_token
         else:
-            pagin_config.from_token = (
-                self.hs.get_event_sources().get_current_token_for_pagination()
-            )
-            room_token = pagin_config.from_token.room_key
+            from_token = self.hs.get_event_sources().get_current_token_for_pagination()
 
-        room_token = RoomStreamToken.parse(room_token)
-
-        pagin_config.from_token = pagin_config.from_token.copy_and_replace(
-            "room_key", str(room_token)
-        )
+        if pagin_config.limit is None:
+            # This shouldn't happen as we've set a default limit before this
+            # gets called.
+            raise Exception("limit not set")
 
-        source_config = pagin_config.get_source_config("room")
+        room_token = from_token.room_key
 
-        with (await self.pagination_lock.read(room_id)):
+        with await self.pagination_lock.read(room_id):
             (
                 membership,
                 member_event_id,
@@ -329,7 +354,7 @@ class PaginationHandler(object):
                 room_id, user_id, allow_departed_users=True
             )
 
-            if source_config.direction == "b":
+            if pagin_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
                 # requires that we have a topo token.
                 if room_token.topological:
@@ -343,27 +368,40 @@ class PaginationHandler(object):
                     # If they have left the room then clamp the token to be before
                     # they left the room, to save the effort of loading from the
                     # database.
-                    leave_token = await self.store.get_topological_token_for_event(
+
+                    # This is only None if the room is world_readable, in which
+                    # case "JOIN" would have been returned.
+                    assert member_event_id
+
+                    leave_token_str = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
-                    leave_token = RoomStreamToken.parse(leave_token)
+                    leave_token = RoomStreamToken.parse(leave_token_str)
+                    assert leave_token.topological is not None
+
                     if leave_token.topological < max_topo:
-                        source_config.from_key = str(leave_token)
+                        from_token = from_token.copy_and_replace(
+                            "room_key", leave_token
+                        )
 
                 await self.hs.get_handlers().federation_handler.maybe_backfill(
                     room_id, max_topo
                 )
 
+            to_room_key = None
+            if pagin_config.to_token:
+                to_room_key = pagin_config.to_token.room_key
+
             events, next_key = await self.store.paginate_room_events(
                 room_id=room_id,
-                from_key=source_config.from_key,
-                to_key=source_config.to_key,
-                direction=source_config.direction,
-                limit=source_config.limit,
+                from_key=from_token.room_key,
+                to_key=to_room_key,
+                direction=pagin_config.direction,
+                limit=pagin_config.limit,
                 event_filter=event_filter,
             )
 
-            next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
+            next_token = from_token.copy_and_replace("room_key", next_key)
 
         if events:
             if event_filter:
@@ -376,7 +414,7 @@ class PaginationHandler(object):
         if not events:
             return {
                 "chunk": [],
-                "start": pagin_config.from_token.to_string(),
+                "start": from_token.to_string(),
                 "end": next_token.to_string(),
             }
 
@@ -394,8 +432,8 @@ class PaginationHandler(object):
             )
 
             if state_ids:
-                state = await self.store.get_events(list(state_ids.values()))
-                state = state.values()
+                state_dict = await self.store.get_events(list(state_ids.values()))
+                state = state_dict.values()
 
         time_now = self.clock.time_msec()
 
@@ -405,7 +443,7 @@ class PaginationHandler(object):
                     events, time_now, as_client_event=as_client_event
                 )
             ),
-            "start": pagin_config.from_token.to_string(),
+            "start": from_token.to_string(),
             "end": next_token.to_string(),
         }
 
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
index d06b110269..88e2f87200 100644
--- a/synapse/handlers/password_policy.py
+++ b/synapse/handlers/password_policy.py
@@ -22,7 +22,7 @@ from synapse.api.errors import Codes, PasswordRefusedError
 logger = logging.getLogger(__name__)
 
 
-class PasswordPolicyHandler(object):
+class PasswordPolicyHandler:
     def __init__(self, hs):
         self.policy = hs.config.password_policy
         self.enabled = hs.config.password_policy_enabled
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 5387b3724f..1000ac95ff 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -33,14 +33,14 @@ from typing_extensions import ContextManager
 import synapse.metrics
 from synapse.api.constants import EventTypes, Membership, PresenceState
 from synapse.api.errors import SynapseError
+from synapse.api.presence import UserPresenceState
 from synapse.logging.context import run_in_background
 from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.state import StateHandler
 from synapse.storage.databases.main import DataStore
-from synapse.storage.presence import UserPresenceState
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.descriptors import cached
 from synapse.util.metrics import Measure
@@ -1010,7 +1010,7 @@ def format_user_presence_state(state, now, include_user_id=True):
     return content
 
 
-class PresenceEventSource(object):
+class PresenceEventSource:
     def __init__(self, hs):
         # We can't call get_presence_handler here because there's a cycle:
         #
@@ -1108,9 +1108,6 @@ class PresenceEventSource(object):
     def get_current_key(self):
         return self.store.get_current_presence_token()
 
-    async def get_pagination_rows(self, user, pagination_config, key):
-        return await self.get_new_events(user, from_key=None, include_offline=False)
-
     @cached(num_args=2, cache_context=True)
     async def _get_interested_in(self, user, explicit_room_id, cache_context):
         """Returns the set of users that the given user should see presence
@@ -1318,7 +1315,7 @@ async def get_interested_parties(
 
 async def get_interested_remotes(
     store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
-) -> List[Tuple[List[str], List[UserPresenceState]]]:
+) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
     """Given a list of presence states figure out which remote servers
     should be sent which.
 
@@ -1334,7 +1331,7 @@ async def get_interested_remotes(
         each tuple the list of UserPresenceState should be sent to each
         destination
     """
-    hosts_and_states = []
+    hosts_and_states = []  # type: List[Tuple[Collection[str], List[UserPresenceState]]]
 
     # First we look up the rooms each user is in (as well as any explicit
     # subscriptions), then for each distinct room we look up the remote
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 31a2e5ea18..0cb8fad89a 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+import random
 
 from synapse.api.errors import (
     AuthError,
@@ -160,6 +161,9 @@ class BaseProfileHandler(BaseHandler):
                     Codes.FORBIDDEN,
                 )
 
+        if not isinstance(new_displayname, str):
+            raise SynapseError(400, "Invalid displayname")
+
         if len(new_displayname) > MAX_DISPLAYNAME_LEN:
             raise SynapseError(
                 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -213,8 +217,14 @@ class BaseProfileHandler(BaseHandler):
     async def set_avatar_url(
         self, target_user, requester, new_avatar_url, by_admin=False
     ):
-        """target_user is the user whose avatar_url is to be changed;
-        auth_user is the user attempting to make this change."""
+        """Set a new avatar URL for a user.
+
+        Args:
+            target_user (UserID): the user whose avatar URL is to be changed.
+            requester (Requester): The user attempting to make this change.
+            new_avatar_url (str): The avatar URL to give this user.
+            by_admin (bool): Whether this change was made by an administrator.
+        """
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this homeserver")
 
@@ -228,6 +238,9 @@ class BaseProfileHandler(BaseHandler):
                     400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
                 )
 
+        if not isinstance(new_avatar_url, str):
+            raise SynapseError(400, "Invalid displayname")
+
         if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
             raise SynapseError(
                 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
@@ -278,6 +291,12 @@ class BaseProfileHandler(BaseHandler):
 
         await self.ratelimit(requester)
 
+        # Do not actually update the room state for shadow-banned users.
+        if requester.shadow_banned:
+            # We randomly sleep a bit just to annoy the requester.
+            await self.clock.sleep(random.randint(1, 10))
+            return
+
         room_ids = await self.store.get_rooms_for_user(target_user.to_string())
 
         for room_id in room_ids:
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index f922d8a545..bdd8e52edd 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -123,7 +123,7 @@ class ReceiptsHandler(BaseHandler):
         await self.federation.send_read_receipt(receipt)
 
 
-class ReceiptEventSource(object):
+class ReceiptEventSource:
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
@@ -142,18 +142,3 @@ class ReceiptEventSource(object):
 
     def get_current_key(self, direction="f"):
         return self.store.get_max_receipt_stream_id()
-
-    async def get_pagination_rows(self, user, config, key):
-        to_key = int(config.from_key)
-
-        if config.to_key:
-            from_key = int(config.to_key)
-        else:
-            from_key = None
-
-        room_ids = await self.store.get_rooms_for_user(user.to_string())
-        events = await self.store.get_linearized_receipts_for_rooms(
-            room_ids, from_key=from_key, to_key=to_key
-        )
-
-        return (events, to_key)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c94209ab3d..cde2dbca92 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -26,6 +26,7 @@ from synapse.replication.http.register import (
     ReplicationPostRegisterActionsServlet,
     ReplicationRegisterServlet,
 )
+from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, UserID, create_requester
 
@@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler):
         self.macaroon_gen = hs.get_macaroon_generator()
         self._server_notices_mxid = hs.config.server_notices_mxid
 
+        self.spam_checker = hs.get_spam_checker()
+
         if hs.config.worker_app:
             self._register_client = ReplicationRegisterServlet.make_client(hs)
             self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -124,7 +127,9 @@ class RegistrationHandler(BaseHandler):
             try:
                 int(localpart)
                 raise SynapseError(
-                    400, "Numeric user IDs are reserved for guest users."
+                    400,
+                    "Numeric user IDs are reserved for guest users.",
+                    errcode=Codes.INVALID_USERNAME,
                 )
             except ValueError:
                 pass
@@ -142,6 +147,7 @@ class RegistrationHandler(BaseHandler):
         address=None,
         bind_emails=[],
         by_admin=False,
+        user_agent_ips=None,
     ):
         """Registers a new client on the server.
 
@@ -159,6 +165,8 @@ class RegistrationHandler(BaseHandler):
             bind_emails (List[str]): list of emails to bind to this account.
             by_admin (bool): True if this registration is being made via the
               admin api, otherwise False.
+            user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+                during the registration process.
         Returns:
             str: user_id
         Raises:
@@ -166,6 +174,24 @@ class RegistrationHandler(BaseHandler):
         """
         self.check_registration_ratelimit(address)
 
+        result = self.spam_checker.check_registration_for_spam(
+            threepid, localpart, user_agent_ips or [],
+        )
+
+        if result == RegistrationBehaviour.DENY:
+            logger.info(
+                "Blocked registration of %r", localpart,
+            )
+            # We return a 429 to make it not obvious that they've been
+            # denied.
+            raise SynapseError(429, "Rate limited")
+
+        shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
+        if shadow_banned:
+            logger.info(
+                "Shadow banning registration of %r", localpart,
+            )
+
         # do not check_auth_blocking if the call is coming through the Admin API
         if not by_admin:
             await self.auth.check_auth_blocking(threepid=threepid)
@@ -194,6 +220,7 @@ class RegistrationHandler(BaseHandler):
                 admin=admin,
                 user_type=user_type,
                 address=address,
+                shadow_banned=shadow_banned,
             )
 
             if self.hs.config.user_directory_search_all_users:
@@ -224,6 +251,7 @@ class RegistrationHandler(BaseHandler):
                         make_guest=make_guest,
                         create_profile_with_displayname=default_display_name,
                         address=address,
+                        shadow_banned=shadow_banned,
                     )
 
                     # Successfully registered
@@ -529,6 +557,7 @@ class RegistrationHandler(BaseHandler):
         admin=False,
         user_type=None,
         address=None,
+        shadow_banned=False,
     ):
         """Register user in the datastore.
 
@@ -546,6 +575,7 @@ class RegistrationHandler(BaseHandler):
             user_type (str|None): type of user. One of the values from
                 api.constants.UserTypes, or None for a normal user.
             address (str|None): the IP address used to perform the registration.
+            shadow_banned (bool): Whether to shadow-ban the user
 
         Returns:
             Awaitable
@@ -561,6 +591,7 @@ class RegistrationHandler(BaseHandler):
                 admin=admin,
                 user_type=user_type,
                 address=address,
+                shadow_banned=shadow_banned,
             )
         else:
             return self.store.register_user(
@@ -572,6 +603,7 @@ class RegistrationHandler(BaseHandler):
                 create_profile_with_displayname=create_profile_with_displayname,
                 admin=admin,
                 user_type=user_type,
+                shadow_banned=shadow_banned,
             )
 
     async def register_device(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a8545255b1..eeade6ad3f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,9 +20,10 @@
 import itertools
 import logging
 import math
+import random
 import string
 from collections import OrderedDict
-from typing import Awaitable, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
 
 from synapse.api.constants import (
     EventTypes,
@@ -32,11 +33,15 @@ from synapse.api.constants import (
     RoomEncryptionAlgorithms,
 )
 from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
+from synapse.api.filtering import Filter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
 from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.storage.state import StateFilter
 from synapse.types import (
+    JsonDict,
+    MutableStateMap,
     Requester,
     RoomAlias,
     RoomID,
@@ -47,12 +52,15 @@ from synapse.types import (
     create_requester,
 )
 from synapse.util import stringutils
-from synapse.util.async_helpers import Linearizer, maybe_awaitable
+from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 id_server_scheme = "https://"
@@ -61,7 +69,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
 
 
 class RoomCreationHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super(RoomCreationHandler, self).__init__(hs)
 
         self.spam_checker = hs.get_spam_checker()
@@ -92,7 +100,7 @@ class RoomCreationHandler(BaseHandler):
                 "guest_can_join": False,
                 "power_level_content_override": {},
             },
-        }
+        }  # type: Dict[str, Dict[str, Any]]
 
         # Modify presets to selectively enable encryption by default per homeserver config
         for preset_name, preset_config in self._presets_dict.items():
@@ -129,6 +137,9 @@ class RoomCreationHandler(BaseHandler):
 
         Returns:
             the new room id
+
+        Raises:
+            ShadowBanError if the requester is shadow-banned.
         """
         await self.ratelimit(requester)
 
@@ -164,6 +175,15 @@ class RoomCreationHandler(BaseHandler):
     async def _upgrade_room(
         self, requester: Requester, old_room_id: str, new_version: RoomVersion
     ):
+        """
+        Args:
+            requester: the user requesting the upgrade
+            old_room_id: the id of the room to be replaced
+            new_versions: the version to upgrade the room to
+
+        Raises:
+            ShadowBanError if the requester is shadow-banned.
+        """
         user_id = requester.user.to_string()
 
         # start by allocating a new room id
@@ -215,6 +235,9 @@ class RoomCreationHandler(BaseHandler):
 
         old_room_state = await tombstone_context.get_current_state_ids()
 
+        # We know the tombstone event isn't an outlier so it has current state.
+        assert old_room_state is not None
+
         # update any aliases
         await self._move_aliases_to_new_room(
             requester, old_room_id, new_room_id, old_room_state
@@ -247,6 +270,9 @@ class RoomCreationHandler(BaseHandler):
             old_room_id: the id of the room to be replaced
             new_room_id: the id of the replacement room
             old_room_state: the state map for the old room
+
+        Raises:
+            ShadowBanError if the requester is shadow-banned.
         """
         old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
 
@@ -425,7 +451,7 @@ class RoomCreationHandler(BaseHandler):
         old_room_member_state_events = await self.store.get_events(
             old_room_member_state_ids.values()
         )
-        for k, old_event in old_room_member_state_events.items():
+        for old_event in old_room_member_state_events.values():
             # Only transfer ban events
             if (
                 "membership" in old_event.content
@@ -528,17 +554,21 @@ class RoomCreationHandler(BaseHandler):
             logger.error("Unable to send updated alias events in new room: %s", e)
 
     async def create_room(
-        self, requester, config, ratelimit=True, creator_join_profile=None
+        self,
+        requester: Requester,
+        config: JsonDict,
+        ratelimit: bool = True,
+        creator_join_profile: Optional[JsonDict] = None,
     ) -> Tuple[dict, int]:
         """ Creates a new room.
 
         Args:
-            requester (synapse.types.Requester):
+            requester:
                 The user who requested the room creation.
-            config (dict) : A dict of configuration options.
-            ratelimit (bool): set to False to disable the rate limiter
+            config : A dict of configuration options.
+            ratelimit: set to False to disable the rate limiter
 
-            creator_join_profile (dict|None):
+            creator_join_profile:
                 Set to override the displayname and avatar for the creating
                 user in this room. If unset, displayname and avatar will be
                 derived from the user's profile. If set, should contain the
@@ -601,6 +631,7 @@ class RoomCreationHandler(BaseHandler):
                 Codes.UNSUPPORTED_ROOM_VERSION,
             )
 
+        room_alias = None
         if "room_alias_name" in config:
             for wchar in string.whitespace:
                 if wchar in config["room_alias_name"]:
@@ -611,9 +642,8 @@ class RoomCreationHandler(BaseHandler):
 
             if mapping:
                 raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
-        else:
-            room_alias = None
 
+        invite_3pid_list = config.get("invite_3pid", [])
         invite_list = config.get("invite", [])
         for i in invite_list:
             try:
@@ -622,6 +652,14 @@ class RoomCreationHandler(BaseHandler):
             except Exception:
                 raise SynapseError(400, "Invalid user_id: %s" % (i,))
 
+        if (invite_list or invite_3pid_list) and requester.shadow_banned:
+            # We randomly sleep a bit just to annoy the requester.
+            await self.clock.sleep(random.randint(1, 10))
+
+            # Allow the request to go through, but remove any associated invites.
+            invite_3pid_list = []
+            invite_list = []
+
         await self.event_creation_handler.assert_accepted_privacy_policy(requester)
 
         power_level_content_override = config.get("power_level_content_override")
@@ -636,8 +674,6 @@ class RoomCreationHandler(BaseHandler):
                 % (user_id,),
             )
 
-        invite_3pid_list = config.get("invite_3pid", [])
-
         visibility = config.get("visibility", None)
         is_public = visibility == "public"
 
@@ -732,6 +768,8 @@ class RoomCreationHandler(BaseHandler):
             if is_direct:
                 content["is_direct"] = is_direct
 
+            # Note that update_membership with an action of "invite" can raise a
+            # ShadowBanError, but this was handled above by emptying invite_list.
             _, last_stream_id = await self.room_member_handler.update_membership(
                 requester,
                 UserID.from_string(invitee),
@@ -746,6 +784,8 @@ class RoomCreationHandler(BaseHandler):
             id_access_token = invite_3pid.get("id_access_token")  # optional
             address = invite_3pid["address"]
             medium = invite_3pid["medium"]
+            # Note that do_3pid_invite can raise a  ShadowBanError, but this was
+            # handled above by emptying invite_3pid_list.
             last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
                 room_id,
                 requester.user,
@@ -764,30 +804,39 @@ class RoomCreationHandler(BaseHandler):
 
         # Always wait for room creation to progate before returning
         await self._replication.wait_for_stream_position(
-            self.hs.config.worker.writers.events, "events", last_stream_id
+            self.hs.config.worker.events_shard_config.get_instance(room_id),
+            "events",
+            last_stream_id,
         )
 
         return result, last_stream_id
 
     async def _send_events_for_new_room(
         self,
-        creator,  # A Requester object.
-        room_id,
-        preset_config,
-        invite_list,
-        initial_state,
-        creation_content,
-        room_alias=None,
-        power_level_content_override=None,  # Doesn't apply when initial state has power level state event content
-        creator_join_profile=None,
+        creator: Requester,
+        room_id: str,
+        preset_config: str,
+        invite_list: List[str],
+        initial_state: MutableStateMap,
+        creation_content: JsonDict,
+        room_alias: Optional[RoomAlias] = None,
+        power_level_content_override: Optional[JsonDict] = None,
+        creator_join_profile: Optional[JsonDict] = None,
     ) -> int:
         """Sends the initial events into a new room.
 
+        `power_level_content_override` doesn't apply when initial state has
+        power level state event content.
+
         Returns:
             The stream_id of the last event persisted.
         """
 
-        def create(etype, content, **kwargs):
+        creator_id = creator.user.to_string()
+
+        event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
+
+        def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
             e = {"type": etype, "content": content}
 
             e.update(event_keys)
@@ -795,23 +844,21 @@ class RoomCreationHandler(BaseHandler):
 
             return e
 
-        async def send(etype, content, **kwargs) -> int:
+        async def send(etype: str, content: JsonDict, **kwargs) -> int:
             event = create(etype, content, **kwargs)
             logger.debug("Sending %s in new room", etype)
+            # Allow these events to be sent even if the user is shadow-banned to
+            # allow the room creation to complete.
             (
                 _,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
-                creator, event, ratelimit=False
+                creator, event, ratelimit=False, ignore_shadow_ban=True,
             )
             return last_stream_id
 
         config = self._presets_dict[preset_config]
 
-        creator_id = creator.user.to_string()
-
-        event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
-
         creation_content.update({"creator": creator_id})
         await send(etype=EventTypes.Create, content=creation_content)
 
@@ -852,7 +899,7 @@ class RoomCreationHandler(BaseHandler):
                 "kick": 50,
                 "redact": 50,
                 "invite": 50,
-            }
+            }  # type: JsonDict
 
             if config["original_invitees_have_ops"]:
                 for invitee in invite_list:
@@ -906,7 +953,7 @@ class RoomCreationHandler(BaseHandler):
         return last_sent_stream_id
 
     async def _generate_room_id(
-        self, creator_id: str, is_public: str, room_version: RoomVersion,
+        self, creator_id: str, is_public: bool, room_version: RoomVersion,
     ):
         # autogen room IDs and try to create it. We may clash, so just
         # try a few times till one goes through, giving up eventually.
@@ -929,24 +976,31 @@ class RoomCreationHandler(BaseHandler):
         raise StoreError(500, "Couldn't generate a room ID.")
 
 
-class RoomContextHandler(object):
-    def __init__(self, hs):
+class RoomContextHandler:
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    async def get_event_context(self, user, room_id, event_id, limit, event_filter):
+    async def get_event_context(
+        self,
+        user: UserID,
+        room_id: str,
+        event_id: str,
+        limit: int,
+        event_filter: Optional[Filter],
+    ) -> Optional[JsonDict]:
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
         Args:
-            user (UserID)
-            room_id (str)
-            event_id (str)
-            limit (int): The maximum number of events to return in total
+            user
+            room_id
+            event_id
+            limit: The maximum number of events to return in total
                 (excluding state).
-            event_filter (Filter|None): the filter to apply to the events returned
+            event_filter: the filter to apply to the events returned
                 (excluding the target event_id)
 
         Returns:
@@ -1032,21 +1086,26 @@ class RoomContextHandler(object):
         return results
 
 
-class RoomEventSource(object):
-    def __init__(self, hs):
+class RoomEventSource:
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
     async def get_new_events(
-        self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
-    ):
+        self,
+        user: UserID,
+        from_key: RoomStreamToken,
+        limit: int,
+        room_ids: List[str],
+        is_guest: bool,
+        explicit_room_id: Optional[str] = None,
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         # We just ignore the key for now.
 
         to_key = self.get_current_key()
 
-        from_token = RoomStreamToken.parse(from_key)
-        if from_token.topological:
+        if from_key.topological:
             logger.warning("Stream has topological part!!!! %r", from_key)
-            from_key = "s%s" % (from_token.stream,)
+            from_key = RoomStreamToken(None, from_key.stream)
 
         app_service = self.store.get_app_service_by_user_id(user.to_string())
         if app_service:
@@ -1075,20 +1134,20 @@ class RoomEventSource(object):
                 events[:] = events[:limit]
 
             if events:
-                end_key = events[-1].internal_metadata.after
+                end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
             else:
                 end_key = to_key
 
         return (events, end_key)
 
-    def get_current_key(self) -> str:
-        return "s%d" % (self.store.get_room_max_stream_ordering(),)
+    def get_current_key(self) -> RoomStreamToken:
+        return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
 
     def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
         return self.store.get_room_events_max_id(room_id)
 
 
-class RoomShutdownHandler(object):
+class RoomShutdownHandler:
 
     DEFAULT_MESSAGE = (
         "Sharing illegal content on this server is not permitted and rooms in"
@@ -1096,7 +1155,7 @@ class RoomShutdownHandler(object):
     )
     DEFAULT_ROOM_NAME = "Content Violation Notification"
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.room_member_handler = hs.get_room_member_handler()
         self._room_creation_handler = hs.get_room_creation_handler()
@@ -1202,10 +1261,10 @@ class RoomShutdownHandler(object):
             # We now wait for the create room to come back in via replication so
             # that we can assume that all the joins/invites have propogated before
             # we try and auto join below.
-            #
-            # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                self.hs.config.worker.writers.events, "events", stream_id
+                self.hs.config.worker.events_shard_config.get_instance(new_room_id),
+                "events",
+                stream_id,
             )
         else:
             new_room_id = None
@@ -1235,7 +1294,9 @@ class RoomShutdownHandler(object):
 
                 # Wait for leave to come in over replication before trying to forget.
                 await self._replication.wait_for_stream_position(
-                    self.hs.config.worker.writers.events, "events", stream_id
+                    self.hs.config.worker.events_shard_config.get_instance(room_id),
+                    "events",
+                    stream_id,
                 )
 
                 await self.room_member_handler.forget(target_requester.user, room_id)
@@ -1272,9 +1333,7 @@ class RoomShutdownHandler(object):
                 ratelimit=False,
             )
 
-            aliases_for_room = await maybe_awaitable(
-                self.store.get_aliases_for_room(room_id)
-            )
+            aliases_for_room = await self.store.get_aliases_for_room(room_id)
 
             await self.store.update_aliases_for_room(
                 room_id, new_room_id, requester_user_id
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 9fcabb22c7..01a6e88262 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -15,14 +15,21 @@
 
 import abc
 import logging
+import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
 
 from unpaddedbase64 import encode_base64
 
 from synapse import types
 from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    LimitExceededError,
+    ShadowBanError,
+    SynapseError,
+)
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.api.room_versions import EventFormatVersions
 from synapse.crypto.event_signing import compute_event_reference_hash
@@ -31,9 +38,9 @@ from synapse.events.builder import create_local_event_from_event_dict
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.storage.roommember import RoomsForUser
-from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
+from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
 from synapse.util.async_helpers import Linearizer
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
 
 from ._base import BaseHandler
 
@@ -44,7 +51,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class RoomMemberHandler(object):
+class RoomMemberHandler:
     # TODO(paul): This handler currently contains a messy conflation of
     #   low-level API that works on UserID objects and so on, and REST-level
     #   API that takes ID strings and returns pagination chunks. These concerns
@@ -75,13 +82,6 @@ class RoomMemberHandler(object):
         self._enable_lookup = hs.config.enable_3pid_lookup
         self.allow_per_room_profiles = self.config.allow_per_room_profiles
 
-        self._event_stream_writer_instance = hs.config.worker.writers.events
-        self._is_on_event_persistence_instance = (
-            self._event_stream_writer_instance == hs.get_instance_name()
-        )
-        if self._is_on_event_persistence_instance:
-            self.persist_event_storage = hs.get_storage().persistence
-
         self._join_rate_limiter_local = Ratelimiter(
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
@@ -142,17 +142,6 @@ class RoomMemberHandler(object):
         raise NotImplementedError()
 
     @abc.abstractmethod
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Notifies distributor on master process that the user has joined the
-        room.
-
-        Args:
-            target
-            room_id
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Notifies distributor on master process that the user has left the
         room.
@@ -169,7 +158,7 @@ class RoomMemberHandler(object):
         target: UserID,
         room_id: str,
         membership: str,
-        prev_event_ids: Collection[str],
+        prev_event_ids: List[str],
         txn_id: Optional[str] = None,
         ratelimit: bool = True,
         content: Optional[dict] = None,
@@ -214,7 +203,6 @@ class RoomMemberHandler(object):
 
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
 
-        newly_joined = False
         if event.membership == Membership.JOIN:
             newly_joined = True
             if prev_member_event_id:
@@ -239,12 +227,7 @@ class RoomMemberHandler(object):
             requester, event, context, extra_users=[target], ratelimit=ratelimit,
         )
 
-        if event.membership == Membership.JOIN and newly_joined:
-            # Only fire user_joined_room if the user has actually joined the
-            # room. Don't bother if the user is just changing their profile
-            # info.
-            await self._user_joined_room(target, room_id)
-        elif event.membership == Membership.LEAVE:
+        if event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
@@ -301,6 +284,31 @@ class RoomMemberHandler(object):
         content: Optional[dict] = None,
         require_consent: bool = True,
     ) -> Tuple[str, int]:
+        """Update a user's membership in a room.
+
+        Params:
+            requester: The user who is performing the update.
+            target: The user whose membership is being updated.
+            room_id: The room ID whose membership is being updated.
+            action: The membership change, see synapse.api.constants.Membership.
+            txn_id: The transaction ID, if given.
+            remote_room_hosts: Remote servers to send the update to.
+            third_party_signed: Information from a 3PID invite.
+            ratelimit: Whether to rate limit the request.
+            content: The content of the created event.
+            require_consent: Whether consent is required.
+
+        Returns:
+            A tuple of the new event ID and stream ID.
+
+        Raises:
+            ShadowBanError if a shadow-banned requester attempts to send an invite.
+        """
+        if action == Membership.INVITE and requester.shadow_banned:
+            # We randomly sleep a bit just to annoy the requester.
+            await self.clock.sleep(random.randint(1, 10))
+            raise ShadowBanError()
+
         key = (room_id,)
 
         with (await self.member_linearizer.queue(key)):
@@ -340,7 +348,7 @@ class RoomMemberHandler(object):
             # later on.
             content = dict(content)
 
-        if not self.allow_per_room_profiles:
+        if not self.allow_per_room_profiles or requester.shadow_banned:
             # Strip profile data, knowing that new profile data will be added to the
             # event's content in event_creation_handler.create_event() using the target's
             # global profile.
@@ -694,25 +702,13 @@ class RoomMemberHandler(object):
             (EventTypes.Member, event.state_key), None
         )
 
-        if event.membership == Membership.JOIN:
-            # Only fire user_joined_room if the user has actually joined the
-            # room. Don't bother if the user is just changing their profile
-            # info.
-            newly_joined = True
-            if prev_member_event_id:
-                prev_member_event = await self.store.get_event(prev_member_event_id)
-                newly_joined = prev_member_event.membership != Membership.JOIN
-            if newly_joined:
-                await self._user_joined_room(target_user, room_id)
-        elif event.membership == Membership.LEAVE:
+        if event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
                     await self._user_left_room(target_user, room_id)
 
-    async def _can_guest_join(
-        self, current_state_ids: Dict[Tuple[str, str], str]
-    ) -> bool:
+    async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
         """
         Returns whether a guest can join a room based on its current state.
         """
@@ -722,7 +718,7 @@ class RoomMemberHandler(object):
 
         guest_access = await self.store.get_event(guest_access_id)
 
-        return (
+        return bool(
             guest_access
             and guest_access.content
             and "guest_access" in guest_access.content
@@ -779,6 +775,25 @@ class RoomMemberHandler(object):
         txn_id: Optional[str],
         id_access_token: Optional[str] = None,
     ) -> int:
+        """Invite a 3PID to a room.
+
+        Args:
+            room_id: The room to invite the 3PID to.
+            inviter: The user sending the invite.
+            medium: The 3PID's medium.
+            address: The 3PID's address.
+            id_server: The identity server to use.
+            requester: The user making the request.
+            txn_id: The transaction ID this is part of, or None if this is not
+                part of a transaction.
+            id_access_token: The optional identity server access token.
+
+        Returns:
+             The new stream ID.
+
+        Raises:
+            ShadowBanError if the requester has been shadow-banned.
+        """
         if self.config.block_non_admin_invites:
             is_requester_admin = await self.auth.is_server_admin(requester.user)
             if not is_requester_admin:
@@ -786,6 +801,11 @@ class RoomMemberHandler(object):
                     403, "Invites have been disabled on this server", Codes.FORBIDDEN
                 )
 
+        if requester.shadow_banned:
+            # We randomly sleep a bit just to annoy the requester.
+            await self.clock.sleep(random.randint(1, 10))
+            raise ShadowBanError()
+
         # We need to rate limit *before* we send out any 3PID invites, so we
         # can't just rely on the standard ratelimiting of events.
         await self.base_handler.ratelimit(requester)
@@ -810,6 +830,8 @@ class RoomMemberHandler(object):
         )
 
         if invitee:
+            # Note that update_membership with an action of "invite" can raise
+            # a ShadowBanError, but this was done above already.
             _, stream_id = await self.update_membership(
                 requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
             )
@@ -915,9 +937,7 @@ class RoomMemberHandler(object):
         )
         return stream_id
 
-    async def _is_host_in_room(
-        self, current_state_ids: Dict[Tuple[str, str], str]
-    ) -> bool:
+    async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
         # Have we just created the room, and is this about to be the very
         # first member event?
         create_event_id = current_state_ids.get(("m.room.create", ""))
@@ -948,10 +968,9 @@ class RoomMemberHandler(object):
 
 class RoomMemberMasterHandler(RoomMemberHandler):
     def __init__(self, hs):
-        super(RoomMemberMasterHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.distributor = hs.get_distributor()
-        self.distributor.declare("user_joined_room")
         self.distributor.declare("user_left_room")
 
     async def _is_remote_room_too_complex(
@@ -1031,7 +1050,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         event_id, stream_id = await self.federation_handler.do_invite_join(
             remote_room_hosts, room_id, user.to_string(), content
         )
-        await self._user_joined_room(user, room_id)
 
         # Check the room we just joined wasn't too large, if we didn't fetch the
         # complexity of it before.
@@ -1048,7 +1066,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
                 return event_id, stream_id
 
             # The room is too large. Leave.
-            requester = types.create_requester(user, None, False, None)
+            requester = types.create_requester(user, None, False, False, None)
             await self.update_membership(
                 requester=requester, target=user, room_id=room_id, action="leave"
             )
@@ -1174,11 +1192,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         )
         return event.event_id, stream_id
 
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Implements RoomMemberHandler._user_joined_room
-        """
-        user_joined_room(self.distributor, target, room_id)
-
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 897338fd54..e7f34737c6 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -57,8 +57,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
             content=content,
         )
 
-        await self._user_joined_room(user, room_id)
-
         return ret["event_id"], ret["stream_id"]
 
     async def remote_reject_invite(
@@ -81,13 +79,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         )
         return ret["event_id"], ret["stream_id"]
 
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Implements RoomMemberHandler._user_joined_room
-        """
-        await self._notify_change_client(
-            user_id=target.to_string(), room_id=room_id, change="joined"
-        )
-
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index c1fcb98454..285c481a96 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -21,9 +21,10 @@ import saml2
 import saml2.response
 from saml2.client import Saml2Client
 
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
+from synapse.http.server import respond_with_html
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
@@ -41,7 +42,11 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+class MappingException(Exception):
+    """Used to catch errors when mapping the SAML2 response to a user."""
+
+
+@attr.s(slots=True)
 class Saml2SessionData:
     """Data we track about SAML2 sessions"""
 
@@ -54,6 +59,7 @@ class Saml2SessionData:
 
 class SamlHandler:
     def __init__(self, hs: "synapse.server.HomeServer"):
+        self.hs = hs
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
         self._auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
@@ -67,6 +73,7 @@ class SamlHandler:
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
         self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
+        self._error_template = hs.config.sso_error_template
 
         # plugin to do custom mapping from saml response to mxid
         self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@@ -83,6 +90,25 @@ class SamlHandler:
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
 
+    def _render_error(
+        self, request, error: str, error_description: Optional[str] = None
+    ) -> None:
+        """Render the error template and respond to the request with it.
+
+        This is used to show errors to the user. The template of this page can
+        be found under `synapse/res/templates/sso_error.html`.
+
+        Args:
+            request: The incoming request from the browser.
+                We'll respond with an HTML page describing the error.
+            error: A technical identifier for this error.
+            error_description: A human-readable description of the error.
+        """
+        html = self._error_template.render(
+            error=error, error_description=error_description
+        )
+        respond_with_html(request, 400, html)
+
     def handle_redirect_request(
         self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
     ) -> bytes:
@@ -133,37 +159,6 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
-        user_id, current_session = await self._map_saml_response_to_user(
-            resp_bytes, relay_state
-        )
-
-        # Complete the interactive auth session or the login.
-        if current_session and current_session.ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, current_session.ui_auth_session_id, request
-            )
-
-        else:
-            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
-
-    async def _map_saml_response_to_user(
-        self, resp_bytes: str, client_redirect_url: str
-    ) -> Tuple[str, Optional[Saml2SessionData]]:
-        """
-        Given a sample response, retrieve the cached session and user for it.
-
-        Args:
-            resp_bytes: The SAML response.
-            client_redirect_url: The redirect URL passed in by the client.
-
-        Returns:
-             Tuple of the user ID and SAML session associated with this response.
-
-        Raises:
-            SynapseError if there was a problem with the response.
-            RedirectException: some mapping providers may raise this if they need
-                to redirect to an interstitial page.
-        """
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -176,12 +171,23 @@ class SamlHandler:
             # in the (user-visible) exception message, so let's log the exception here
             # so we can track down the session IDs later.
             logger.warning(str(e))
-            raise SynapseError(400, "Unexpected SAML2 login.")
+            self._render_error(
+                request, "unsolicited_response", "Unexpected SAML2 login."
+            )
+            return
         except Exception as e:
-            raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
+            self._render_error(
+                request,
+                "invalid_response",
+                "Unable to parse SAML2 response: %s." % (e,),
+            )
+            return
 
         if saml2_auth.not_signed:
-            raise SynapseError(400, "SAML2 response was not signed.")
+            self._render_error(
+                request, "unsigned_respond", "SAML2 response was not signed."
+            )
+            return
 
         logger.debug("SAML2 response: %s", saml2_auth.origxml)
         for assertion in saml2_auth.assertions:
@@ -200,15 +206,73 @@ class SamlHandler:
             saml2_auth.in_response_to, None
         )
 
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
         for requirement in self._saml2_attribute_requirements:
-            _check_attribute_requirement(saml2_auth.ava, requirement)
+            if not _check_attribute_requirement(saml2_auth.ava, requirement):
+                self._render_error(
+                    request, "unauthorised", "You are not authorised to log in here."
+                )
+                return
+
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+        ip_address = self.hs.get_ip_from_request(request)
+
+        # Call the mapper to register/login the user
+        try:
+            user_id = await self._map_saml_response_to_user(
+                saml2_auth, relay_state, user_agent, ip_address
+            )
+        except MappingException as e:
+            logger.exception("Could not map user")
+            self._render_error(request, "mapping_error", str(e))
+            return
+
+        # Complete the interactive auth session or the login.
+        if current_session and current_session.ui_auth_session_id:
+            await self._auth_handler.complete_sso_ui_auth(
+                user_id, current_session.ui_auth_session_id, request
+            )
+
+        else:
+            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+
+    async def _map_saml_response_to_user(
+        self,
+        saml2_auth: saml2.response.AuthnResponse,
+        client_redirect_url: str,
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """
+        Given a SAML response, retrieve the user ID for it and possibly register the user.
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+
+        Returns:
+             The user ID associated with this response.
+
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
+        """
 
         remote_user_id = self._user_mapping_provider.get_remote_user_id(
             saml2_auth, client_redirect_url
         )
 
         if not remote_user_id:
-            raise Exception("Failed to extract remote user id from SAML response")
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
+            )
 
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
@@ -222,7 +286,7 @@ class SamlHandler:
             )
             if registered_user_id is not None:
                 logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id, current_session
+                return registered_user_id
 
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
@@ -247,7 +311,7 @@ class SamlHandler:
                     await self._datastore.record_user_external_id(
                         self._auth_provider_id, remote_user_id, registered_user_id
                     )
-                    return registered_user_id, current_session
+                    return registered_user_id
 
             # Map saml response to user attributes using the configured mapping provider
             for i in range(1000):
@@ -264,7 +328,7 @@ class SamlHandler:
 
                 localpart = attribute_dict.get("mxid_localpart")
                 if not localpart:
-                    raise Exception(
+                    raise MappingException(
                         "Error parsing SAML2 response: SAML mapping provider plugin "
                         "did not return a mxid_localpart value"
                     )
@@ -281,8 +345,8 @@ class SamlHandler:
             else:
                 # Unable to generate a username in 1000 iterations
                 # Break and return error to the user
-                raise SynapseError(
-                    500, "Unable to generate a Matrix ID from the SAML response"
+                raise MappingException(
+                    "Unable to generate a Matrix ID from the SAML response"
                 )
 
             logger.info("Mapped SAML user to local part %s", localpart)
@@ -291,12 +355,13 @@ class SamlHandler:
                 localpart=localpart,
                 default_display_name=displayname,
                 bind_emails=emails,
+                user_agent_ips=(user_agent, ip_address),
             )
 
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
-            return registered_user_id, current_session
+            return registered_user_id
 
     def expire_sessions(self):
         expire_before = self._clock.time_msec() - self._saml2_session_lifetime
@@ -309,11 +374,11 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
+def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
     values = ava.get(req.attribute, [])
     for v in values:
         if v == req.value:
-            return
+            return True
 
     logger.info(
         "SAML2 attribute %s did not match required value '%s' (was '%s')",
@@ -321,7 +386,7 @@ def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
         req.value,
         values,
     )
-    raise AuthError(403, "You are not authorized to log in here.")
+    return False
 
 
 DOT_REPLACE_PATTERN = re.compile(
@@ -346,12 +411,12 @@ MXID_MAPPER_MAP = {
 
 
 @attr.s
-class SamlConfig(object):
+class SamlConfig:
     mxid_source_attribute = attr.ib()
     mxid_mapper = attr.ib()
 
 
-class DefaultSamlMappingProvider(object):
+class DefaultSamlMappingProvider:
     __version__ = "0.0.1"
 
     def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
@@ -376,7 +441,7 @@ class DefaultSamlMappingProvider(object):
             return saml_response.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "'uid' not in SAML2 response")
+            raise MappingException("'uid' not in SAML2 response")
 
     def saml_response_to_user_attributes(
         self,
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index 8590c1eff4..7a4ae0727a 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -18,7 +18,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class StateDeltasHandler(object):
+class StateDeltasHandler:
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c42dac18f5..9b3a4f638b 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -16,7 +16,7 @@
 
 import itertools
 import logging
-from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
 
 import attr
 from prometheus_client import Counter
@@ -31,6 +31,7 @@ from synapse.storage.state import StateFilter
 from synapse.types import (
     Collection,
     JsonDict,
+    MutableStateMap,
     RoomStreamToken,
     StateMap,
     StreamToken,
@@ -43,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.metrics import Measure, measure_func
 from synapse.visibility import filter_events_for_client
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # Debug logger for https://github.com/matrix-org/synapse/issues/4422
@@ -85,16 +89,19 @@ class TimelineBatch:
     events = attr.ib(type=List[EventBase])
     limited = attr.ib(bool)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
         return bool(self.events)
 
-    __bool__ = __nonzero__  # python3
-
 
-@attr.s(slots=True, frozen=True)
+# We can't freeze this class, because we need to update it after it's instantiated to
+# update its unread count. This is because we calculate the unread count for a room only
+# if there are updates for it, which we check after the instance has been created.
+# This should not be a big deal because we update the notification counts afterwards as
+# well anyway.
+@attr.s(slots=True)
 class JoinedSyncResult:
     room_id = attr.ib(type=str)
     timeline = attr.ib(type=TimelineBatch)
@@ -103,8 +110,9 @@ class JoinedSyncResult:
     account_data = attr.ib(type=List[JsonDict])
     unread_notifications = attr.ib(type=JsonDict)
     summary = attr.ib(type=Optional[JsonDict])
+    unread_count = attr.ib(type=int)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
@@ -117,8 +125,6 @@ class JoinedSyncResult:
             # else in the result, we don't need to send it.
         )
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class ArchivedSyncResult:
@@ -127,26 +133,22 @@ class ArchivedSyncResult:
     state = attr.ib(type=StateMap[EventBase])
     account_data = attr.ib(type=List[JsonDict])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
         return bool(self.timeline or self.state or self.account_data)
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class InvitedSyncResult:
     room_id = attr.ib(type=str)
     invite = attr.ib(type=EventBase)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Invited rooms should always be reported to the client"""
         return True
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class GroupsSyncResult:
@@ -154,11 +156,9 @@ class GroupsSyncResult:
     invite = attr.ib(type=JsonDict)
     leave = attr.ib(type=JsonDict)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         return bool(self.join or self.invite or self.leave)
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class DeviceLists:
@@ -171,13 +171,11 @@ class DeviceLists:
     changed = attr.ib(type=Collection[str])
     left = attr.ib(type=Collection[str])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         return bool(self.changed or self.left)
 
-    __bool__ = __nonzero__  # python3
 
-
-@attr.s
+@attr.s(slots=True)
 class _RoomChanges:
     """The set of room entries to include in the sync, plus the set of joined
     and left room IDs since last sync.
@@ -217,7 +215,7 @@ class SyncResult:
     device_one_time_keys_count = attr.ib(type=JsonDict)
     groups = attr.ib(type=Optional[GroupsSyncResult])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if the notifier needs to wait for more events when polling for
         events.
@@ -233,11 +231,9 @@ class SyncResult:
             or self.groups
         )
 
-    __bool__ = __nonzero__  # python3
-
 
-class SyncHandler(object):
-    def __init__(self, hs):
+class SyncHandler:
+    def __init__(self, hs: "HomeServer"):
         self.hs_config = hs.config
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
@@ -368,7 +364,7 @@ class SyncHandler(object):
         sync_config = sync_result_builder.sync_config
 
         with Measure(self.clock, "ephemeral_by_room"):
-            typing_key = since_token.typing_key if since_token else "0"
+            typing_key = since_token.typing_key if since_token else 0
 
             room_ids = sync_result_builder.joined_room_ids
 
@@ -392,7 +388,7 @@ class SyncHandler(object):
                 event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
                 ephemeral_by_room.setdefault(room_id, []).append(event_copy)
 
-            receipt_key = since_token.receipt_key if since_token else "0"
+            receipt_key = since_token.receipt_key if since_token else 0
 
             receipt_source = self.event_sources.sources["receipt"]
             receipts, receipt_key = await receipt_source.get_new_events(
@@ -523,7 +519,7 @@ class SyncHandler(object):
             if len(recents) > timeline_limit:
                 limited = True
                 recents = recents[-timeline_limit:]
-                room_key = recents[0].internal_metadata.before
+                room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
 
             prev_batch_token = now_token.copy_and_replace("room_key", room_key)
 
@@ -588,7 +584,7 @@ class SyncHandler(object):
         room_id: str,
         sync_config: SyncConfig,
         batch: TimelineBatch,
-        state: StateMap[EventBase],
+        state: MutableStateMap[EventBase],
         now_token: StreamToken,
     ) -> Optional[JsonDict]:
         """ Works out a room summary block for this room, summarising the number
@@ -710,9 +706,8 @@ class SyncHandler(object):
         ]
 
         missing_hero_state = await self.store.get_events(missing_hero_event_ids)
-        missing_hero_state = missing_hero_state.values()
 
-        for s in missing_hero_state:
+        for s in missing_hero_state.values():
             cache.set(s.state_key, s.event_id)
             state[(EventTypes.Member, s.state_key)] = s
 
@@ -736,7 +731,7 @@ class SyncHandler(object):
         since_token: Optional[StreamToken],
         now_token: StreamToken,
         full_state: bool,
-    ) -> StateMap[EventBase]:
+    ) -> MutableStateMap[EventBase]:
         """ Works out the difference in state between the start of the timeline
         and the previous sync.
 
@@ -930,7 +925,7 @@ class SyncHandler(object):
 
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
-    ) -> Optional[Dict[str, str]]:
+    ) -> Dict[str, int]:
         with Measure(self.clock, "unread_notifs_for_room_id"):
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
@@ -938,15 +933,10 @@ class SyncHandler(object):
                 receipt_type="m.read",
             )
 
-            if last_unread_event_id:
-                notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
-                    room_id, sync_config.user.to_string(), last_unread_event_id
-                )
-                return notifs
-
-        # There is no new information in this period, so your notification
-        # count is whatever it was last time.
-        return None
+            notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+                room_id, sync_config.user.to_string(), last_unread_event_id
+            )
+            return notifs
 
     async def generate_sync_result(
         self,
@@ -1306,12 +1296,11 @@ class SyncHandler(object):
         presence_source = self.event_sources.sources["presence"]
 
         since_token = sync_result_builder.since_token
+        presence_key = None
+        include_offline = False
         if since_token and not sync_result_builder.full_state:
             presence_key = since_token.presence_key
             include_offline = True
-        else:
-            presence_key = None
-            include_offline = False
 
         presence, presence_key = await presence_source.get_new_events(
             user=user,
@@ -1319,6 +1308,7 @@ class SyncHandler(object):
             is_guest=sync_config.is_guest,
             include_offline=include_offline,
         )
+        assert presence_key
         sync_result_builder.now_token = now_token.copy_and_replace(
             "presence_key", presence_key
         )
@@ -1481,7 +1471,7 @@ class SyncHandler(object):
         if rooms_changed:
             return True
 
-        stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
+        stream_id = since_token.room_key.stream
         for room_id in sync_result_builder.joined_room_ids:
             if self.store.has_room_changed_since(room_id, stream_id):
                 return True
@@ -1747,7 +1737,7 @@ class SyncHandler(object):
                             continue
 
                 leave_token = now_token.copy_and_replace(
-                    "room_key", "s%d" % (event.stream_ordering,)
+                    "room_key", RoomStreamToken(None, event.stream_ordering)
                 )
                 room_entries.append(
                     RoomSyncResultBuilder(
@@ -1769,7 +1759,7 @@ class SyncHandler(object):
         ignored_users: Set[str],
         room_builder: "RoomSyncResultBuilder",
         ephemeral: List[JsonDict],
-        tags: Optional[List[JsonDict]],
+        tags: Optional[Dict[str, Dict[str, Any]]],
         account_data: Dict[str, JsonDict],
         always_include: bool = False,
     ):
@@ -1885,7 +1875,7 @@ class SyncHandler(object):
             )
 
         if room_builder.rtype == "joined":
-            unread_notifications = {}  # type: Dict[str, str]
+            unread_notifications = {}  # type: Dict[str, int]
             room_sync = JoinedSyncResult(
                 room_id=room_id,
                 timeline=batch,
@@ -1894,14 +1884,16 @@ class SyncHandler(object):
                 account_data=account_data_events,
                 unread_notifications=unread_notifications,
                 summary=summary,
+                unread_count=0,
             )
 
             if room_sync or always_include:
                 notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
-                if notifs is not None:
-                    unread_notifications["notification_count"] = notifs["notify_count"]
-                    unread_notifications["highlight_count"] = notifs["highlight_count"]
+                unread_notifications["notification_count"] = notifs["notify_count"]
+                unread_notifications["highlight_count"] = notifs["highlight_count"]
+
+                room_sync.unread_count = notifs["unread_count"]
 
                 sync_result_builder.joined.append(room_sync)
 
@@ -2032,7 +2024,7 @@ def _calculate_state(
     return {event_id_to_key[e]: e for e in state_ids}
 
 
-@attr.s
+@attr.s(slots=True)
 class SyncResultBuilder:
     """Used to help build up a new SyncResult for a user
 
@@ -2068,8 +2060,8 @@ class SyncResultBuilder:
     to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
 
 
-@attr.s
-class RoomSyncResultBuilder(object):
+@attr.s(slots=True)
+class RoomSyncResultBuilder:
     """Stores information needed to create either a `JoinedSyncResult` or
     `ArchivedSyncResult`.
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index a86ac0150e..3cbfc2d780 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -14,10 +14,11 @@
 # limitations under the License.
 
 import logging
+import random
 from collections import namedtuple
 from typing import TYPE_CHECKING, List, Set, Tuple
 
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import AuthError, ShadowBanError, SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.streams import TypingStream
 from synapse.types import UserID, get_domain_from_id
@@ -227,9 +228,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             self._stopped_typing(member)
             return
 
-    async def started_typing(self, target_user, auth_user, room_id, timeout):
+    async def started_typing(self, target_user, requester, room_id, timeout):
         target_user_id = target_user.to_string()
-        auth_user_id = auth_user.to_string()
+        auth_user_id = requester.user.to_string()
 
         if not self.is_mine_id(target_user_id):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -237,6 +238,11 @@ class TypingWriterHandler(FollowerTypingHandler):
         if target_user_id != auth_user_id:
             raise AuthError(400, "Cannot set another user's typing state")
 
+        if requester.shadow_banned:
+            # We randomly sleep a bit just to annoy the requester.
+            await self.clock.sleep(random.randint(1, 10))
+            raise ShadowBanError()
+
         await self.auth.check_user_in_room(room_id, target_user_id)
 
         logger.debug("%s has started typing in %s", target_user_id, room_id)
@@ -256,9 +262,9 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._push_update(member=member, typing=True)
 
-    async def stopped_typing(self, target_user, auth_user, room_id):
+    async def stopped_typing(self, target_user, requester, room_id):
         target_user_id = target_user.to_string()
-        auth_user_id = auth_user.to_string()
+        auth_user_id = requester.user.to_string()
 
         if not self.is_mine_id(target_user_id):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -266,6 +272,11 @@ class TypingWriterHandler(FollowerTypingHandler):
         if target_user_id != auth_user_id:
             raise AuthError(400, "Cannot set another user's typing state")
 
+        if requester.shadow_banned:
+            # We randomly sleep a bit just to annoy the requester.
+            await self.clock.sleep(random.randint(1, 10))
+            raise ShadowBanError()
+
         await self.auth.check_user_in_room(room_id, target_user_id)
 
         logger.debug("%s has stopped typing in %s", target_user_id, room_id)
@@ -401,7 +412,7 @@ class TypingWriterHandler(FollowerTypingHandler):
         raise Exception("Typing writer instance got typing info over replication")
 
 
-class TypingNotificationEventSource(object):
+class TypingNotificationEventSource:
     def __init__(self, hs):
         self.hs = hs
         self.clock = hs.get_clock()
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index a011e9fe29..9146dc1a3b 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -16,13 +16,12 @@
 import logging
 from typing import Any
 
-from canonicaljson import json
-
 from twisted.web.client import PartialDownloadError
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
         except PartialDownloadError as pde:
             # Twisted is silly
             data = pde.response
-            resp_body = json.loads(data.decode("utf-8"))
+            resp_body = json_decoder.decode(data.decode("utf-8"))
 
         if "success" in resp_body:
             # Note that we do NOT check the hostname here: we explicitly
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 521b6d620d..e21f8dbc58 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -234,7 +234,7 @@ class UserDirectoryHandler(StateDeltasHandler):
     async def _handle_room_publicity_change(
         self, room_id, prev_event_id, event_id, typ
     ):
-        """Handle a room having potentially changed from/to world_readable/publically
+        """Handle a room having potentially changed from/to world_readable/publicly
         joinable.
 
         Args:
@@ -388,9 +388,15 @@ class UserDirectoryHandler(StateDeltasHandler):
 
         prev_name = prev_event.content.get("displayname")
         new_name = event.content.get("displayname")
+        # If the new name is an unexpected form, do not update the directory.
+        if not isinstance(new_name, str):
+            new_name = prev_name
 
         prev_avatar = prev_event.content.get("avatar_url")
         new_avatar = event.content.get("avatar_url")
+        # If the new avatar is an unexpected form, do not update the directory.
+        if not isinstance(new_avatar, str):
+            new_avatar = prev_avatar
 
         if prev_name != new_name or prev_avatar != new_avatar:
             await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8aeb70cdec..13fcab3378 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -19,7 +19,7 @@ import urllib
 from io import BytesIO
 
 import treq
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
 from netaddr import IPAddress
 from prometheus_client import Counter
 from zope.interface import implementer, provider
@@ -47,6 +47,7 @@ from synapse.http import (
 from synapse.http.proxyagent import ProxyAgent
 from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 
 logger = logging.getLogger(__name__)
@@ -85,7 +86,7 @@ def _make_scheduler(reactor):
     return _scheduler
 
 
-class IPBlacklistingResolver(object):
+class IPBlacklistingResolver:
     """
     A proxy for reactor.nameResolver which only produces non-blacklisted IP
     addresses, preventing DNS rebinding attacks on URL preview.
@@ -132,7 +133,7 @@ class IPBlacklistingResolver(object):
             r.resolutionComplete()
 
         @provider(IResolutionReceiver)
-        class EndpointReceiver(object):
+        class EndpointReceiver:
             @staticmethod
             def resolutionBegan(resolutionInProgress):
                 pass
@@ -191,7 +192,7 @@ class BlacklistingAgentWrapper(Agent):
         )
 
 
-class SimpleHttpClient(object):
+class SimpleHttpClient:
     """
     A simple, no-frills HTTP client with methods that wrap up common ways of
     using HTTP in Matrix
@@ -243,7 +244,7 @@ class SimpleHttpClient(object):
             )
 
             @implementer(IReactorPluggableNameResolver)
-            class Reactor(object):
+            class Reactor:
                 def __getattr__(_self, attr):
                     if attr == "nameResolver":
                         return nameResolver
@@ -391,7 +392,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -433,7 +434,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -463,7 +464,7 @@ class SimpleHttpClient(object):
             actual_headers.update(headers)
 
         body = await self.get_raw(uri, args, headers=headers)
-        return json.loads(body.decode("utf-8"))
+        return json_decoder.decode(body.decode("utf-8"))
 
     async def put_json(self, uri, json_body, args={}, headers=None):
         """ Puts some json to the given URI.
@@ -506,7 +507,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index be7b2ceb8e..856e28454f 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -31,7 +31,7 @@ class ProxyConnectError(ConnectError):
 
 
 @implementer(IStreamClientEndpoint)
-class HTTPConnectProxyEndpoint(object):
+class HTTPConnectProxyEndpoint:
     """An Endpoint implementation which will send a CONNECT request to an http proxy
 
     Wraps an existing HostnameEndpoint for the proxy.
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 369bf9c2fc..83d6196d4a 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
 
 
 @implementer(IAgent)
-class MatrixFederationAgent(object):
+class MatrixFederationAgent:
     """An Agent-like thing which provides a `request` method which correctly
     handles resolving matrix server names when using matrix://. Handles standard
     https URIs as normal.
@@ -134,8 +134,8 @@ class MatrixFederationAgent(object):
             and not _is_ip_literal(parsed_uri.hostname)
             and not parsed_uri.port
         ):
-            well_known_result = yield self._well_known_resolver.get_well_known(
-                parsed_uri.hostname
+            well_known_result = yield defer.ensureDeferred(
+                self._well_known_resolver.get_well_known(parsed_uri.hostname)
             )
             delegated_server = well_known_result.delegated_server
 
@@ -175,7 +175,7 @@ class MatrixFederationAgent(object):
 
 
 @implementer(IAgentEndpointFactory)
-class MatrixHostnameEndpointFactory(object):
+class MatrixHostnameEndpointFactory:
     """Factory for MatrixHostnameEndpoint for parsing to an Agent.
     """
 
@@ -198,7 +198,7 @@ class MatrixHostnameEndpointFactory(object):
 
 
 @implementer(IStreamClientEndpoint)
-class MatrixHostnameEndpoint(object):
+class MatrixHostnameEndpoint:
     """An endpoint that resolves matrix:// URLs using Matrix server name
     resolution (i.e. via SRV). Does not check for well-known delegation.
 
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 2ede90a9b1..d9620032d2 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -33,7 +33,7 @@ SERVER_CACHE = {}
 
 
 @attr.s(slots=True, frozen=True)
-class Server(object):
+class Server:
     """
     Our record of an individual server which can be tried to reach a destination.
 
@@ -96,7 +96,7 @@ def _sort_server_list(server_list):
     return results
 
 
-class SrvResolver(object):
+class SrvResolver:
     """Interface to the dns client to do SRV lookups, with result caching.
 
     The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 89a3b041ce..a306faa267 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -13,10 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import logging
 import random
 import time
+from typing import Callable, Dict, Optional, Tuple
 
 import attr
 
@@ -24,9 +24,10 @@ from twisted.internet import defer
 from twisted.web.client import RedirectAgent, readBody
 from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
 
 from synapse.logging.context import make_deferred_yieldable
-from synapse.util import Clock
+from synapse.util import Clock, json_decoder
 from synapse.util.caches.ttlcache import TTLCache
 from synapse.util.metrics import Measure
 
@@ -70,11 +71,11 @@ _had_valid_well_known_cache = TTLCache("had-valid-well-known")
 
 
 @attr.s(slots=True, frozen=True)
-class WellKnownLookupResult(object):
+class WellKnownLookupResult:
     delegated_server = attr.ib()
 
 
-class WellKnownResolver(object):
+class WellKnownResolver:
     """Handles well-known lookups for matrix servers.
     """
 
@@ -100,15 +101,14 @@ class WellKnownResolver(object):
         self._well_known_agent = RedirectAgent(agent)
         self.user_agent = user_agent
 
-    @defer.inlineCallbacks
-    def get_well_known(self, server_name):
+    async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
         """Attempt to fetch and parse a .well-known file for the given server
 
         Args:
-            server_name (bytes): name of the server, from the requested url
+            server_name: name of the server, from the requested url
 
         Returns:
-            Deferred[WellKnownLookupResult]: The result of the lookup
+            The result of the lookup
         """
         try:
             prev_result, expiry, ttl = self._well_known_cache.get_with_expiry(
@@ -125,7 +125,9 @@ class WellKnownResolver(object):
         # requests for the same server in parallel?
         try:
             with Measure(self._clock, "get_well_known"):
-                result, cache_period = yield self._fetch_well_known(server_name)
+                result, cache_period = await self._fetch_well_known(
+                    server_name
+                )  # type: Tuple[Optional[bytes], float]
 
         except _FetchWellKnownFailure as e:
             if prev_result and e.temporary:
@@ -154,18 +156,17 @@ class WellKnownResolver(object):
 
         return WellKnownLookupResult(delegated_server=result)
 
-    @defer.inlineCallbacks
-    def _fetch_well_known(self, server_name):
+    async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
         """Actually fetch and parse a .well-known, without checking the cache
 
         Args:
-            server_name (bytes): name of the server, from the requested url
+            server_name: name of the server, from the requested url
 
         Raises:
             _FetchWellKnownFailure if we fail to lookup a result
 
         Returns:
-            Deferred[Tuple[bytes,int]]: The lookup result and cache period.
+            The lookup result and cache period.
         """
 
         had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
@@ -173,7 +174,7 @@ class WellKnownResolver(object):
         # We do this in two steps to differentiate between possibly transient
         # errors (e.g. can't connect to host, 503 response) and more permenant
         # errors (such as getting a 404 response).
-        response, body = yield self._make_well_known_request(
+        response, body = await self._make_well_known_request(
             server_name, retry=had_valid_well_known
         )
 
@@ -181,7 +182,7 @@ class WellKnownResolver(object):
             if response.code != 200:
                 raise Exception("Non-200 response %s" % (response.code,))
 
-            parsed_body = json.loads(body.decode("utf-8"))
+            parsed_body = json_decoder.decode(body.decode("utf-8"))
             logger.info("Response from .well-known: %s", parsed_body)
 
             result = parsed_body["m.server"].encode("ascii")
@@ -216,20 +217,20 @@ class WellKnownResolver(object):
 
         return result, cache_period
 
-    @defer.inlineCallbacks
-    def _make_well_known_request(self, server_name, retry):
+    async def _make_well_known_request(
+        self, server_name: bytes, retry: bool
+    ) -> Tuple[IResponse, bytes]:
         """Make the well known request.
 
         This will retry the request if requested and it fails (with unable
         to connect or receives a 5xx error).
 
         Args:
-            server_name (bytes)
-            retry (bool): Whether to retry the request if it fails.
+            server_name: name of the server, from the requested url
+            retry: Whether to retry the request if it fails.
 
         Returns:
-            Deferred[tuple[IResponse, bytes]] Returns the response object and
-            body. Response may be a non-200 response.
+            Returns the response object and body. Response may be a non-200 response.
         """
         uri = b"https://%s/.well-known/matrix/server" % (server_name,)
         uri_str = uri.decode("ascii")
@@ -244,12 +245,12 @@ class WellKnownResolver(object):
 
             logger.info("Fetching %s", uri_str)
             try:
-                response = yield make_deferred_yieldable(
+                response = await make_deferred_yieldable(
                     self._well_known_agent.request(
                         b"GET", uri, headers=Headers(headers)
                     )
                 )
-                body = yield make_deferred_yieldable(readBody(response))
+                body = await make_deferred_yieldable(readBody(response))
 
                 if 500 <= response.code < 600:
                     raise Exception("Non-200 response %s" % (response.code,))
@@ -266,21 +267,24 @@ class WellKnownResolver(object):
                 logger.info("Error fetching %s: %s. Retrying", uri_str, e)
 
             # Sleep briefly in the hopes that they come back up
-            yield self._clock.sleep(0.5)
+            await self._clock.sleep(0.5)
 
 
-def _cache_period_from_headers(headers, time_now=time.time):
+def _cache_period_from_headers(
+    headers: Headers, time_now: Callable[[], float] = time.time
+) -> Optional[float]:
     cache_controls = _parse_cache_control(headers)
 
     if b"no-store" in cache_controls:
         return 0
 
     if b"max-age" in cache_controls:
-        try:
-            max_age = int(cache_controls[b"max-age"])
-            return max_age
-        except ValueError:
-            pass
+        max_age = cache_controls[b"max-age"]
+        if max_age:
+            try:
+                return int(max_age)
+            except ValueError:
+                pass
 
     expires = headers.getRawHeaders(b"expires")
     if expires is not None:
@@ -296,7 +300,7 @@ def _cache_period_from_headers(headers, time_now=time.time):
     return None
 
 
-def _parse_cache_control(headers):
+def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
     cache_controls = {}
     for hdr in headers.getRawHeaders(b"cache-control", []):
         for directive in hdr.split(b","):
@@ -307,7 +311,7 @@ def _parse_cache_control(headers):
     return cache_controls
 
 
-@attr.s()
+@attr.s(slots=True)
 class _FetchWellKnownFailure(Exception):
     # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
     # a temporary failure.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 738be43f46..3c86cbc546 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -54,6 +54,7 @@ from synapse.logging.opentracing import (
     start_active_span,
     tags,
 )
+from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -75,8 +76,8 @@ MAXINT = sys.maxsize
 _next_id = 1
 
 
-@attr.s(frozen=True)
-class MatrixFederationRequest(object):
+@attr.s(slots=True, frozen=True)
+class MatrixFederationRequest:
     method = attr.ib()
     """HTTP method
     :type: str
@@ -164,7 +165,9 @@ async def _handle_json_response(
     try:
         check_content_type_is_json(response.headers)
 
-        d = treq.json_content(response)
+        # Use the custom JSON decoder (partially re-implements treq.json_content).
+        d = treq.text_content(response, encoding="utf-8")
+        d.addCallback(json_decoder.decode)
         d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
 
         body = await make_deferred_yieldable(d)
@@ -203,7 +206,7 @@ async def _handle_json_response(
     return body
 
 
-class MatrixFederationHttpClient(object):
+class MatrixFederationHttpClient:
     """HTTP client used to talk to other homeservers over the federation
     protocol. Send client certificates and signs requests.
 
@@ -226,7 +229,7 @@ class MatrixFederationHttpClient(object):
         )
 
         @implementer(IReactorPluggableNameResolver)
-        class Reactor(object):
+        class Reactor:
             def __getattr__(_self, attr):
                 if attr == "nameResolver":
                     return nameResolver
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index b58ae3d9db..cd94e789e8 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -145,7 +145,7 @@ LaterGauge(
 )
 
 
-class RequestMetrics(object):
+class RequestMetrics:
     def start(self, time_sec, name, method):
         self.start = time_sec
         self.start_context = current_context()
diff --git a/synapse/http/server.py b/synapse/http/server.py
index ffe6cfa09e..996a31a9ec 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,12 +22,13 @@ import types
 import urllib
 from http import HTTPStatus
 from io import BytesIO
-from typing import Any, Callable, Dict, Tuple, Union
+from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
 
 import jinja2
-from canonicaljson import encode_canonical_json, encode_pretty_printed_json
+from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
+from zope.interface import implementer
 
-from twisted.internet import defer
+from twisted.internet import defer, interfaces
 from twisted.python import failure
 from twisted.web import resource
 from twisted.web.server import NOT_DONE_YET, Request
@@ -173,7 +174,7 @@ def wrap_async_request_handler(h):
     return preserve_fn(wrapped_async_request_handler)
 
 
-class HttpServer(object):
+class HttpServer:
     """ Interface for registering callbacks on a HTTP server
     """
 
@@ -499,6 +500,90 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
     pass
 
 
+@implementer(interfaces.IPushProducer)
+class _ByteProducer:
+    """
+    Iteratively write bytes to the request.
+    """
+
+    # The minimum number of bytes for each chunk. Note that the last chunk will
+    # usually be smaller than this.
+    min_chunk_size = 1024
+
+    def __init__(
+        self, request: Request, iterator: Iterator[bytes],
+    ):
+        self._request = request
+        self._iterator = iterator
+        self._paused = False
+
+        # Register the producer and start producing data.
+        self._request.registerProducer(self, True)
+        self.resumeProducing()
+
+    def _send_data(self, data: List[bytes]) -> None:
+        """
+        Send a list of bytes as a chunk of a response.
+        """
+        if not data:
+            return
+        self._request.write(b"".join(data))
+
+    def pauseProducing(self) -> None:
+        self._paused = True
+
+    def resumeProducing(self) -> None:
+        # We've stopped producing in the meantime (note that this might be
+        # re-entrant after calling write).
+        if not self._request:
+            return
+
+        self._paused = False
+
+        # Write until there's backpressure telling us to stop.
+        while not self._paused:
+            # Get the next chunk and write it to the request.
+            #
+            # The output of the JSON encoder is buffered and coalesced until
+            # min_chunk_size is reached. This is because JSON encoders produce
+            # very small output per iteration and the Request object converts
+            # each call to write() to a separate chunk. Without this there would
+            # be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n").
+            #
+            # Note that buffer stores a list of bytes (instead of appending to
+            # bytes) to hopefully avoid many allocations.
+            buffer = []
+            buffered_bytes = 0
+            while buffered_bytes < self.min_chunk_size:
+                try:
+                    data = next(self._iterator)
+                    buffer.append(data)
+                    buffered_bytes += len(data)
+                except StopIteration:
+                    # The entire JSON object has been serialized, write any
+                    # remaining data, finalize the producer and the request, and
+                    # clean-up any references.
+                    self._send_data(buffer)
+                    self._request.unregisterProducer()
+                    self._request.finish()
+                    self.stopProducing()
+                    return
+
+            self._send_data(buffer)
+
+    def stopProducing(self) -> None:
+        # Clear a circular reference.
+        self._request = None
+
+
+def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
+    """
+    Encode an object into JSON. Returns an iterator of bytes.
+    """
+    for chunk in json_encoder.iterencode(json_object):
+        yield chunk.encode("utf-8")
+
+
 def respond_with_json(
     request: Request,
     code: int,
@@ -533,15 +618,22 @@ def respond_with_json(
         return None
 
     if pretty_print:
-        json_bytes = encode_pretty_printed_json(json_object) + b"\n"
+        encoder = iterencode_pretty_printed_json
     else:
         if canonical_json or synapse.events.USE_FROZEN_DICTS:
-            # canonicaljson already encodes to bytes
-            json_bytes = encode_canonical_json(json_object)
+            encoder = iterencode_canonical_json
         else:
-            json_bytes = json_encoder.encode(json_object).encode("utf-8")
+            encoder = _encode_json_bytes
+
+    request.setResponseCode(code)
+    request.setHeader(b"Content-Type", b"application/json")
+    request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
 
-    return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
+    if send_cors:
+        set_cors_headers(request)
+
+    _ByteProducer(request, encoder(json_object))
+    return NOT_DONE_YET
 
 
 def respond_with_json_bytes(
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index a34e5ead88..fd90ba7828 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -17,9 +17,8 @@
 
 import logging
 
-from canonicaljson import json
-
 from synapse.api.errors import Codes, SynapseError
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
         return None
 
     try:
-        content = json.loads(content_bytes.decode("utf-8"))
+        content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
         logger.warning("Unable to parse JSON: %s", e)
         raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
@@ -257,7 +256,7 @@ def assert_params_in_dict(body, required):
         raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
 
 
-class RestServlet(object):
+class RestServlet:
 
     """ A Synapse REST Servlet.
 
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 7372450b45..144506c8f2 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -55,7 +55,7 @@ def stdlib_log_level_to_twisted(level: str) -> LogLevel:
 
 @attr.s
 @implementer(ILogObserver)
-class LogContextObserver(object):
+class LogContextObserver:
     """
     An ILogObserver which adds Synapse-specific log context information.
 
@@ -169,7 +169,7 @@ class OutputPipeType(Values):
 
 
 @attr.s
-class DrainConfiguration(object):
+class DrainConfiguration:
     name = attr.ib()
     type = attr.ib()
     location = attr.ib()
@@ -177,7 +177,7 @@ class DrainConfiguration(object):
 
 
 @attr.s
-class NetworkJSONTerseOptions(object):
+class NetworkJSONTerseOptions:
     maximum_buffer = attr.ib(type=int)
 
 
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index c0b9384189..1b8916cfa2 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -152,7 +152,7 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
 
 @attr.s
 @implementer(IPushProducer)
-class LogProducer(object):
+class LogProducer:
     """
     An IPushProducer that writes logs from its buffer to its transport when it
     is resumed.
@@ -190,7 +190,7 @@ class LogProducer(object):
 
 @attr.s
 @implementer(ILogObserver)
-class TerseJSONToTCPLogObserver(object):
+class TerseJSONToTCPLogObserver:
     """
     An IObserver that writes JSON logs to a TCP target.
 
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index cbeeb870cb..2e282d9d67 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -74,7 +74,7 @@ except Exception:
 get_thread_id = threading.get_ident
 
 
-class ContextResourceUsage(object):
+class ContextResourceUsage:
     """Object for tracking the resources used by a log context
 
     Attributes:
@@ -179,7 +179,7 @@ class ContextResourceUsage(object):
 LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
 
 
-class _Sentinel(object):
+class _Sentinel:
     """Sentinel to represent the root context"""
 
     __slots__ = ["previous_context", "finished", "request", "scope", "tag"]
@@ -217,16 +217,14 @@ class _Sentinel(object):
     def record_event_fetch(self, event_count):
         pass
 
-    def __nonzero__(self):
+    def __bool__(self):
         return False
 
-    __bool__ = __nonzero__  # python3
-
 
 SENTINEL_CONTEXT = _Sentinel()
 
 
-class LoggingContext(object):
+class LoggingContext:
     """Additional context for log formatting. Contexts are scoped within a
     "with" block.
 
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 21dbd9f415..e58850faff 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -172,11 +172,11 @@ from functools import wraps
 from typing import TYPE_CHECKING, Dict, Optional, Type
 
 import attr
-from canonicaljson import json
 
 from twisted.internet import defer
 
 from synapse.config import ConfigError
+from synapse.util import json_decoder, json_encoder
 
 if TYPE_CHECKING:
     from synapse.http.site import SynapseRequest
@@ -185,7 +185,7 @@ if TYPE_CHECKING:
 # Helper class
 
 
-class _DummyTagNames(object):
+class _DummyTagNames:
     """wrapper of opentracings tags. We need to have them if we
     want to reference them without opentracing around. Clearly they
     should never actually show up in a trace. `set_tags` overwrites
@@ -499,7 +499,9 @@ def start_active_span_from_edu(
     if opentracing is None:
         return _noop_context_manager()
 
-    carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
+    carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
+        "opentracing", {}
+    )
     context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
     _references = [
         opentracing.child_of(span_context_from_string(x))
@@ -507,7 +509,7 @@ def start_active_span_from_edu(
     ]
 
     # For some reason jaeger decided not to support the visualization of multiple parent
-    # spans or explicitely show references. I include the span context as a tag here as
+    # spans or explicitly show references. I include the span context as a tag here as
     # an aid to people debugging but it's really not an ideal solution.
 
     references += _references
@@ -690,7 +692,7 @@ def active_span_context_as_string():
         opentracing.tracer.inject(
             opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
         )
-    return json.dumps(carrier)
+    return json_encoder.encode(carrier)
 
 
 @only_if_tracing
@@ -699,7 +701,7 @@ def span_context_from_string(carrier):
     Returns:
         The active span context decoded from a string.
     """
-    carrier = json.loads(carrier)
+    carrier = json_decoder.decode(carrier)
     return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
 
 
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index fea774e2e5..becf66dd86 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -29,11 +29,11 @@ def _log_debug_as_f(f, msg, msg_args):
         lineno = f.__code__.co_firstlineno
         pathname = f.__code__.co_filename
 
-        record = logging.LogRecord(
+        record = logger.makeRecord(
             name=name,
             level=logging.DEBUG,
-            pathname=pathname,
-            lineno=lineno,
+            fn=pathname,
+            lno=lineno,
             msg=msg,
             args=msg_args,
             exc_info=None,
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 6035672698..a1f7ca3449 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -51,7 +51,7 @@ all_gauges = {}  # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollec
 HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
 
 
-class RegistryProxy(object):
+class RegistryProxy:
     @staticmethod
     def collect():
         for metric in REGISTRY.collect():
@@ -59,8 +59,8 @@ class RegistryProxy(object):
                 yield metric
 
 
-@attr.s(hash=True)
-class LaterGauge(object):
+@attr.s(slots=True, hash=True)
+class LaterGauge:
 
     name = attr.ib(type=str)
     desc = attr.ib(type=str)
@@ -100,7 +100,7 @@ class LaterGauge(object):
         all_gauges[self.name] = self
 
 
-class InFlightGauge(object):
+class InFlightGauge:
     """Tracks number of things (e.g. requests, Measure blocks, etc) in flight
     at any given time.
 
@@ -205,8 +205,8 @@ class InFlightGauge(object):
         all_gauges[self.name] = self
 
 
-@attr.s(hash=True)
-class BucketCollector(object):
+@attr.s(slots=True, hash=True)
+class BucketCollector:
     """
     Like a Histogram, but allows buckets to be point-in-time instead of
     incrementally added to.
@@ -269,7 +269,7 @@ class BucketCollector(object):
 #
 
 
-class CPUMetrics(object):
+class CPUMetrics:
     def __init__(self):
         ticks_per_sec = 100
         try:
@@ -329,7 +329,7 @@ gc_time = Histogram(
 )
 
 
-class GCCounts(object):
+class GCCounts:
     def collect(self):
         cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
         for n, m in enumerate(gc.get_count()):
@@ -347,7 +347,7 @@ if not running_on_pypy:
 #
 
 
-class PyPyGCStats(object):
+class PyPyGCStats:
     def collect(self):
 
         # @stats is a pretty-printer object with __str__() returning a nice table,
@@ -482,7 +482,7 @@ build_info.labels(
 last_ticked = time.time()
 
 
-class ReactorLastSeenMetric(object):
+class ReactorLastSeenMetric:
     def collect(self):
         cm = GaugeMetricFamily(
             "python_twisted_reactor_last_seen",
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index f766d16db6..5b73463504 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -105,7 +105,7 @@ _background_processes_active_since_last_scrape = set()  # type: Set[_BackgroundP
 _bg_metrics_lock = threading.Lock()
 
 
-class _Collector(object):
+class _Collector:
     """A custom metrics collector for the background process metrics.
 
     Ensures that all of the metrics are up-to-date with any in-flight processes
@@ -140,7 +140,7 @@ class _Collector(object):
 REGISTRY.register(_Collector())
 
 
-class _BackgroundProcess(object):
+class _BackgroundProcess:
     def __init__(self, desc, ctx):
         self.desc = desc
         self._context = ctx
@@ -175,7 +175,7 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
     It returns a Deferred which completes when the function completes, but it doesn't
     follow the synapse logcontext rules, which makes it appropriate for passing to
     clock.looping_call and friends (or for firing-and-forgetting in the middle of a
-    normal synapse inlineCallbacks function).
+    normal synapse async function).
 
     Args:
         desc: a description for this background process type
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index c2fb757d9a..fcbd5378c4 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -31,7 +31,7 @@ __all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"
 logger = logging.getLogger(__name__)
 
 
-class ModuleApi(object):
+class ModuleApi:
     """A proxy object that gets passed to various plugin modules so they
     can register new users etc if necessary.
     """
@@ -167,8 +167,10 @@ class ModuleApi(object):
             external_id: id on that system
             user_id: complete mxid that it is mapped to
         """
-        return self._store.record_user_external_id(
-            auth_provider_id, remote_user_id, registered_user_id
+        return defer.ensureDeferred(
+            self._store.record_user_external_id(
+                auth_provider_id, remote_user_id, registered_user_id
+            )
         )
 
     def generate_short_term_login_token(
@@ -223,7 +225,9 @@ class ModuleApi(object):
         Returns:
             Deferred[object]: result of func
         """
-        return self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
+        return defer.ensureDeferred(
+            self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
+        )
 
     def complete_sso_login(
         self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
diff --git a/synapse/notifier.py b/synapse/notifier.py
index dfb096e589..a8fd3ef886 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -42,7 +42,7 @@ from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.streams.config import PaginationConfig
-from synapse.types import Collection, StreamToken, UserID
+from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
 from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
 from synapse.util.metrics import Measure
 from synapse.visibility import filter_events_for_client
@@ -68,7 +68,7 @@ def count(func: Callable[[T], bool], it: Iterable[T]) -> int:
     return n
 
 
-class _NotificationListener(object):
+class _NotificationListener:
     """ This represents a single client connection to the events stream.
     The events stream handler will have yielded to the deferred, so to
     notify the handler it is sufficient to resolve the deferred.
@@ -80,7 +80,7 @@ class _NotificationListener(object):
         self.deferred = deferred
 
 
-class _NotifierUserStream(object):
+class _NotifierUserStream:
     """This represents a user connected to the event stream.
     It tracks the most recent stream token for that user.
     At a given point a user may have a number of streams listening for
@@ -112,7 +112,9 @@ class _NotifierUserStream(object):
         with PreserveLoggingContext():
             self.notify_deferred = ObservableDeferred(defer.Deferred())
 
-    def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
+    def notify(
+        self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
+    ):
         """Notify any listeners for this user of a new event from an
         event source.
         Args:
@@ -162,13 +164,11 @@ class _NotifierUserStream(object):
 
 
 class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
-    def __nonzero__(self):
+    def __bool__(self):
         return bool(self.events)
 
-    __bool__ = __nonzero__  # python3
-
 
-class Notifier(object):
+class Notifier:
     """ This class is responsible for notifying any listeners when there are
     new events available for it.
 
@@ -187,7 +187,7 @@ class Notifier(object):
         self.store = hs.get_datastore()
         self.pending_new_room_events = (
             []
-        )  # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]]
+        )  # type: List[Tuple[int, EventBase, Collection[UserID]]]
 
         # Called when there are new things to stream over replication
         self.replication_callbacks = []  # type: List[Callable[[], None]]
@@ -198,6 +198,7 @@ class Notifier(object):
 
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
+        self._pusher_pool = hs.get_pusherpool()
 
         self.federation_sender = None
         if hs.should_send_federation():
@@ -247,7 +248,7 @@ class Notifier(object):
         event: EventBase,
         room_stream_id: int,
         max_room_stream_id: int,
-        extra_users: Collection[Union[str, UserID]] = [],
+        extra_users: Collection[UserID] = [],
     ):
         """ Used by handlers to inform the notifier something has happened
         in the room, room event wise.
@@ -274,47 +275,68 @@ class Notifier(object):
         """
         pending = self.pending_new_room_events
         self.pending_new_room_events = []
+
+        users = set()  # type: Set[UserID]
+        rooms = set()  # type: Set[str]
+
         for room_stream_id, event, extra_users in pending:
             if room_stream_id > max_room_stream_id:
                 self.pending_new_room_events.append(
                     (room_stream_id, event, extra_users)
                 )
             else:
-                self._on_new_room_event(event, room_stream_id, extra_users)
+                if (
+                    event.type == EventTypes.Member
+                    and event.membership == Membership.JOIN
+                ):
+                    self._user_joined_room(event.state_key, event.room_id)
+
+                users.update(extra_users)
+                rooms.add(event.room_id)
+
+        if users or rooms:
+            self.on_new_event(
+                "room_key",
+                RoomStreamToken(None, max_room_stream_id),
+                users=users,
+                rooms=rooms,
+            )
+            self._on_updated_room_token(max_room_stream_id)
+
+    def _on_updated_room_token(self, max_room_stream_id: int):
+        """Poke services that might care that the room position has been
+        updated.
+        """
 
-    def _on_new_room_event(
-        self,
-        event: EventBase,
-        room_stream_id: int,
-        extra_users: Collection[Union[str, UserID]] = [],
-    ):
-        """Notify any user streams that are interested in this room event"""
         # poke any interested application service.
         run_as_background_process(
-            "notify_app_services", self._notify_app_services, room_stream_id
+            "_notify_app_services", self._notify_app_services, max_room_stream_id
         )
 
-        if self.federation_sender:
-            self.federation_sender.notify_new_events(room_stream_id)
-
-        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
-            self._user_joined_room(event.state_key, event.room_id)
-
-        self.on_new_event(
-            "room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
+        run_as_background_process(
+            "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id
         )
 
-    async def _notify_app_services(self, room_stream_id: int):
+        if self.federation_sender:
+            self.federation_sender.notify_new_events(max_room_stream_id)
+
+    async def _notify_app_services(self, max_room_stream_id: int):
         try:
-            await self.appservice_handler.notify_interested_services(room_stream_id)
+            await self.appservice_handler.notify_interested_services(max_room_stream_id)
         except Exception:
             logger.exception("Error notifying application services of event")
 
+    async def _notify_pusher_pool(self, max_room_stream_id: int):
+        try:
+            await self._pusher_pool.on_new_notifications(max_room_stream_id)
+        except Exception:
+            logger.exception("Error pusher pool of event")
+
     def on_new_event(
         self,
         stream_key: str,
-        new_token: int,
-        users: Collection[Union[str, UserID]] = [],
+        new_token: Union[int, RoomStreamToken],
+        users: Collection[UserID] = [],
         rooms: Collection[str] = [],
     ):
         """ Used to inform listeners that something has happened event wise.
@@ -432,8 +454,9 @@ class Notifier(object):
         If explicit_room_id is set, that room will be polled for events only if
         it is world readable or the user has joined the room.
         """
-        from_token = pagination_config.from_token
-        if not from_token:
+        if pagination_config.from_token:
+            from_token = pagination_config.from_token
+        else:
             from_token = self.event_sources.get_current_token()
 
         limit = pagination_config.limit
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 0d23142653..fabc9ba126 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -22,7 +22,7 @@ from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
 logger = logging.getLogger(__name__)
 
 
-class ActionGenerator(object):
+class ActionGenerator:
     def __init__(self, hs):
         self.hs = hs
         self.clock = hs.get_clock()
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index e7fcee0e87..c440f2545c 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,8 +19,10 @@ from collections import namedtuple
 
 from prometheus_client import Counter
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RelationTypes
 from synapse.event_auth import get_user_power_level
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import register_cache
@@ -51,7 +53,49 @@ push_rules_delta_state_cache_metric = register_cache(
 )
 
 
-class BulkPushRuleEvaluator(object):
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+    EventTypes.Topic,
+    EventTypes.Name,
+    EventTypes.RoomAvatar,
+    EventTypes.Tombstone,
+}
+
+
+def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+    # Exclude rejected and soft-failed events.
+    if context.rejected or event.internal_metadata.is_soft_failed():
+        return False
+
+    # Exclude notices.
+    if (
+        not event.is_state()
+        and event.type == EventTypes.Message
+        and event.content.get("msgtype") == "m.notice"
+    ):
+        return False
+
+    # Exclude edits.
+    relates_to = event.content.get("m.relates_to", {})
+    if relates_to.get("rel_type") == RelationTypes.REPLACE:
+        return False
+
+    # Mark events that have a non-empty string body as unread.
+    body = event.content.get("body")
+    if isinstance(body, str) and body:
+        return True
+
+    # Mark some state events as unread.
+    if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+        return True
+
+    # Mark encrypted events as unread.
+    if not event.is_state() and event.type == EventTypes.Encrypted:
+        return True
+
+    return False
+
+
+class BulkPushRuleEvaluator:
     """Calculates the outcome of push rules for an event for all users in the
     room at once.
     """
@@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object):
         return pl_event.content if pl_event else {}, sender_level
 
     async def action_for_event_by_user(self, event, context) -> None:
-        """Given an event and context, evaluate the push rules and insert the
-        results into the event_push_actions_staging table.
+        """Given an event and context, evaluate the push rules, check if the message
+        should increment the unread count, and insert the results into the
+        event_push_actions_staging table.
         """
+        count_as_unread = _should_count_as_unread(event, context)
+
         rules_by_user = await self._get_rules_for_event(event, context)
         actions_by_user = {}
 
@@ -172,6 +219,13 @@ class BulkPushRuleEvaluator(object):
                 if event.type == EventTypes.Member and event.state_key == uid:
                     display_name = event.content.get("displayname", None)
 
+            if count_as_unread:
+                # Add an element for the current user if the event needs to be marked as
+                # unread, so that add_push_actions_to_staging iterates over it.
+                # If the event shouldn't be marked as unread but should notify the
+                # current user, it'll be added to the dict later.
+                actions_by_user[uid] = []
+
             for rule in rules:
                 if "enabled" in rule and not rule["enabled"]:
                     continue
@@ -189,7 +243,9 @@ class BulkPushRuleEvaluator(object):
         # Mark in the DB staging area the push actions for users who should be
         # notified for this event. (This will then get handled when we persist
         # the event)
-        await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
+        await self.store.add_push_actions_to_staging(
+            event.event_id, actions_by_user, count_as_unread,
+        )
 
 
 def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -212,7 +268,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
     return True
 
 
-class RulesForRoom(object):
+class RulesForRoom:
     """Caches push rules for users in a room.
 
     This efficiently handles users joining/leaving the room by not invalidating
@@ -369,8 +425,8 @@ class RulesForRoom(object):
         Args:
             ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
                 updated with any new rules.
-            member_event_ids (list): List of event ids for membership events that
-                have happened since the last time we filled rules_by_user
+            member_event_ids (dict): Dict of user id to event id for membership events
+                that have happened since the last time we filled rules_by_user
             state_group: The state group we are currently computing push rules
                 for. Used when updating the cache.
         """
@@ -390,34 +446,19 @@ class RulesForRoom(object):
         if logger.isEnabledFor(logging.DEBUG):
             logger.debug("Found members %r: %r", self.room_id, members.values())
 
-        interested_in_user_ids = {
+        user_ids = {
             user_id
             for user_id, membership in members.values()
             if membership == Membership.JOIN
         }
 
-        logger.debug("Joined: %r", interested_in_user_ids)
-
-        if_users_with_pushers = await self.store.get_if_users_have_pushers(
-            interested_in_user_ids, on_invalidate=self.invalidate_all_cb
-        )
-
-        user_ids = {
-            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
-        }
-
-        logger.debug("With pushers: %r", user_ids)
-
-        users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
-            self.room_id, on_invalidate=self.invalidate_all_cb
-        )
-
-        logger.debug("With receipts: %r", users_with_receipts)
+        logger.debug("Joined: %r", user_ids)
 
-        # any users with pushers must be ours: they have pushers
-        for uid in users_with_receipts:
-            if uid in interested_in_user_ids:
-                user_ids.add(uid)
+        # Previously we only considered users with pushers or read receipts in that
+        # room. We can't do this anymore because we use push actions to calculate unread
+        # counts, which don't rely on the user having pushers or sent a read receipt into
+        # the room. Therefore we just need to filter for local users here.
+        user_ids = list(filter(self.is_mine_id, user_ids))
 
         rules_by_user = await self.store.bulk_get_push_rules(
             user_ids, on_invalidate=self.invalidate_all_cb
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 568c13eaea..28bd8ab748 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -45,7 +45,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
 INCLUDE_ALL_UNREAD_NOTIFS = False
 
 
-class EmailPusher(object):
+class EmailPusher:
     """
     A pusher that sends email notifications about events (approximately)
     when they happen.
@@ -91,7 +91,7 @@ class EmailPusher(object):
                 pass
             self.timed_call = None
 
-    def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
+    def on_new_notifications(self, max_stream_ordering):
         if self.max_stream_ordering:
             self.max_stream_ordering = max(
                 max_stream_ordering, self.max_stream_ordering
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 4c469efb20..26706bf3e1 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -49,7 +49,7 @@ http_badges_failed_counter = Counter(
 )
 
 
-class HttpPusher(object):
+class HttpPusher:
     INITIAL_BACKOFF_SEC = 1  # in seconds because that's what Twisted takes
     MAX_BACKOFF_SEC = 60 * 60
 
@@ -114,7 +114,7 @@ class HttpPusher(object):
         if should_check_for_notifs:
             self._start_processing()
 
-    def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
+    def on_new_notifications(self, max_stream_ordering):
         self.max_stream_ordering = max(
             max_stream_ordering, self.max_stream_ordering or 0
         )
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index af117fddf9..455a1acb46 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -16,8 +16,7 @@
 import email.mime.multipart
 import email.utils
 import logging
-import time
-import urllib
+import urllib.parse
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
 from typing import Iterable, List, TypeVar
@@ -93,7 +92,7 @@ ALLOWED_ATTRS = {
 # ALLOWED_SCHEMES = ["http", "https", "ftp", "mailto"]
 
 
-class Mailer(object):
+class Mailer:
     def __init__(self, hs, app_name, template_html, template_text):
         self.hs = hs
         self.template_html = template_html
@@ -124,7 +123,7 @@ class Mailer(object):
         params = {"token": token, "client_secret": client_secret, "sid": sid}
         link = (
             self.hs.config.public_baseurl
-            + "_matrix/client/unstable/password_reset/email/submit_token?%s"
+            + "_synapse/client/password_reset/email/submit_token?%s"
             % urllib.parse.urlencode(params)
         )
 
@@ -640,72 +639,3 @@ def string_ordinal_total(s):
     for c in s:
         tot += ord(c)
     return tot
-
-
-def format_ts_filter(value, format):
-    return time.strftime(format, time.localtime(value / 1000))
-
-
-def load_jinja2_templates(
-    template_dir,
-    template_filenames,
-    apply_format_ts_filter=False,
-    apply_mxc_to_http_filter=False,
-    public_baseurl=None,
-):
-    """Loads and returns one or more jinja2 templates and applies optional filters
-
-    Args:
-        template_dir (str): The directory where templates are stored
-        template_filenames (list[str]): A list of template filenames
-        apply_format_ts_filter (bool): Whether to apply a template filter that formats
-            timestamps
-        apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts
-            mxc urls to http urls
-        public_baseurl (str|None): The public baseurl of the server. Required for
-            apply_mxc_to_http_filter to be enabled
-
-    Returns:
-        A list of jinja2 templates corresponding to the given list of filenames,
-        with order preserved
-    """
-    logger.info(
-        "loading email templates %s from '%s'", template_filenames, template_dir
-    )
-    loader = jinja2.FileSystemLoader(template_dir)
-    env = jinja2.Environment(loader=loader)
-
-    if apply_format_ts_filter:
-        env.filters["format_ts"] = format_ts_filter
-
-    if apply_mxc_to_http_filter and public_baseurl:
-        env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl)
-
-    templates = []
-    for template_filename in template_filenames:
-        template = env.get_template(template_filename)
-        templates.append(template)
-
-    return templates
-
-
-def _create_mxc_to_http_filter(public_baseurl):
-    def mxc_to_http_filter(value, width, height, resize_method="crop"):
-        if value[0:6] != "mxc://":
-            return ""
-
-        serverAndMediaId = value[6:]
-        fragment = None
-        if "#" in serverAndMediaId:
-            (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1)
-            fragment = "#" + fragment
-
-        params = {"width": width, "height": height, "method": resize_method}
-        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
-            public_baseurl,
-            serverAndMediaId,
-            urllib.parse.urlencode(params),
-            fragment or "",
-        )
-
-    return mxc_to_http_filter
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 2d79ada189..709ace01e5 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -105,7 +105,7 @@ def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
     return tweaks
 
 
-class PushRuleEvaluatorForEvent(object):
+class PushRuleEvaluatorForEvent:
     def __init__(
         self,
         event: EventBase,
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 8ad0bf5936..2a52e226e3 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -15,24 +15,15 @@
 
 import logging
 
+from synapse.push.emailpusher import EmailPusher
+from synapse.push.mailer import Mailer
+
 from .httppusher import HttpPusher
 
 logger = logging.getLogger(__name__)
 
-# We try importing this if we can (it will fail if we don't
-# have the optional email dependencies installed). We don't
-# yet have the config to know if we need the email pusher,
-# but importing this after daemonizing seems to fail
-# (even though a simple test of importing from a daemonized
-# process works fine)
-try:
-    from synapse.push.emailpusher import EmailPusher
-    from synapse.push.mailer import Mailer, load_jinja2_templates
-except Exception:
-    pass
-
 
-class PusherFactory(object):
+class PusherFactory:
     def __init__(self, hs):
         self.hs = hs
         self.config = hs.config
@@ -43,16 +34,8 @@ class PusherFactory(object):
         if hs.config.email_enable_notifs:
             self.mailers = {}  # app_name -> Mailer
 
-            self.notif_template_html, self.notif_template_text = load_jinja2_templates(
-                self.config.email_template_dir,
-                [
-                    self.config.email_notif_template_html,
-                    self.config.email_notif_template_text,
-                ],
-                apply_format_ts_filter=True,
-                apply_mxc_to_http_filter=True,
-                public_baseurl=self.config.public_baseurl,
-            )
+            self._notif_template_html = hs.config.email_notif_template_html
+            self._notif_template_text = hs.config.email_notif_template_text
 
             self.pusher_types["email"] = self._create_email_pusher
 
@@ -73,8 +56,8 @@ class PusherFactory(object):
             mailer = Mailer(
                 hs=self.hs,
                 app_name=app_name,
-                template_html=self.notif_template_html,
-                template_text=self.notif_template_text,
+                template_html=self._notif_template_html,
+                template_text=self._notif_template_text,
             )
             self.mailers[app_name] = mailer
         return EmailPusher(self.hs, pusherdict, mailer)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3c3262a88c..cc839ffce4 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -64,6 +64,12 @@ class PusherPool:
         self._pusher_shard_config = hs.config.push.pusher_shard_config
         self._instance_name = hs.get_instance_name()
 
+        # Record the last stream ID that we were poked about so we can get
+        # changes since then. We set this to the current max stream ID on
+        # startup as every individual pusher will have checked for changes on
+        # startup.
+        self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
+
         # map from user id to app_id:pushkey to pusher
         self.pushers = {}  # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
 
@@ -178,20 +184,27 @@ class PusherPool:
                 )
                 await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
 
-    async def on_new_notifications(self, min_stream_id, max_stream_id):
+    async def on_new_notifications(self, max_stream_id: int):
         if not self.pushers:
             # nothing to do here.
             return
 
+        if max_stream_id < self._last_room_stream_id_seen:
+            # Nothing to do
+            return
+
+        prev_stream_id = self._last_room_stream_id_seen
+        self._last_room_stream_id_seen = max_stream_id
+
         try:
             users_affected = await self.store.get_push_action_users_in_range(
-                min_stream_id, max_stream_id
+                prev_stream_id, max_stream_id
             )
 
             for u in users_affected:
                 if u in self.pushers:
                     for p in self.pushers[u].values():
-                        p.on_new_notifications(min_stream_id, max_stream_id)
+                        p.on_new_notifications(max_stream_id)
 
         except Exception:
             logger.exception("Exception in pusher on_new_notifications")
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index e5f22fb858..ff0c67228b 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -43,7 +43,7 @@ REQUIREMENTS = [
     "jsonschema>=2.5.1",
     "frozendict>=1",
     "unpaddedbase64>=1.1.0",
-    "canonicaljson>=1.2.0",
+    "canonicaljson>=1.4.0",
     # we use the type definitions added in signedjson 1.1.
     "signedjson>=1.1.0",
     "pynacl>=1.2.1",
@@ -66,7 +66,9 @@ REQUIREMENTS = [
     "msgpack>=0.5.2",
     "phonenumbers>=8.2.0",
     "prometheus_client>=0.0.18,<0.9.0",
-    # we use attr.validators.deep_iterable, which arrived in 19.1.0
+    # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note:
+    # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33
+    # is out in November.)
     "attrs>=19.1.0",
     "netaddr>=0.7.18",
     "Jinja2>=2.9",
@@ -78,8 +80,6 @@ CONDITIONAL_REQUIREMENTS = {
     "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
     # we use execute_batch, which arrived in psycopg 2.7.
     "postgres": ["psycopg2>=2.7"],
-    # ConsentResource uses select_autoescape, which arrived in jinja 2.9
-    "resources.consent": ["Jinja2>=2.9"],
     # ACME support is required to provision TLS certificates from authorities
     # that use the protocol, such as Let's Encrypt.
     "acme": [
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 6a28c2db9d..ba16f22c91 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -33,7 +33,7 @@ from synapse.util.stringutils import random_string
 logger = logging.getLogger(__name__)
 
 
-class ReplicationEndpoint(object):
+class ReplicationEndpoint:
     """Helper base class for defining new replication HTTP endpoints.
 
     This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 6b56315148..5c8be747e1 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         self.federation_handler = hs.get_handlers().federation_handler
 
     @staticmethod
-    async def _serialize_payload(store, event_and_contexts, backfilled):
+    async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
         """
         Args:
             store
+            room_id (str)
             event_and_contexts (list[tuple[FrozenEvent, EventContext]])
             backfilled (bool): Whether or not the events are the result of
                 backfilling
@@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
                 }
             )
 
-        payload = {"events": event_payloads, "backfilled": backfilled}
+        payload = {
+            "events": event_payloads,
+            "backfilled": backfilled,
+            "room_id": room_id,
+        }
 
         return payload
 
@@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         with Measure(self.clock, "repl_fed_send_events_parse"):
             content = parse_json_object_from_request(request)
 
+            room_id = content["room_id"]
             backfilled = content["backfilled"]
 
             event_payloads = content["events"]
@@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         logger.info("Got %d events from federation", len(event_and_contexts))
 
         max_stream_id = await self.federation_handler.persist_events_and_notify(
-            event_and_contexts, backfilled
+            room_id, event_and_contexts, backfilled
         )
 
         return 200, {"max_stream_id": max_stream_id}
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 741329ab5f..08095fdf7d 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict, Requester, UserID
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -181,9 +181,9 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
         Args:
             room_id (str)
             user_id (str)
-            change (str): Either "joined" or "left"
+            change (str): "left"
         """
-        assert change in ("joined", "left")
+        assert change == "left"
 
         return {}
 
@@ -192,9 +192,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
 
         user = UserID.from_string(user_id)
 
-        if change == "joined":
-            user_joined_room(self.distributor, user, room_id)
-        elif change == "left":
+        if change == "left":
             user_left_room(self.distributor, user, room_id)
         else:
             raise Exception("Unrecognized change: %r", change)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index ce9420aa69..a02b27474d 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
         admin,
         user_type,
         address,
+        shadow_banned,
     ):
         """
         Args:
@@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             user_type (str|None): type of user. One of the values from
                 api.constants.UserTypes, or None for a normal user.
             address (str|None): the IP address used to perform the regitration.
+            shadow_banned (bool): Whether to shadow-ban the user
         """
         return {
             "password_hash": password_hash,
@@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             "admin": admin,
             "user_type": user_type,
             "address": address,
+            "shadow_banned": shadow_banned,
         }
 
     async def _handle_request(self, request, user_id):
@@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             admin=content["admin"],
             user_type=content["user_type"],
             address=content["address"],
+            shadow_banned=content["shadow_banned"],
         )
 
         return 200, {}
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..eb74903d68 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -16,14 +16,14 @@
 from synapse.storage.util.id_generators import _load_current_id
 
 
-class SlavedIdTracker(object):
+class SlavedIdTracker:
     def __init__(self, db_conn, table, column, extra_tables=[], step=1):
         self.step = step
         self._current = _load_current_id(db_conn, table, column, step)
         for table, column in extra_tables:
-            self.advance(_load_current_id(db_conn, table, column))
+            self.advance(None, _load_current_id(db_conn, table, column))
 
-    def advance(self, new_id):
+    def advance(self, instance_name, new_id):
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
     def get_current_token(self):
@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
             int
         """
         return self._current
+
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to
+        `get_current_token`.
+        """
+        return self.get_current_token()
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 154f0e687c..bb66ba9b80 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -41,12 +41,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(token)
+            self._account_data_id_gen.advance(instance_name, token)
             for row in rows:
                 self.get_tags_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
         elif stream_name == AccountDataStream.NAME:
-            self._account_data_id_gen.advance(token)
+            self._account_data_id_gen.advance(instance_name, token)
             for row in rows:
                 if not row.room_id:
                     self.get_global_account_data_by_type_for_user.invalidate(
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index ee7f69a918..533d927701 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -46,7 +46,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == ToDeviceStream.NAME:
-            self._device_inbox_id_gen.advance(token)
+            self._device_inbox_id_gen.advance(instance_name, token)
             for row in rows:
                 if row.entity.startswith("@"):
                     self._device_inbox_stream_cache.entity_has_changed(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 722f3745e9..3b788c9625 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,12 +48,15 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             "DeviceListFederationStreamChangeCache", device_list_max
         )
 
+    def get_device_stream_token(self) -> int:
+        return self._device_list_id_gen.get_current_token()
+
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
-            self._device_list_id_gen.advance(token)
+            self._device_list_id_gen.advance(instance_name, token)
             self._invalidate_caches_for_devices(token, rows)
         elif stream_name == UserSignatureStream.NAME:
-            self._device_list_id_gen.advance(token)
+            self._device_list_id_gen.advance(instance_name, token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 3291558c7a..567b4a5cc1 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -40,7 +40,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == GroupServerStream.NAME:
-            self._group_updates_id_gen.advance(token)
+            self._group_updates_id_gen.advance(instance_name, token)
             for row in rows:
                 self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
 
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index a912c04360..025f6f6be8 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -44,7 +44,7 @@ class SlavedPresenceStore(BaseSlavedStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == PresenceStream.NAME:
-            self._presence_id_gen.advance(token)
+            self._presence_id_gen.advance(instance_name, token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
                 self._get_presence_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 590187df46..de904c943c 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 
@@ -21,18 +22,15 @@ from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def get_push_rules_stream_token(self):
-        return (
-            self._push_rules_stream_id_gen.get_current_token(),
-            self._stream_id_gen.get_current_token(),
-        )
-
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
+        # We assert this for the benefit of mypy
+        assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
+
         if stream_name == PushRulesStream.NAME:
-            self._push_rules_stream_id_gen.advance(token)
+            self._push_rules_stream_id_gen.advance(instance_name, token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
                 self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 63300e5da6..9da218bfe8 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -34,5 +34,5 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == PushersStream.NAME:
-            self._pushers_id_gen.advance(token)
+            self._pushers_id_gen.advance(instance_name, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 17ba1f22ac..5c2986e050 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -46,7 +46,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == ReceiptsStream.NAME:
-            self._receipts_id_gen.advance(token)
+            self._receipts_id_gen.advance(instance_name, token)
             for row in rows:
                 self.invalidate_caches_for_receipt(
                     row.room_id, row.receipt_type, row.user_id
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 427c81772b..80ae803ad9 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -33,6 +33,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == PublicRoomsStream.NAME:
-            self._public_room_id_gen.advance(token)
+            self._public_room_id_gen.advance(instance_name, token)
 
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index fcf8ebf1e7..e82b9e386f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 """A replication client for use by synapse workers.
 """
-import heapq
 import logging
 from typing import TYPE_CHECKING, Dict, List, Tuple
 
@@ -30,6 +29,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
     EventsStreamRow,
 )
+from synapse.types import UserID
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -99,7 +99,6 @@ class ReplicationDataHandler:
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
-        self.pusher_pool = hs.get_pusherpool()
         self.notifier = hs.get_notifier()
         self._reactor = hs.get_reactor()
         self._clock = hs.get_clock()
@@ -149,14 +148,12 @@ class ReplicationDataHandler:
                 if event.rejected_reason:
                     continue
 
-                extra_users = ()  # type: Tuple[str, ...]
+                extra_users = ()  # type: Tuple[UserID, ...]
                 if event.type == EventTypes.Member:
-                    extra_users = (event.state_key,)
+                    extra_users = (UserID.from_string(event.state_key),)
                 max_token = self.store.get_room_max_stream_ordering()
                 self.notifier.on_new_room_event(event, token, max_token, extra_users)
 
-            await self.pusher_pool.on_new_notifications(token, token)
-
         # Notify any waiting deferreds. The list is ordered by position so we
         # just iterate through the list until we reach a position that is
         # greater than the received row position.
@@ -219,9 +216,8 @@ class ReplicationDataHandler:
 
         waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
 
-        # We insert into the list using heapq as it is more efficient than
-        # pushing then resorting each time.
-        heapq.heappush(waiting_list, (position, deferred))
+        waiting_list.append((position, deferred))
+        waiting_list.sort(key=lambda t: t[0])
 
         # We measure here to get in flight counts and average waiting time.
         with Measure(self._clock, "repl.wait_for_stream_position"):
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index d853e4447e..8cd47770c1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -21,9 +21,7 @@ import abc
 import logging
 from typing import Tuple, Type
 
-from canonicaljson import json
-
-from synapse.util import json_encoder as _json_encoder
+from synapse.util import json_decoder, json_encoder
 
 logger = logging.getLogger(__name__)
 
@@ -125,7 +123,7 @@ class RdataCommand(Command):
             stream_name,
             instance_name,
             None if token == "batch" else int(token),
-            json.loads(row_json),
+            json_decoder.decode(row_json),
         )
 
     def to_line(self):
@@ -134,7 +132,7 @@ class RdataCommand(Command):
                 self.stream_name,
                 self.instance_name,
                 str(self.token) if self.token is not None else "batch",
-                _json_encoder.encode(self.row),
+                json_encoder.encode(self.row),
             )
         )
 
@@ -359,7 +357,7 @@ class UserIpCommand(Command):
     def from_line(cls, line):
         user_id, jsn = line.split(" ", 1)
 
-        access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
+        access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
 
         return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
 
@@ -367,7 +365,7 @@ class UserIpCommand(Command):
         return (
             self.user_id
             + " "
-            + _json_encoder.encode(
+            + json_encoder.encode(
                 (
                     self.access_token,
                     self.ip,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1c303f3a46..b323841f73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -109,7 +109,7 @@ class ReplicationCommandHandler:
             if isinstance(stream, (EventsStream, BackfillStream)):
                 # Only add EventStream and BackfillStream as a source on the
                 # instance in charge of event persistence.
-                if hs.config.worker.writers.events == hs.get_instance_name():
+                if hs.get_instance_name() in hs.config.worker.writers.events:
                     self._streams_to_replicate.append(stream)
 
                 continue
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0350923898..0b0d204e64 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -113,7 +113,7 @@ PING_TIMEOUT_MULTIPLIER = 5
 PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER
 
 
-class ConnectionStates(object):
+class ConnectionStates:
     CONNECTING = "connecting"
     ESTABLISHED = "established"
     PAUSED = "paused"
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 41569305df..687984e7a8 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -58,7 +58,7 @@ class ReplicationStreamProtocolFactory(Factory):
         )
 
 
-class ReplicationStreamer(object):
+class ReplicationStreamer:
     """Handles replication connections.
 
     This needs to be poked when new replication data may be available. When new
@@ -93,7 +93,7 @@ class ReplicationStreamer(object):
         """
         if not self.command_handler.connected():
             # Don't bother if nothing is listening. We still need to advance
-            # the stream tokens otherwise they'll fall beihind forever
+            # the stream tokens otherwise they'll fall behind forever
             for stream in self.streams:
                 stream.discard_updates_and_advance()
             return
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7a42de3f7d..1f609f158c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -79,7 +79,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
 UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
 
 
-class Stream(object):
+class Stream:
     """Base class for the streams.
 
     Provides a `get_updates()` function that returns new updates since the last
@@ -352,7 +352,7 @@ class PushRulesStream(Stream):
         )
 
     def _current_token(self, instance_name: str) -> int:
-        push_rules_token, _ = self.store.get_push_rules_stream_token()
+        push_rules_token = self.store.get_max_push_rules_stream_id()
         return push_rules_token
 
 
@@ -383,7 +383,7 @@ class CachesStream(Stream):
     the cache on the workers
     """
 
-    @attr.s
+    @attr.s(slots=True)
     class CachesStreamRow:
         """Stream to inform workers they should invalidate their cache.
 
@@ -405,7 +405,7 @@ class CachesStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token,
+            store.get_cache_stream_token_for_writer,
             store.get_all_updated_caches,
         )
 
@@ -441,7 +441,7 @@ class DeviceListsStream(Stream):
     told about a device update.
     """
 
-    @attr.s
+    @attr.s(slots=True)
     class DeviceListsStreamRow:
         entity = attr.ib(type=str)
 
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 16c63ff4ec..ccc7ca30d8 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
+from ._base import Stream, StreamUpdateResult, Token
 
 """Handling of the 'events' replication stream
 
@@ -49,14 +49,14 @@ data part are:
 
 
 @attr.s(slots=True, frozen=True)
-class EventsStreamRow(object):
+class EventsStreamRow:
     """A parsed row from the events replication stream"""
 
     type = attr.ib()  # str: the TypeId of one of the *EventsStreamRows
     data = attr.ib()  # BaseEventsStreamRow
 
 
-class BaseEventsStreamRow(object):
+class BaseEventsStreamRow:
     """Base class for rows to be sent in the events stream.
 
     Specifies how to identify, serialize and deserialize the different types.
@@ -117,7 +117,7 @@ class EventsStream(Stream):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(self._store.get_current_events_token),
+            self._store._stream_id_gen.get_current_token_for_writer,
             self._update_function,
         )
 
diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html
new file mode 100644
index 0000000000..def4b5162b
--- /dev/null
+++ b/synapse/res/templates/password_reset_confirmation.html
@@ -0,0 +1,16 @@
+<html>
+<head></head>
+<body>
+<!--Use a hidden form to resubmit the information necessary to reset the password-->
+<form method="post">
+    <input type="hidden" name="sid" value="{{ sid }}">
+    <input type="hidden" name="token" value="{{ token }}">
+    <input type="hidden" name="client_secret" value="{{ client_secret }}">
+
+    <p>You have requested to <strong>reset your Matrix account password</strong>. Click the link below to confirm this action. <br /><br />
+        If you did not mean to do this, please close this page and your password will not be changed.</p>
+    <p><button type="submit">Confirm changing my password</button></p>
+</form>
+</body>
+</html>
+
diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html
deleted file mode 100644
index 01cd9bdaf3..0000000000
--- a/synapse/res/templates/saml_error.html
+++ /dev/null
@@ -1,52 +0,0 @@
-<!DOCTYPE html>
-<html lang="en">
-<head>
-    <meta charset="UTF-8">
-    <title>SSO login error</title>
-</head>
-<body>
-{# a 403 means we have actively rejected their login #}
-{% if code == 403 %}
-    <p>You are not allowed to log in here.</p>
-{% else %}
-    <p>
-        There was an error during authentication:
-    </p>
-    <div id="errormsg" style="margin:20px 80px">{{ msg }}</div>
-    <p>
-        If you are seeing this page after clicking a link sent to you via email, make
-        sure you only click the confirmation link once, and that you open the
-        validation link in the same client you're logging in from.
-    </p>
-    <p>
-        Try logging in again from your Matrix client and if the problem persists
-        please contact the server's administrator.
-    </p>
-
-    <script type="text/javascript">
-        // Error handling to support Auth0 errors that we might get through a GET request
-        // to the validation endpoint. If an error is provided, it's either going to be
-        // located in the query string or in a query string-like URI fragment.
-        // We try to locate the error from any of these two locations, but if we can't
-        // we just don't print anything specific.
-        let searchStr = "";
-        if (window.location.search) {
-            // window.location.searchParams isn't always defined when
-            // window.location.search is, so it's more reliable to parse the latter.
-            searchStr = window.location.search;
-        } else if (window.location.hash) {
-            // Replace the # with a ? so that URLSearchParams does the right thing and
-            // doesn't parse the first parameter incorrectly.
-            searchStr = window.location.hash.replace("#", "?");
-        }
-
-        // We might end up with no error in the URL, so we need to check if we have one
-        // to print one.
-        let errorDesc = new URLSearchParams(searchStr).get("error_description")
-        if (errorDesc) {
-            document.getElementById("errormsg").innerText = errorDesc;
-        }
-    </script>
-{% endif %}
-</body>
-</html>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 43a211386b..af8459719a 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -5,14 +5,49 @@
     <title>SSO error</title>
 </head>
 <body>
-    <p>Oops! Something went wrong during authentication.</p>
+{# If an error of unauthorised is returned it means we have actively rejected their login #}
+{% if error == "unauthorised" %}
+    <p>You are not allowed to log in here.</p>
+{% else %}
+    <p>
+        There was an error during authentication:
+    </p>
+    <div id="errormsg" style="margin:20px 80px">{{ error_description }}</div>
+    <p>
+        If you are seeing this page after clicking a link sent to you via email, make
+        sure you only click the confirmation link once, and that you open the
+        validation link in the same client you're logging in from.
+    </p>
     <p>
         Try logging in again from your Matrix client and if the problem persists
         please contact the server's administrator.
     </p>
     <p>Error: <code>{{ error }}</code></p>
-    {% if error_description %}
-    <pre><code>{{ error_description }}</code></pre>
-    {% endif %}
+
+    <script type="text/javascript">
+        // Error handling to support Auth0 errors that we might get through a GET request
+        // to the validation endpoint. If an error is provided, it's either going to be
+        // located in the query string or in a query string-like URI fragment.
+        // We try to locate the error from any of these two locations, but if we can't
+        // we just don't print anything specific.
+        let searchStr = "";
+        if (window.location.search) {
+            // window.location.searchParams isn't always defined when
+            // window.location.search is, so it's more reliable to parse the latter.
+            searchStr = window.location.search;
+        } else if (window.location.hash) {
+            // Replace the # with a ? so that URLSearchParams does the right thing and
+            // doesn't parse the first parameter incorrectly.
+            searchStr = window.location.hash.replace("#", "?");
+        }
+
+        // We might end up with no error in the URL, so we need to check if we have one
+        // to print one.
+        let errorDesc = new URLSearchParams(searchStr).get("error_description")
+        if (errorDesc) {
+            document.getElementById("errormsg").innerText = errorDesc;
+        }
+    </script>
+{% endif %}
 </body>
 </html>
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 46e458e95b..40f5c32db2 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -13,8 +13,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import synapse.rest.admin
 from synapse.http.server import JsonResource
+from synapse.rest import admin
 from synapse.rest.client import versions
 from synapse.rest.client.v1 import (
     directory,
@@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import (
     room_keys,
     room_upgrade_rest_servlet,
     sendtodevice,
+    shared_rooms,
     sync,
     tags,
     thirdparty,
@@ -122,6 +123,7 @@ class ClientRestResource(JsonResource):
         password_policy.register_servlets(hs, client_resource)
 
         # moving to /_synapse/admin
-        synapse.rest.admin.register_servlets_for_client_rest_resource(
-            hs, client_resource
-        )
+        admin.register_servlets_for_client_rest_resource(hs, client_resource)
+
+        # unstable
+        shared_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 7c292ef3f9..09726d52d6 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -316,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet):
         join_rules_event = room_state.get((EventTypes.JoinRules, ""))
         if join_rules_event:
             if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
+                # update_membership with an action of "invite" can raise a
+                # ShadowBanError. This is not handled since it is assumed that
+                # an admin isn't going to call this API with a shadow-banned user.
                 await self.room_member_handler.update_membership(
                     requester=requester,
                     target=fake_requester.user,
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index cc0bdfa5c9..f3e77da850 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -73,6 +73,7 @@ class UsersRestServletV2(RestServlet):
     The parameters `from` and `limit` are required only for pagination.
     By default, a `limit` of 100 is used.
     The parameter `user_id` can be used to filter by user id.
+    The parameter `name` can be used to filter by user id or display name.
     The parameter `guests` can be used to exclude guest users.
     The parameter `deactivated` can be used to include deactivated users.
     """
@@ -89,11 +90,12 @@ class UsersRestServletV2(RestServlet):
         start = parse_integer(request, "from", default=0)
         limit = parse_integer(request, "limit", default=100)
         user_id = parse_string(request, "user_id", default=None)
+        name = parse_string(request, "name", default=None)
         guests = parse_boolean(request, "guests", default=True)
         deactivated = parse_boolean(request, "deactivated", default=False)
 
         users, total = await self.store.get_users_paginate(
-            start, limit, user_id, guests, deactivated
+            start, limit, user_id, name, guests, deactivated
         )
         ret = {"users": users, "total": total}
         if len(users) >= limit:
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 6da71dc46f..7be5c0fb88 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 CLEANUP_PERIOD_MS = 1000 * 60 * 30  # 30 mins
 
 
-class HttpTransactionCache(object):
+class HttpTransactionCache:
     def __init__(self, hs):
         self.hs = hs
         self.auth = self.hs.get_auth()
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 379f668d6f..a14618ac84 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,6 +18,10 @@ from typing import Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.handlers.auth import (
+    convert_client_dict_legacy_fields_to_identifier,
+    login_id_phone_to_thirdparty,
+)
 from synapse.http.server import finish_request
 from synapse.http.servlet import (
     RestServlet,
@@ -28,56 +32,11 @@ from synapse.http.site import SynapseRequest
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
 from synapse.types import JsonDict, UserID
-from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
 logger = logging.getLogger(__name__)
 
 
-def login_submission_legacy_convert(submission):
-    """
-    If the input login submission is an old style object
-    (ie. with top-level user / medium / address) convert it
-    to a typed object.
-    """
-    if "user" in submission:
-        submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
-        del submission["user"]
-
-    if "medium" in submission and "address" in submission:
-        submission["identifier"] = {
-            "type": "m.id.thirdparty",
-            "medium": submission["medium"],
-            "address": submission["address"],
-        }
-        del submission["medium"]
-        del submission["address"]
-
-
-def login_id_thirdparty_from_phone(identifier):
-    """
-    Convert a phone login identifier type to a generic threepid identifier
-    Args:
-        identifier(dict): Login identifier dict of type 'm.id.phone'
-
-    Returns: Login identifier dict of type 'm.id.threepid'
-    """
-    if "country" not in identifier or (
-        # The specification requires a "phone" field, while Synapse used to require a "number"
-        # field. Accept both for backwards compatibility.
-        "phone" not in identifier
-        and "number" not in identifier
-    ):
-        raise SynapseError(400, "Invalid phone-type identifier")
-
-    # Accept both "phone" and "number" as valid keys in m.id.phone
-    phone_number = identifier.get("phone", identifier["number"])
-
-    msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
-
-    return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
-
-
 class LoginRestServlet(RestServlet):
     PATTERNS = client_patterns("/login$", v1=True)
     CAS_TYPE = "m.login.cas"
@@ -194,18 +153,11 @@ class LoginRestServlet(RestServlet):
             login_submission.get("address"),
             login_submission.get("user"),
         )
-        login_submission_legacy_convert(login_submission)
-
-        if "identifier" not in login_submission:
-            raise SynapseError(400, "Missing param: identifier")
-
-        identifier = login_submission["identifier"]
-        if "type" not in identifier:
-            raise SynapseError(400, "Login identifier has no type")
+        identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
 
         # convert phone type identifiers to generic threepids
         if identifier["type"] == "m.id.phone":
-            identifier = login_id_thirdparty_from_phone(identifier)
+            identifier = login_id_phone_to_thirdparty(identifier)
 
         # convert threepid identifiers to user IDs
         if identifier["type"] == "m.id.thirdparty":
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 00831879f3..ddf8ed5e9c 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from synapse.api.errors import (
     NotFoundError,
     StoreError,
@@ -160,10 +159,22 @@ class PushRuleRestServlet(RestServlet):
         return 200, {}
 
     def notify_user(self, user_id):
-        stream_id, _ = self.store.get_push_rules_stream_token()
+        stream_id = self.store.get_max_push_rules_stream_id()
         self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
 
-    def set_rule_attr(self, user_id, spec, val):
+    async def set_rule_attr(self, user_id, spec, val):
+        if spec["attr"] not in ("enabled", "actions"):
+            # for the sake of potential future expansion, shouldn't report
+            # 404 in the case of an unknown request so check it corresponds to
+            # a known attribute first.
+            raise UnrecognizedRequestError()
+
+        namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+        rule_id = spec["rule_id"]
+        is_default_rule = rule_id.startswith(".")
+        if is_default_rule:
+            if namespaced_rule_id not in BASE_RULE_IDS:
+                raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
         if spec["attr"] == "enabled":
             if isinstance(val, dict) and "enabled" in val:
                 val = val["enabled"]
@@ -172,8 +183,9 @@ class PushRuleRestServlet(RestServlet):
                 # This should *actually* take a dict, but many clients pass
                 # bools directly, so let's not break them.
                 raise SynapseError(400, "Value for 'enabled' must be boolean")
-            namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
-            return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
+            return await self.store.set_push_rule_enabled(
+                user_id, namespaced_rule_id, val, is_default_rule
+            )
         elif spec["attr"] == "actions":
             actions = val.get("actions")
             _check_actions(actions)
@@ -188,7 +200,7 @@ class PushRuleRestServlet(RestServlet):
 
                 if namespaced_rule_id not in rule_ids:
                     raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
-            return self.store.set_push_rule_actions(
+            return await self.store.set_push_rule_actions(
                 user_id, namespaced_rule_id, actions, is_default_rule
             )
         else:
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 2ab30ce897..84baf3d59b 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -21,14 +21,13 @@ import re
 from typing import List, Optional
 from urllib import parse as urlparse
 
-from canonicaljson import json
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
     Codes,
     HttpResponseException,
     InvalidClientCredentialsError,
+    ShadowBanError,
     SynapseError,
 )
 from synapse.api.filtering import Filter
@@ -46,6 +45,8 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.util import json_decoder
+from synapse.util.stringutils import random_string
 
 MYPY = False
 if MYPY:
@@ -170,7 +171,6 @@ class RoomStateEventRestServlet(TransactionRestServlet):
             room_id=room_id,
             event_type=event_type,
             state_key=state_key,
-            is_guest=requester.is_guest,
         )
 
         if not data:
@@ -200,23 +200,26 @@ class RoomStateEventRestServlet(TransactionRestServlet):
         if state_key is not None:
             event_dict["state_key"] = state_key
 
-        if event_type == EventTypes.Member:
-            membership = content.get("membership", None)
-            event_id, _ = await self.room_member_handler.update_membership(
-                requester,
-                target=UserID.from_string(state_key),
-                room_id=room_id,
-                action=membership,
-                content=content,
-            )
-        else:
-            (
-                event,
-                _,
-            ) = await self.event_creation_handler.create_and_send_nonmember_event(
-                requester, event_dict, txn_id=txn_id
-            )
-            event_id = event.event_id
+        try:
+            if event_type == EventTypes.Member:
+                membership = content.get("membership", None)
+                event_id, _ = await self.room_member_handler.update_membership(
+                    requester,
+                    target=UserID.from_string(state_key),
+                    room_id=room_id,
+                    action=membership,
+                    content=content,
+                )
+            else:
+                (
+                    event,
+                    _,
+                ) = await self.event_creation_handler.create_and_send_nonmember_event(
+                    requester, event_dict, txn_id=txn_id
+                )
+                event_id = event.event_id
+        except ShadowBanError:
+            event_id = "$" + random_string(43)
 
         set_tag("event_id", event_id)
         ret = {"event_id": event_id}
@@ -249,12 +252,19 @@ class RoomSendEventRestServlet(TransactionRestServlet):
         if b"ts" in request.args and requester.app_service:
             event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
 
-        event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
-            requester, event_dict, txn_id=txn_id
-        )
+        try:
+            (
+                event,
+                _,
+            ) = await self.event_creation_handler.create_and_send_nonmember_event(
+                requester, event_dict, txn_id=txn_id
+            )
+            event_id = event.event_id
+        except ShadowBanError:
+            event_id = "$" + random_string(43)
 
-        set_tag("event_id", event.event_id)
-        return 200, {"event_id": event.event_id}
+        set_tag("event_id", event_id)
+        return 200, {"event_id": event_id}
 
     def on_GET(self, request, room_id, event_type, txn_id):
         return 200, "Not implemented"
@@ -519,7 +529,9 @@ class RoomMessageListRestServlet(RestServlet):
         filter_str = parse_string(request, b"filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(json.loads(filter_json))  # type: Optional[Filter]
+            event_filter = Filter(
+                json_decoder.decode(filter_json)
+            )  # type: Optional[Filter]
             if (
                 event_filter
                 and event_filter.filter_json.get("event_format", "client")
@@ -631,7 +643,9 @@ class RoomEventContextServlet(RestServlet):
         filter_str = parse_string(request, b"filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(json.loads(filter_json))  # type: Optional[Filter]
+            event_filter = Filter(
+                json_decoder.decode(filter_json)
+            )  # type: Optional[Filter]
         else:
             event_filter = None
 
@@ -716,16 +730,20 @@ class RoomMembershipRestServlet(TransactionRestServlet):
             content = {}
 
         if membership_action == "invite" and self._has_3pid_invite_keys(content):
-            await self.room_member_handler.do_3pid_invite(
-                room_id,
-                requester.user,
-                content["medium"],
-                content["address"],
-                content["id_server"],
-                requester,
-                txn_id,
-                content.get("id_access_token"),
-            )
+            try:
+                await self.room_member_handler.do_3pid_invite(
+                    room_id,
+                    requester.user,
+                    content["medium"],
+                    content["address"],
+                    content["id_server"],
+                    requester,
+                    txn_id,
+                    content.get("id_access_token"),
+                )
+            except ShadowBanError:
+                # Pretend the request succeeded.
+                pass
             return 200, {}
 
         target = requester.user
@@ -737,15 +755,19 @@ class RoomMembershipRestServlet(TransactionRestServlet):
         if "reason" in content:
             event_content = {"reason": content["reason"]}
 
-        await self.room_member_handler.update_membership(
-            requester=requester,
-            target=target,
-            room_id=room_id,
-            action=membership_action,
-            txn_id=txn_id,
-            third_party_signed=content.get("third_party_signed", None),
-            content=event_content,
-        )
+        try:
+            await self.room_member_handler.update_membership(
+                requester=requester,
+                target=target,
+                room_id=room_id,
+                action=membership_action,
+                txn_id=txn_id,
+                third_party_signed=content.get("third_party_signed", None),
+                content=event_content,
+            )
+        except ShadowBanError:
+            # Pretend the request succeeded.
+            pass
 
         return_value = {}
 
@@ -783,20 +805,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
         requester = await self.auth.get_user_by_req(request)
         content = parse_json_object_from_request(request)
 
-        event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
-            requester,
-            {
-                "type": EventTypes.Redaction,
-                "content": content,
-                "room_id": room_id,
-                "sender": requester.user.to_string(),
-                "redacts": event_id,
-            },
-            txn_id=txn_id,
-        )
+        try:
+            (
+                event,
+                _,
+            ) = await self.event_creation_handler.create_and_send_nonmember_event(
+                requester,
+                {
+                    "type": EventTypes.Redaction,
+                    "content": content,
+                    "room_id": room_id,
+                    "sender": requester.user.to_string(),
+                    "redacts": event_id,
+                },
+                txn_id=txn_id,
+            )
+            event_id = event.event_id
+        except ShadowBanError:
+            event_id = "$" + random_string(43)
 
-        set_tag("event_id", event.event_id)
-        return 200, {"event_id": event.event_id}
+        set_tag("event_id", event_id)
+        return 200, {"event_id": event_id}
 
     def on_PUT(self, request, room_id, event_id, txn_id):
         set_tag("txn_id", txn_id)
@@ -839,17 +868,21 @@ class RoomTypingRestServlet(RestServlet):
         # Limit timeout to stop people from setting silly typing timeouts.
         timeout = min(content.get("timeout", 30000), 120000)
 
-        if content["typing"]:
-            await self.typing_handler.started_typing(
-                target_user=target_user,
-                auth_user=requester.user,
-                room_id=room_id,
-                timeout=timeout,
-            )
-        else:
-            await self.typing_handler.stopped_typing(
-                target_user=target_user, auth_user=requester.user, room_id=room_id
-            )
+        try:
+            if content["typing"]:
+                await self.typing_handler.started_typing(
+                    target_user=target_user,
+                    requester=requester,
+                    room_id=room_id,
+                    timeout=timeout,
+                )
+            else:
+                await self.typing_handler.stopped_typing(
+                    target_user=target_user, requester=requester, room_id=room_id
+                )
+        except ShadowBanError:
+            # Pretend this worked without error.
+            pass
 
         return 200, {}
 
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index fead85074b..ade97a6708 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -15,7 +15,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import random
 from http import HTTPStatus
+from typing import TYPE_CHECKING
+from urllib.parse import urlparse
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
@@ -32,7 +38,7 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
-from synapse.push.mailer import Mailer, load_jinja2_templates
+from synapse.push.mailer import Mailer
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import assert_valid_client_secret, random_string
 from synapse.util.threepids import canonicalise_email, check_3pid_allowed
@@ -53,21 +59,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         self.identity_handler = hs.get_handlers().identity_handler
 
         if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
-            template_html, template_text = load_jinja2_templates(
-                self.config.email_template_dir,
-                [
-                    self.config.email_password_reset_template_html,
-                    self.config.email_password_reset_template_text,
-                ],
-                apply_format_ts_filter=True,
-                apply_mxc_to_http_filter=True,
-                public_baseurl=self.config.public_baseurl,
-            )
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email_app_name,
-                template_html=template_html,
-                template_text=template_text,
+                template_html=self.config.email_password_reset_template_html,
+                template_text=self.config.email_password_reset_template_text,
             )
 
     async def on_POST(self, request):
@@ -107,6 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        # Raise if the provided next_link value isn't valid
+        assert_valid_next_link(self.hs, next_link)
+
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
         # an email address which is controlled by the attacker but which, after
@@ -119,6 +118,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             if self.config.request_token_inhibit_3pid_errors:
                 # Make the client think the operation succeeded. See the rationale in the
                 # comments for request_token_inhibit_3pid_errors.
+                # Also wait for some random amount of time between 100ms and 1s to make it
+                # look like we did something.
+                await self.hs.clock.sleep(random.randint(1, 10) / 10)
                 return 200, {"sid": random_string(16)}
 
             raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -150,82 +152,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         return 200, ret
 
 
-class PasswordResetSubmitTokenServlet(RestServlet):
-    """Handles 3PID validation token submission"""
-
-    PATTERNS = client_patterns(
-        "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
-    )
-
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
-        super(PasswordResetSubmitTokenServlet, self).__init__()
-        self.hs = hs
-        self.auth = hs.get_auth()
-        self.config = hs.config
-        self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
-        if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
-            (self.failure_email_template,) = load_jinja2_templates(
-                self.config.email_template_dir,
-                [self.config.email_password_reset_template_failure_html],
-            )
-
-    async def on_GET(self, request, medium):
-        # We currently only handle threepid token submissions for email
-        if medium != "email":
-            raise SynapseError(
-                400, "This medium is currently not supported for password resets"
-            )
-        if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "Password reset emails have been disabled due to lack of an email config"
-                )
-            raise SynapseError(
-                400, "Email-based password resets are disabled on this server"
-            )
-
-        sid = parse_string(request, "sid", required=True)
-        token = parse_string(request, "token", required=True)
-        client_secret = parse_string(request, "client_secret", required=True)
-        assert_valid_client_secret(client_secret)
-
-        # Attempt to validate a 3PID session
-        try:
-            # Mark the session as valid
-            next_link = await self.store.validate_threepid_session(
-                sid, client_secret, token, self.clock.time_msec()
-            )
-
-            # Perform a 302 redirect if next_link is set
-            if next_link:
-                if next_link.startswith("file:///"):
-                    logger.warning(
-                        "Not redirecting to next_link as it is a local file: address"
-                    )
-                else:
-                    request.setResponseCode(302)
-                    request.setHeader("Location", next_link)
-                    finish_request(request)
-                    return None
-
-            # Otherwise show the success template
-            html = self.config.email_password_reset_template_success_html
-            status_code = 200
-        except ThreepidValidationError as e:
-            status_code = e.code
-
-            # Show a failure page with a reason
-            template_vars = {"failure_reason": e.msg}
-            html = self.failure_email_template.render(**template_vars)
-
-        respond_with_html(request, status_code, html)
-
-
 class PasswordRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password$")
 
@@ -375,7 +301,7 @@ class DeactivateAccountRestServlet(RestServlet):
 
         requester = await self.auth.get_user_by_req(request)
 
-        # allow ASes to dectivate their own users
+        # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
                 requester.user.to_string(), erase
@@ -411,19 +337,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
         self.store = self.hs.get_datastore()
 
         if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
-            template_html, template_text = load_jinja2_templates(
-                self.config.email_template_dir,
-                [
-                    self.config.email_add_threepid_template_html,
-                    self.config.email_add_threepid_template_text,
-                ],
-                public_baseurl=self.config.public_baseurl,
-            )
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email_app_name,
-                template_html=template_html,
-                template_text=template_text,
+                template_html=self.config.email_add_threepid_template_html,
+                template_text=self.config.email_add_threepid_template_text,
             )
 
     async def on_POST(self, request):
@@ -461,12 +379,18 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        # Raise if the provided next_link value isn't valid
+        assert_valid_next_link(self.hs, next_link)
+
         existing_user_id = await self.store.get_user_id_by_threepid("email", email)
 
         if existing_user_id is not None:
             if self.config.request_token_inhibit_3pid_errors:
                 # Make the client think the operation succeeded. See the rationale in the
                 # comments for request_token_inhibit_3pid_errors.
+                # Also wait for some random amount of time between 100ms and 1s to make it
+                # look like we did something.
+                await self.hs.clock.sleep(random.randint(1, 10) / 10)
                 return 200, {"sid": random_string(16)}
 
             raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -529,12 +453,18 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        # Raise if the provided next_link value isn't valid
+        assert_valid_next_link(self.hs, next_link)
+
         existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
 
         if existing_user_id is not None:
             if self.hs.config.request_token_inhibit_3pid_errors:
                 # Make the client think the operation succeeded. See the rationale in the
                 # comments for request_token_inhibit_3pid_errors.
+                # Also wait for some random amount of time between 100ms and 1s to make it
+                # look like we did something.
+                await self.hs.clock.sleep(random.randint(1, 10) / 10)
                 return 200, {"sid": random_string(16)}
 
             raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
@@ -578,9 +508,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
-            (self.failure_email_template,) = load_jinja2_templates(
-                self.config.email_template_dir,
-                [self.config.email_add_threepid_template_failure_html],
+            self._failure_email_template = (
+                self.config.email_add_threepid_template_failure_html
             )
 
     async def on_GET(self, request):
@@ -613,15 +542,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
 
             # Perform a 302 redirect if next_link is set
             if next_link:
-                if next_link.startswith("file:///"):
-                    logger.warning(
-                        "Not redirecting to next_link as it is a local file: address"
-                    )
-                else:
-                    request.setResponseCode(302)
-                    request.setHeader("Location", next_link)
-                    finish_request(request)
-                    return None
+                request.setResponseCode(302)
+                request.setHeader("Location", next_link)
+                finish_request(request)
+                return None
 
             # Otherwise show the success template
             html = self.config.email_add_threepid_template_success_html_content
@@ -631,7 +555,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
 
             # Show a failure page with a reason
             template_vars = {"failure_reason": e.msg}
-            html = self.failure_email_template.render(**template_vars)
+            html = self._failure_email_template.render(**template_vars)
 
         respond_with_html(request, status_code, html)
 
@@ -885,6 +809,45 @@ class ThreepidDeleteRestServlet(RestServlet):
         return 200, {"id_server_unbind_result": id_server_unbind_result}
 
 
+def assert_valid_next_link(hs: "HomeServer", next_link: str):
+    """
+    Raises a SynapseError if a given next_link value is invalid
+
+    next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config
+    option is either empty or contains a domain that matches the one in the given next_link
+
+    Args:
+        hs: The homeserver object
+        next_link: The next_link value given by the client
+
+    Raises:
+        SynapseError: If the next_link is invalid
+    """
+    valid = True
+
+    # Parse the contents of the URL
+    next_link_parsed = urlparse(next_link)
+
+    # Scheme must not point to the local drive
+    if next_link_parsed.scheme == "file":
+        valid = False
+
+    # If the domain whitelist is set, the domain must be in it
+    if (
+        valid
+        and hs.config.next_link_domain_whitelist is not None
+        and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
+    ):
+        valid = False
+
+    if not valid:
+        raise SynapseError(
+            400,
+            "'next_link' domain not included in whitelist, or not http(s)",
+            errcode=Codes.INVALID_PARAM,
+        )
+
+
 class WhoamiRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/whoami$")
 
@@ -900,7 +863,6 @@ class WhoamiRestServlet(RestServlet):
 
 def register_servlets(hs, http_server):
     EmailPasswordRequestTokenRestServlet(hs).register(http_server)
-    PasswordResetSubmitTokenServlet(hs).register(http_server)
     PasswordRestServlet(hs).register(http_server)
     DeactivateAccountRestServlet(hs).register(http_server)
     EmailThreepidRequestTokenRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index d84a6d7e11..13ecf7005d 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,6 +16,7 @@
 
 import logging
 
+from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.types import GroupID
 
@@ -325,6 +326,9 @@ class GroupRoomServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
+        if not GroupID.is_valid(group_id):
+            raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+
         result = await self.groups_handler.get_rooms_in_group(
             group_id, requester_user_id
         )
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index f808175698..b6b90a8b30 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -16,6 +16,7 @@
 
 import hmac
 import logging
+import random
 from typing import List, Union
 
 import synapse
@@ -44,7 +45,7 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
-from synapse.push.mailer import load_jinja2_templates
+from synapse.push.mailer import Mailer
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -81,23 +82,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
         self.config = hs.config
 
         if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
-            from synapse.push.mailer import Mailer, load_jinja2_templates
-
-            template_html, template_text = load_jinja2_templates(
-                self.config.email_template_dir,
-                [
-                    self.config.email_registration_template_html,
-                    self.config.email_registration_template_text,
-                ],
-                apply_format_ts_filter=True,
-                apply_mxc_to_http_filter=True,
-                public_baseurl=self.config.public_baseurl,
-            )
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email_app_name,
-                template_html=template_html,
-                template_text=template_text,
+                template_html=self.config.email_registration_template_html,
+                template_text=self.config.email_registration_template_text,
             )
 
     async def on_POST(self, request):
@@ -143,6 +132,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
             if self.hs.config.request_token_inhibit_3pid_errors:
                 # Make the client think the operation succeeded. See the rationale in the
                 # comments for request_token_inhibit_3pid_errors.
+                # Also wait for some random amount of time between 100ms and 1s to make it
+                # look like we did something.
+                await self.hs.clock.sleep(random.randint(1, 10) / 10)
                 return 200, {"sid": random_string(16)}
 
             raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -215,6 +207,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
             if self.hs.config.request_token_inhibit_3pid_errors:
                 # Make the client think the operation succeeded. See the rationale in the
                 # comments for request_token_inhibit_3pid_errors.
+                # Also wait for some random amount of time between 100ms and 1s to make it
+                # look like we did something.
+                await self.hs.clock.sleep(random.randint(1, 10) / 10)
                 return 200, {"sid": random_string(16)}
 
             raise SynapseError(
@@ -262,15 +257,8 @@ class RegistrationSubmitTokenServlet(RestServlet):
         self.store = hs.get_datastore()
 
         if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
-            (self.failure_email_template,) = load_jinja2_templates(
-                self.config.email_template_dir,
-                [self.config.email_registration_template_failure_html],
-            )
-
-        if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
-            (self.failure_email_template,) = load_jinja2_templates(
-                self.config.email_template_dir,
-                [self.config.email_registration_template_failure_html],
+            self._failure_email_template = (
+                self.config.email_registration_template_failure_html
             )
 
     async def on_GET(self, request, medium):
@@ -318,7 +306,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
 
             # Show a failure page with a reason
             template_vars = {"failure_reason": e.msg}
-            html = self.failure_email_template.render(**template_vars)
+            html = self._failure_email_template.render(**template_vars)
 
         respond_with_html(request, status_code, html)
 
@@ -610,12 +598,17 @@ class RegisterRestServlet(RestServlet):
                                 Codes.THREEPID_IN_USE,
                             )
 
+            entries = await self.store.get_user_agents_ips_to_ui_auth_session(
+                session_id
+            )
+
             registered_user_id = await self.registration_handler.register_user(
                 localpart=desired_username,
                 password_hash=password_hash,
                 guest_access_token=guest_access_token,
                 threepid=threepid,
                 address=client_addr,
+                user_agent_ips=entries,
             )
             # Necessary due to auth checks prior to the threepid being
             # written to the db
@@ -665,7 +658,7 @@ class RegisterRestServlet(RestServlet):
             (object) params: registration parameters, from which we pull
                 device_id, initial_device_name and inhibit_login
         Returns:
-            (object) dictionary for response from /register
+             dictionary for response from /register
         """
         result = {"user_id": user_id, "home_server": self.hs.hostname}
         if not params.get("inhibit_login", False):
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 89002ffbff..e29f49f7f5 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -22,7 +22,7 @@ any time to reflect changes in the MSC.
 import logging
 
 from synapse.api.constants import EventTypes, RelationTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import ShadowBanError, SynapseError
 from synapse.http.servlet import (
     RestServlet,
     parse_integer,
@@ -35,6 +35,7 @@ from synapse.storage.relations import (
     PaginationChunk,
     RelationPaginationToken,
 )
+from synapse.util.stringutils import random_string
 
 from ._base import client_patterns
 
@@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet):
             "sender": requester.user.to_string(),
         }
 
-        event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
-            requester, event_dict=event_dict, txn_id=txn_id
-        )
+        try:
+            (
+                event,
+                _,
+            ) = await self.event_creation_handler.create_and_send_nonmember_event(
+                requester, event_dict=event_dict, txn_id=txn_id
+            )
+            event_id = event.event_id
+        except ShadowBanError:
+            event_id = "$" + random_string(43)
 
-        return 200, {"event_id": event.event_id}
+        return 200, {"event_id": event_id}
 
 
 class RelationPaginationServlet(RestServlet):
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index f357015a70..39a5518614 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -15,13 +15,14 @@
 
 import logging
 
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, ShadowBanError, SynapseError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
     parse_json_object_from_request,
 )
+from synapse.util import stringutils
 
 from ._base import client_patterns
 
@@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet):
 
         content = parse_json_object_from_request(request)
         assert_params_in_dict(content, ("new_version",))
-        new_version = content["new_version"]
 
         new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
         if new_version is None:
@@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet):
                 Codes.UNSUPPORTED_ROOM_VERSION,
             )
 
-        new_room_id = await self._room_creation_handler.upgrade_room(
-            requester, room_id, new_version
-        )
+        try:
+            new_room_id = await self._room_creation_handler.upgrade_room(
+                requester, room_id, new_version
+            )
+        except ShadowBanError:
+            # Generate a random room ID.
+            new_room_id = stringutils.random_string(18)
 
         ret = {"replacement_room": new_room_id}
 
diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py
new file mode 100644
index 0000000000..2492634dac
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.types import UserID
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class UserSharedRoomsServlet(RestServlet):
+    """
+    GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
+    """
+
+    PATTERNS = client_patterns(
+        "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
+        releases=(),  # This is an unstable feature
+    )
+
+    def __init__(self, hs):
+        super(UserSharedRoomsServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.user_directory_active = hs.config.update_user_directory
+
+    async def on_GET(self, request, user_id):
+
+        if not self.user_directory_active:
+            raise SynapseError(
+                code=400,
+                msg="The user directory is disabled on this server. Cannot determine shared rooms.",
+                errcode=Codes.FORBIDDEN,
+            )
+
+        UserID.from_string(user_id)
+
+        requester = await self.auth.get_user_by_req(request)
+        if user_id == requester.user.to_string():
+            raise SynapseError(
+                code=400,
+                msg="You cannot request a list of shared rooms with yourself",
+                errcode=Codes.FORBIDDEN,
+            )
+        rooms = await self.store.get_shared_rooms_for_users(
+            requester.user.to_string(), user_id
+        )
+
+        return 200, {"joined": list(rooms)}
+
+
+def register_servlets(hs, http_server):
+    UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a5c24fbd63..a0b00135e1 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -16,8 +16,6 @@
 import itertools
 import logging
 
-from canonicaljson import json
-
 from synapse.api.constants import PresenceState
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state
 from synapse.handlers.sync import SyncConfig
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.types import StreamToken
+from synapse.util import json_decoder
 
 from ._base import client_patterns, set_timeline_upper_limit
 
@@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet):
             filter_collection = DEFAULT_FILTER_COLLECTION
         elif filter_id.startswith("{"):
             try:
-                filter_object = json.loads(filter_id)
+                filter_object = json_decoder.decode(filter_id)
                 set_timeline_upper_limit(
                     filter_object, self.hs.config.filter_timeline_limit
                 )
@@ -426,6 +425,7 @@ class SyncRestServlet(RestServlet):
             result["ephemeral"] = {"events": ephemeral_events}
             result["unread_notifications"] = room.unread_notifications
             result["summary"] = room.summary
+            result["org.matrix.msc2654.unread_count"] = room.unread_count
 
         return result
 
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0d668df0b6..24ac57f35d 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -60,6 +60,8 @@ class VersionsRestServlet(RestServlet):
                     "org.matrix.e2e_cross_signing": True,
                     # Implements additional endpoints as described in MSC2432
                     "org.matrix.msc2432": True,
+                    # Implements additional endpoints as described in MSC2666
+                    "uk.half-shot.msc2666": True,
                 },
             },
         )
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9b3f85b306..f843f02454 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,19 +15,19 @@
 import logging
 from typing import Dict, Set
 
-from canonicaljson import encode_canonical_json, json
 from signedjson.sign import sign_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
+from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
 
 class RemoteKey(DirectServeJsonResource):
-    """HTTP resource for retreiving the TLS certificate and NACL signature
+    """HTTP resource for retrieving the TLS certificate and NACL signature
     verification keys for a collection of servers. Checks that the reported
     X.509 TLS certificate matches the one used in the HTTPS connection. Checks
     that the NACL signature for the remote server is valid. Returns a dict of
@@ -35,7 +35,7 @@ class RemoteKey(DirectServeJsonResource):
 
     Supports individual GET APIs and a bulk query POST API.
 
-    Requsts:
+    Requests:
 
     GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
 
@@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource):
                     # Cast to bytes since postgresql returns a memoryview.
                     json_results.add(bytes(result["key_json"]))
 
+        # If there is a cache miss, request the missing keys, then recurse (and
+        # ensure the result is sent).
         if cache_misses and query_remote_on_cache_miss:
             await self.fetcher.get_keys(cache_misses)
             await self.query_keys(request, query, query_remote_on_cache_miss=False)
         else:
             signed_keys = []
             for key_json in json_results:
-                key_json = json.loads(key_json.decode("utf-8"))
+                key_json = json_decoder.decode(key_json.decode("utf-8"))
                 for signing_key in self.config.key_server_signing_keys:
                     key_json = sign_json(key_json, self.config.server_name, signing_key)
 
@@ -223,4 +225,4 @@ class RemoteKey(DirectServeJsonResource):
 
             results = {"server_keys": signed_keys}
 
-            respond_with_json_bytes(request, 200, encode_canonical_json(results))
+            respond_with_json(request, 200, results, canonical_json=True)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 20ddb9550b..6568e61829 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -235,7 +235,7 @@ async def respond_with_responder(
     finish_request(request)
 
 
-class Responder(object):
+class Responder:
     """Represents a response that can be streamed to the requester.
 
     Responder is a context manager which *must* be used, so that any resources
@@ -260,7 +260,7 @@ class Responder(object):
         pass
 
 
-class FileInfo(object):
+class FileInfo:
     """Details about a requested/uploaded file.
 
     Attributes:
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index e25c382c9c..7447eeaebe 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -33,7 +33,7 @@ def _wrap_in_base_path(func):
     return _wrapped
 
 
-class MediaFilePaths(object):
+class MediaFilePaths:
     """Describes where files are stored on disk.
 
     Most of the functions have a `*_rel` variant which returns a file path that
@@ -80,7 +80,7 @@ class MediaFilePaths(object):
         self, server_name, file_id, width, height, content_type, method
     ):
         top_level_type, sub_type = content_type.split("/")
-        file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+        file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
             "remote_thumbnail",
             server_name,
@@ -92,6 +92,23 @@ class MediaFilePaths(object):
 
     remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
 
+    # Legacy path that was used to store thumbnails previously.
+    # Should be removed after some time, when most of the thumbnails are stored
+    # using the new path.
+    def remote_media_thumbnail_rel_legacy(
+        self, server_name, file_id, width, height, content_type
+    ):
+        top_level_type, sub_type = content_type.split("/")
+        file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+        return os.path.join(
+            "remote_thumbnail",
+            server_name,
+            file_id[0:2],
+            file_id[2:4],
+            file_id[4:],
+            file_name,
+        )
+
     def remote_media_thumbnail_dir(self, server_name, file_id):
         return os.path.join(
             self.base_path,
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 6fb4039e98..69f353d46f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -53,7 +53,7 @@ from .media_storage import MediaStorage
 from .preview_url_resource import PreviewUrlResource
 from .storage_provider import StorageProviderWrapper
 from .thumbnail_resource import ThumbnailResource
-from .thumbnailer import Thumbnailer
+from .thumbnailer import Thumbnailer, ThumbnailError
 from .upload_resource import UploadResource
 
 logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ logger = logging.getLogger(__name__)
 UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
 
 
-class MediaRepository(object):
+class MediaRepository:
     def __init__(self, hs):
         self.hs = hs
         self.auth = hs.get_auth()
@@ -460,13 +460,30 @@ class MediaRepository(object):
         return t_byte_source
 
     async def generate_local_exact_thumbnail(
-        self, media_id, t_width, t_height, t_method, t_type, url_cache
-    ):
+        self,
+        media_id: str,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+        url_cache: str,
+    ) -> Optional[str]:
         input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(None, media_id, url_cache=url_cache)
         )
 
-        thumbnailer = Thumbnailer(input_path)
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s",
+                media_id,
+                t_method,
+                t_type,
+                e,
+            )
+            return None
+
         t_byte_source = await defer_to_thread(
             self.hs.get_reactor(),
             self._generate_thumbnail,
@@ -506,14 +523,36 @@ class MediaRepository(object):
 
             return output_path
 
+        # Could not generate thumbnail.
+        return None
+
     async def generate_remote_exact_thumbnail(
-        self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
-    ):
+        self,
+        server_name: str,
+        file_id: str,
+        media_id: str,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+    ) -> Optional[str]:
         input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(server_name, file_id, url_cache=False)
         )
 
-        thumbnailer = Thumbnailer(input_path)
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s",
+                media_id,
+                server_name,
+                t_method,
+                t_type,
+                e,
+            )
+            return None
+
         t_byte_source = await defer_to_thread(
             self.hs.get_reactor(),
             self._generate_thumbnail,
@@ -559,6 +598,9 @@ class MediaRepository(object):
 
             return output_path
 
+        # Could not generate thumbnail.
+        return None
+
     async def _generate_thumbnails(
         self,
         server_name: Optional[str],
@@ -590,7 +632,18 @@ class MediaRepository(object):
             FileInfo(server_name, file_id, url_cache=url_cache)
         )
 
-        thumbnailer = Thumbnailer(input_path)
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate thumbnails for remote media %s from %s using a method of %s and type of %s: %s",
+                media_id,
+                server_name,
+                media_type,
+                e,
+            )
+            return None
+
         m_width = thumbnailer.width
         m_height = thumbnailer.height
 
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index ab1fa705bf..5681677fc9 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class MediaStorage(object):
+class MediaStorage:
     """Responsible for storing/fetching files from local sources.
 
     Args:
@@ -147,6 +147,20 @@ class MediaStorage(object):
         if os.path.exists(local_path):
             return FileResponder(open(local_path, "rb"))
 
+        # Fallback for paths without method names
+        # Should be removed in the future
+        if file_info.thumbnail and file_info.server_name:
+            legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+                server_name=file_info.server_name,
+                file_id=file_info.file_id,
+                width=file_info.thumbnail_width,
+                height=file_info.thumbnail_height,
+                content_type=file_info.thumbnail_type,
+            )
+            legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+            if os.path.exists(legacy_local_path):
+                return FileResponder(open(legacy_local_path, "rb"))
+
         for provider in self.storage_providers:
             res = await provider.fetch(path, file_info)  # type: Any
             if res:
@@ -170,6 +184,20 @@ class MediaStorage(object):
         if os.path.exists(local_path):
             return local_path
 
+        # Fallback for paths without method names
+        # Should be removed in the future
+        if file_info.thumbnail and file_info.server_name:
+            legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+                server_name=file_info.server_name,
+                file_id=file_info.file_id,
+                width=file_info.thumbnail_width,
+                height=file_info.thumbnail_height,
+                content_type=file_info.thumbnail_type,
+            )
+            legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+            if os.path.exists(legacy_local_path):
+                return legacy_local_path
+
         dirname = os.path.dirname(local_path)
         if not os.path.exists(dirname):
             os.makedirs(dirname)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index cd8c246594..987765e877 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -102,7 +102,7 @@ for endpoint, globs in _oembed_globs.items():
         _oembed_patterns[re.compile(pattern)] = endpoint
 
 
-@attr.s
+@attr.s(slots=True)
 class OEmbedResult:
     # Either HTML content or URL must be provided.
     html = attr.ib(type=Optional[str])
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a83535b97b..30421b663a 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,6 +16,7 @@
 
 import logging
 
+from synapse.api.errors import SynapseError
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
 from synapse.http.servlet import parse_integer, parse_string
 
@@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource):
             await respond_with_file(request, desired_type, file_path)
         else:
             logger.warning("Failed to generate thumbnail")
-            respond_404(request)
+            raise SynapseError(400, "Failed to generate thumbnail.")
 
     async def _select_or_generate_remote_thumbnail(
         self,
@@ -235,7 +236,7 @@ class ThumbnailResource(DirectServeJsonResource):
             await respond_with_file(request, desired_type, file_path)
         else:
             logger.warning("Failed to generate thumbnail")
-            respond_404(request)
+            raise SynapseError(400, "Failed to generate thumbnail.")
 
     async def _respond_remote_thumbnail(
         self, request, server_name, media_id, width, height, method, m_type
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 7126997134..32a8e4f960 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -15,7 +15,7 @@
 import logging
 from io import BytesIO
 
-from PIL import Image as Image
+from PIL import Image
 
 logger = logging.getLogger(__name__)
 
@@ -31,12 +31,22 @@ EXIF_TRANSPOSE_MAPPINGS = {
 }
 
 
-class Thumbnailer(object):
+class ThumbnailError(Exception):
+    """An error occurred generating a thumbnail."""
+
+
+class Thumbnailer:
 
     FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
 
     def __init__(self, input_path):
-        self.image = Image.open(input_path)
+        try:
+            self.image = Image.open(input_path)
+        except OSError as e:
+            # If an error occurs opening the image, a thumbnail won't be able to
+            # be generated.
+            raise ThumbnailError from e
+
         self.width, self.height = self.image.size
         self.transpose_method = None
         try:
@@ -73,7 +83,7 @@ class Thumbnailer(object):
 
         Args:
             max_width: The largest possible width.
-            max_height: The larget possible height.
+            max_height: The largest possible height.
         """
 
         if max_width * self.height < max_height * self.width:
@@ -107,7 +117,7 @@ class Thumbnailer(object):
 
         Args:
             max_width: The largest possible width.
-            max_height: The larget possible height.
+            max_height: The largest possible height.
 
         Returns:
             BytesIO: the bytes of the encoded image ready to be written to disk
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index c10188a5d7..f6668fb5e3 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -13,10 +13,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.python import failure
 
-from synapse.api.errors import SynapseError
-from synapse.http.server import DirectServeHtmlResource, return_html_error
+from synapse.http.server import DirectServeHtmlResource
 
 
 class SAML2ResponseResource(DirectServeHtmlResource):
@@ -27,21 +25,15 @@ class SAML2ResponseResource(DirectServeHtmlResource):
     def __init__(self, hs):
         super().__init__()
         self._saml_handler = hs.get_saml_handler()
-        self._error_html_template = hs.config.saml2.saml2_error_html_template
 
     async def _async_render_GET(self, request):
         # We're not expecting any GET request on that resource if everything goes right,
         # but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
         # In this case, just tell the user that something went wrong and they should
         # try to authenticate again.
-        f = failure.Failure(
-            SynapseError(400, "Unexpected GET request on /saml2/authn_response")
+        self._saml_handler._render_error(
+            request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
         )
-        return_html_error(f, request, self._error_html_template)
 
     async def _async_render_POST(self, request):
-        try:
-            await self._saml_handler.handle_saml_response(request)
-        except Exception:
-            f = failure.Failure()
-            return_html_error(f, request, self._error_html_template)
+        await self._saml_handler.handle_saml_response(request)
diff --git a/synapse/rest/synapse/__init__.py b/synapse/rest/synapse/__init__.py
new file mode 100644
index 0000000000..c0b733488b
--- /dev/null
+++ b/synapse/rest/synapse/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
new file mode 100644
index 0000000000..c0b733488b
--- /dev/null
+++ b/synapse/rest/synapse/client/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
new file mode 100644
index 0000000000..9e4fbc0cbd
--- /dev/null
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
+
+from synapse.api.errors import ThreepidValidationError
+from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.http.server import DirectServeHtmlResource
+from synapse.http.servlet import parse_string
+from synapse.util.stringutils import assert_valid_client_secret
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
+    """Handles 3PID validation token submission
+
+    This resource gets mounted under /_synapse/client/password_reset/email/submit_token
+    """
+
+    isLeaf = 1
+
+    def __init__(self, hs: "HomeServer"):
+        """
+        Args:
+            hs: server
+        """
+        super().__init__()
+
+        self.clock = hs.get_clock()
+        self.store = hs.get_datastore()
+
+        self._local_threepid_handling_disabled_due_to_email_config = (
+            hs.config.local_threepid_handling_disabled_due_to_email_config
+        )
+        self._confirmation_email_template = (
+            hs.config.email_password_reset_template_confirmation_html
+        )
+        self._email_password_reset_template_success_html = (
+            hs.config.email_password_reset_template_success_html_content
+        )
+        self._failure_email_template = (
+            hs.config.email_password_reset_template_failure_html
+        )
+
+        # This resource should not be mounted if threepid behaviour is not LOCAL
+        assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+
+    async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
+        sid = parse_string(request, "sid", required=True)
+        token = parse_string(request, "token", required=True)
+        client_secret = parse_string(request, "client_secret", required=True)
+        assert_valid_client_secret(client_secret)
+
+        # Show a confirmation page, just in case someone accidentally clicked this link when
+        # they didn't mean to
+        template_vars = {
+            "sid": sid,
+            "token": token,
+            "client_secret": client_secret,
+        }
+        return (
+            200,
+            self._confirmation_email_template.render(**template_vars).encode("utf-8"),
+        )
+
+    async def _async_render_POST(self, request: Request) -> Tuple[int, bytes]:
+        sid = parse_string(request, "sid", required=True)
+        token = parse_string(request, "token", required=True)
+        client_secret = parse_string(request, "client_secret", required=True)
+
+        # Attempt to validate a 3PID session
+        try:
+            # Mark the session as valid
+            next_link = await self.store.validate_threepid_session(
+                sid, client_secret, token, self.clock.time_msec()
+            )
+
+            # Perform a 302 redirect if next_link is set
+            if next_link:
+                if next_link.startswith("file:///"):
+                    logger.warning(
+                        "Not redirecting to next_link as it is a local file: address"
+                    )
+                else:
+                    next_link_bytes = next_link.encode("utf-8")
+                    request.setHeader("Location", next_link_bytes)
+                    return (
+                        302,
+                        (
+                            b'You are being redirected to <a src="%s">%s</a>.'
+                            % (next_link_bytes, next_link_bytes)
+                        ),
+                    )
+
+            # Otherwise show the success template
+            html_bytes = self._email_password_reset_template_success_html.encode(
+                "utf-8"
+            )
+            status_code = 200
+        except ThreepidValidationError as e:
+            status_code = e.code
+
+            # Show a failure page with a reason
+            template_vars = {"failure_reason": e.msg}
+            html_bytes = self._failure_email_template.render(**template_vars).encode(
+                "utf-8"
+            )
+
+        return status_code, html_bytes
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 20177b44e7..f591cc6c5c 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -13,17 +13,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import logging
 
 from twisted.web.resource import Resource
 
 from synapse.http.server import set_cors_headers
+from synapse.util import json_encoder
 
 logger = logging.getLogger(__name__)
 
 
-class WellKnownBuilder(object):
+class WellKnownBuilder:
     """Utility to construct the well-known response
 
     Args:
@@ -67,4 +67,4 @@ class WellKnownResource(Resource):
 
         logger.debug("returning: %s", r)
         request.setHeader(b"Content-Type", b"application/json")
-        return json.dumps(r).encode("utf-8")
+        return json_encoder.encode(r).encode("utf-8")
diff --git a/synapse/secrets.py b/synapse/secrets.py
index ff86950a54..fb6d90a3b7 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -37,7 +37,7 @@ else:
     import binascii
     import os
 
-    class Secrets(object):
+    class Secrets:
         def token_bytes(self, nbytes=32):
             return os.urandom(nbytes)
 
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 089cfef0b3..3673e7f47e 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -23,7 +23,7 @@ from synapse.types import get_localpart_from_id
 logger = logging.getLogger(__name__)
 
 
-class ConsentServerNotices(object):
+class ConsentServerNotices:
     """Keeps track of whether we need to send users server_notices about
     privacy policy consent, and sends one if we do.
     """
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index c2faef6eab..2258d306d9 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -27,7 +27,7 @@ from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG
 logger = logging.getLogger(__name__)
 
 
-class ResourceLimitsServerNotices(object):
+class ResourceLimitsServerNotices:
     """ Keeps track of whether the server has reached it's resource limit and
     ensures that the client is kept up to date.
     """
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index ed96aa8571..0422d4c7ce 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 SERVER_NOTICE_ROOM_TAG = "m.server_notice"
 
 
-class ServerNoticesManager(object):
+class ServerNoticesManager:
     def __init__(self, hs):
         """
 
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
index a754f75db4..6870b67ca0 100644
--- a/synapse/server_notices/server_notices_sender.py
+++ b/synapse/server_notices/server_notices_sender.py
@@ -20,7 +20,7 @@ from synapse.server_notices.resource_limits_server_notices import (
 )
 
 
-class ServerNoticesSender(object):
+class ServerNoticesSender:
     """A centralised place which sends server notices automatically when
     Certain Events take place
     """
diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py
index e9390b19da..9273e61895 100644
--- a/synapse/server_notices/worker_server_notices_sender.py
+++ b/synapse/server_notices/worker_server_notices_sender.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 
-class WorkerServerNoticesSender(object):
+class WorkerServerNoticesSender:
     """Stub impl of ServerNoticesSender which does nothing"""
 
     def __init__(self, hs):
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 9b78924d96..395ac5ab02 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from enum import Enum
 
 from twisted.internet import defer
 
@@ -25,7 +26,17 @@ if MYPY:
 logger = logging.getLogger(__name__)
 
 
-class SpamCheckerApi(object):
+class RegistrationBehaviour(Enum):
+    """
+    Enum to define whether a registration request should allowed, denied, or shadow-banned.
+    """
+
+    ALLOW = "allow"
+    SHADOW_BAN = "shadow_ban"
+    DENY = "deny"
+
+
+class SpamCheckerApi:
     """A proxy object that gets passed to spam checkers so they can get
     access to rooms and other relevant information.
     """
@@ -48,8 +59,10 @@ class SpamCheckerApi(object):
             twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
                 The filtered state events in the room.
         """
-        state_ids = yield self._store.get_filtered_current_state_ids(
-            room_id=room_id, state_filter=StateFilter.from_types(types)
+        state_ids = yield defer.ensureDeferred(
+            self._store.get_filtered_current_state_ids(
+                room_id=room_id, state_filter=StateFilter.from_types(types)
+            )
         )
-        state = yield self._store.get_events(state_ids.values())
+        state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
         return state.values()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index a1d3884667..56d6afb863 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,11 +16,23 @@
 
 import logging
 from collections import namedtuple
-from typing import Awaitable, Dict, Iterable, List, Optional, Set
+from typing import (
+    Awaitable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Union,
+    cast,
+    overload,
+)
 
 import attr
 from frozendict import frozendict
 from prometheus_client import Histogram
+from typing_extensions import Literal
 
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -30,7 +42,7 @@ from synapse.logging.utils import log_function
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
-from synapse.types import StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
 from synapse.util import Clock
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -65,11 +77,17 @@ def _gen_state_id():
     return s
 
 
-class _StateCacheEntry(object):
+class _StateCacheEntry:
     __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
-    def __init__(self, state, state_group, prev_group=None, delta_ids=None):
-        # dict[(str, str), str] map  from (type, state_key) to event_id
+    def __init__(
+        self,
+        state: StateMap[str],
+        state_group: Optional[int],
+        prev_group: Optional[int] = None,
+        delta_ids: Optional[StateMap[str]] = None,
+    ):
+        # A map from (type, state_key) to event_id.
         self.state = frozendict(state)
 
         # the ID of a state group if one and only one is involved.
@@ -95,7 +113,7 @@ class _StateCacheEntry(object):
         return len(self.state)
 
 
-class StateHandler(object):
+class StateHandler:
     """Fetches bits of state from the stores, and does state resolution
     where necessary
     """
@@ -107,24 +125,49 @@ class StateHandler(object):
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
+    @overload
     async def get_current_state(
-        self, room_id, event_type=None, state_key="", latest_event_ids=None
-    ):
-        """ Retrieves the current state for the room. This is done by
+        self,
+        room_id: str,
+        event_type: Literal[None] = None,
+        state_key: str = "",
+        latest_event_ids: Optional[List[str]] = None,
+    ) -> StateMap[EventBase]:
+        ...
+
+    @overload
+    async def get_current_state(
+        self,
+        room_id: str,
+        event_type: str,
+        state_key: str = "",
+        latest_event_ids: Optional[List[str]] = None,
+    ) -> Optional[EventBase]:
+        ...
+
+    async def get_current_state(
+        self,
+        room_id: str,
+        event_type: Optional[str] = None,
+        state_key: str = "",
+        latest_event_ids: Optional[List[str]] = None,
+    ) -> Union[Optional[EventBase], StateMap[EventBase]]:
+        """Retrieves the current state for the room. This is done by
         calling `get_latest_events_in_room` to get the leading edges of the
         event graph and then resolving any of the state conflicts.
 
         This is equivalent to getting the state of an event that were to send
         next before receiving any new events.
 
-        If `event_type` is specified, then the method returns only the one
-        event (or None) with that `event_type` and `state_key`.
-
         Returns:
-            map from (type, state_key) to event
+            If `event_type` is specified, then the method returns only the one
+            event (or None) with that `event_type` and `state_key`.
+
+            Otherwise, a map from (type, state_key) to event.
         """
         if not latest_event_ids:
             latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+        assert latest_event_ids is not None
 
         logger.debug("calling resolve_state_groups from get_current_state")
         ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
@@ -140,34 +183,30 @@ class StateHandler(object):
         state_map = await self.store.get_events(
             list(state.values()), get_prev_content=False
         )
-        state = {
+        return {
             key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
         }
 
-        return state
-
-    async def get_current_state_ids(self, room_id, latest_event_ids=None):
+    async def get_current_state_ids(
+        self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
+    ) -> StateMap[str]:
         """Get the current state, or the state at a set of events, for a room
 
         Args:
-            room_id (str):
-
-            latest_event_ids (iterable[str]|None): if given, the forward
-                extremities to resolve. If None, we look them up from the
-                database (via a cache)
+            room_id:
+            latest_event_ids: if given, the forward extremities to resolve. If
+                None, we look them up from the database (via a cache).
 
         Returns:
-            Deferred[dict[(str, str), str)]]: the state dict, mapping from
-                (event_type, state_key) -> event_id
+            the state dict, mapping from (event_type, state_key) -> event_id
         """
         if not latest_event_ids:
             latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+        assert latest_event_ids is not None
 
         logger.debug("calling resolve_state_groups from get_current_state_ids")
         ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        state = ret.state
-
-        return state
+        return ret.state
 
     async def get_current_users_in_room(
         self, room_id: str, latest_event_ids: Optional[List[str]] = None
@@ -183,32 +222,34 @@ class StateHandler(object):
         """
         if not latest_event_ids:
             latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+        assert latest_event_ids is not None
+
         logger.debug("calling resolve_state_groups from get_current_users_in_room")
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        joined_users = await self.store.get_joined_users_from_state(room_id, entry)
-        return joined_users
+        return await self.store.get_joined_users_from_state(room_id, entry)
 
-    async def get_current_hosts_in_room(self, room_id):
+    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
         event_ids = await self.store.get_latest_event_ids_in_room(room_id)
         return await self.get_hosts_in_room_at_events(room_id, event_ids)
 
-    async def get_hosts_in_room_at_events(self, room_id, event_ids):
+    async def get_hosts_in_room_at_events(
+        self, room_id: str, event_ids: List[str]
+    ) -> Set[str]:
         """Get the hosts that were in a room at the given event ids
 
         Args:
-            room_id (str):
-            event_ids (list[str]):
+            room_id:
+            event_ids:
 
         Returns:
-            Deferred[list[str]]: the hosts in the room at the given events
+            The hosts in the room at the given events
         """
         entry = await self.resolve_state_groups_for_events(room_id, event_ids)
-        joined_hosts = await self.store.get_joined_hosts(room_id, entry)
-        return joined_hosts
+        return await self.store.get_joined_hosts(room_id, entry)
 
     async def compute_event_context(
         self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
-    ):
+    ) -> EventContext:
         """Build an EventContext structure for the event.
 
         This works out what the current state should be for the event, and
@@ -221,7 +262,7 @@ class StateHandler(object):
                 when receiving an event from federation where we don't have the
                 prev events for, e.g. when backfilling.
         Returns:
-            synapse.events.snapshot.EventContext:
+            The event context.
         """
 
         if event.internal_metadata.is_outlier():
@@ -262,7 +303,7 @@ class StateHandler(object):
             # if we're given the state before the event, then we use that
             state_ids_before_event = {
                 (s.type, s.state_key): s.event_id for s in old_state
-            }
+            }  # type: StateMap[str]
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
@@ -346,19 +387,18 @@ class StateHandler(object):
         )
 
     @measure_func()
-    async def resolve_state_groups_for_events(self, room_id, event_ids):
+    async def resolve_state_groups_for_events(
+        self, room_id: str, event_ids: Iterable[str]
+    ) -> _StateCacheEntry:
         """ Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
         Args:
-            room_id (str)
-            event_ids (list[str])
-            explicit_room_version (str|None): If set uses the the given room
-                version to choose the resolution algorithm. If None, then
-                checks the database for room version.
+            room_id
+            event_ids
 
         Returns:
-            Deferred[_StateCacheEntry]: resolved state
+            The resolved state
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
@@ -394,7 +434,12 @@ class StateHandler(object):
         )
         return result
 
-    async def resolve_events(self, room_version, state_sets, event):
+    async def resolve_events(
+        self,
+        room_version: str,
+        state_sets: Collection[Iterable[EventBase]],
+        event: EventBase,
+    ) -> StateMap[EventBase]:
         logger.info(
             "Resolving state for %s with %d groups", event.room_id, len(state_sets)
         )
@@ -414,12 +459,10 @@ class StateHandler(object):
                 state_res_store=StateResolutionStore(self.store),
             )
 
-        new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
+        return {key: state_map[ev_id] for key, ev_id in new_state.items()}
 
-        return new_state
 
-
-class StateResolutionHandler(object):
+class StateResolutionHandler:
     """Responsible for doing state conflict resolution.
 
     Note that the storage layer depends on this handler, so all functions must
@@ -444,7 +487,12 @@ class StateResolutionHandler(object):
 
     @log_function
     async def resolve_state_groups(
-        self, room_id, room_version, state_groups_ids, event_map, state_res_store
+        self,
+        room_id: str,
+        room_version: str,
+        state_groups_ids: Dict[int, StateMap[str]],
+        event_map: Optional[Dict[str, EventBase]],
+        state_res_store: "StateResolutionStore",
     ):
         """Resolves conflicts between a set of state groups
 
@@ -452,13 +500,13 @@ class StateResolutionHandler(object):
         not be called for a single state group
 
         Args:
-            room_id (str): room we are resolving for (used for logging and sanity checks)
-            room_version (str): version of the room
-            state_groups_ids (dict[int, dict[(str, str), str]]):
-                 map from state group id to the state in that state group
+            room_id: room we are resolving for (used for logging and sanity checks)
+            room_version: version of the room
+            state_groups_ids:
+                A map from state group id to the state in that state group
                 (where 'state' is a map from state key to event id)
 
-            event_map(dict[str,FrozenEvent]|None):
+            event_map:
                 a dict from event_id to event, for any events that we happen to
                 have in flight (eg, those currently being persisted). This will be
                 used as a starting point fof finding the state we need; any missing
@@ -466,10 +514,10 @@ class StateResolutionHandler(object):
 
                 If None, all events will be fetched via state_res_store.
 
-            state_res_store (StateResolutionStore)
+            state_res_store
 
         Returns:
-            _StateCacheEntry: resolved state
+            The resolved state
         """
         logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
 
@@ -493,7 +541,7 @@ class StateResolutionHandler(object):
             #
             # XXX: is this actually worthwhile, or should we just let
             # resolve_events_with_store do it?
-            new_state = {}
+            new_state = {}  # type: MutableStateMap[str]
             conflicted_state = False
             for st in state_groups_ids.values():
                 for key, e_id in st.items():
@@ -507,13 +555,20 @@ class StateResolutionHandler(object):
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
-                    new_state = await resolve_events_with_store(
-                        self.clock,
-                        room_id,
-                        room_version,
-                        list(state_groups_ids.values()),
-                        event_map=event_map,
-                        state_res_store=state_res_store,
+                    # resolve_events_with_store returns a StateMap, but we can
+                    # treat it as a MutableStateMap as it is above. It isn't
+                    # actually mutated anymore (and is frozen in
+                    # _make_state_cache_entry below).
+                    new_state = cast(
+                        MutableStateMap,
+                        await resolve_events_with_store(
+                            self.clock,
+                            room_id,
+                            room_version,
+                            list(state_groups_ids.values()),
+                            event_map=event_map,
+                            state_res_store=state_res_store,
+                        ),
                     )
 
             # if the new state matches any of the input state groups, we can
@@ -530,21 +585,22 @@ class StateResolutionHandler(object):
             return cache
 
 
-def _make_state_cache_entry(new_state, state_groups_ids):
+def _make_state_cache_entry(
+    new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
+) -> _StateCacheEntry:
     """Given a resolved state, and a set of input state groups, pick one to base
     a new state group on (if any), and return an appropriately-constructed
     _StateCacheEntry.
 
     Args:
-        new_state (dict[(str, str), str]): resolved state map (mapping from
-           (type, state_key) to event_id)
+        new_state: resolved state map (mapping from (type, state_key) to event_id)
 
-        state_groups_ids (dict[int, dict[(str, str), str]]):
-                 map from state group id to the state in that state group
-                (where 'state' is a map from state key to event id)
+        state_groups_ids:
+            map from state group id to the state in that state group (where
+            'state' is a map from state key to event id)
 
     Returns:
-        _StateCacheEntry
+        The cache entry.
     """
     # if the new state matches any of the input state groups, we can
     # use that state group again. Otherwise we will generate a state_id
@@ -585,7 +641,7 @@ def resolve_events_with_store(
     clock: Clock,
     room_id: str,
     room_version: str,
-    state_sets: List[StateMap[str]],
+    state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "StateResolutionStore",
 ) -> Awaitable[StateMap[str]]:
@@ -622,8 +678,8 @@ def resolve_events_with_store(
         )
 
 
-@attr.s
-class StateResolutionStore(object):
+@attr.s(slots=True)
+class StateResolutionStore:
     """Interface that allows state resolution algorithms to access the database
     in well defined way.
 
@@ -633,15 +689,17 @@ class StateResolutionStore(object):
 
     store = attr.ib()
 
-    def get_events(self, event_ids, allow_rejected=False):
+    def get_events(
+        self, event_ids: Iterable[str], allow_rejected: bool = False
+    ) -> Awaitable[Dict[str, EventBase]]:
         """Get events from the database
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            allow_rejected (bool): If True return rejected events.
+            event_ids: The event_ids of the events to fetch
+            allow_rejected: If True return rejected events.
 
         Returns:
-            Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+            An awaitable which resolves to a dict from event_id to event.
         """
 
         return self.store.get_events(
@@ -651,7 +709,9 @@ class StateResolutionStore(object):
             allow_rejected=allow_rejected,
         )
 
-    def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+    def get_auth_chain_difference(
+        self, state_sets: List[Set[str]]
+    ) -> Awaitable[Set[str]]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
 
@@ -660,7 +720,7 @@ class StateResolutionStore(object):
         chain.
 
         Returns:
-            Deferred[Set[str]]: Set of event IDs.
+            An awaitable that resolves to a set of event IDs.
         """
 
         return self.store.get_auth_chain_difference(state_sets)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index ab5e24841d..a493279cbd 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,14 +15,24 @@
 
 import hashlib
 import logging
-from typing import Awaitable, Callable, Dict, List, Optional
+from typing import (
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
 
 logger = logging.getLogger(__name__)
 
@@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 
 async def resolve_events_with_store(
     room_id: str,
-    state_sets: List[StateMap[str]],
+    state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_map_factory: Callable[[List[str]], Awaitable],
-):
+    state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+) -> StateMap[str]:
     """
     Args:
         room_id: the room we are working in
@@ -56,8 +66,7 @@ async def resolve_events_with_store(
             an Awaitable that resolves to a dict of event_id to event.
 
     Returns:
-        Deferred[dict[(str, str), str]]:
-            a map from (type, state_key) to event_id.
+        A map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         return state_sets[0]
@@ -75,8 +84,8 @@ async def resolve_events_with_store(
         "Asking for %d/%d conflicted events", len(needed_events), needed_event_count
     )
 
-    # dict[str, FrozenEvent]: a map from state event id to event. Only includes
-    # the state events which are in conflict (and those in event_map)
+    # A map from state event id to event. Only includes the state events which
+    # are in conflict (and those in event_map).
     state_map = await state_map_factory(needed_events)
     if event_map is not None:
         state_map.update(event_map)
@@ -91,8 +100,6 @@ async def resolve_events_with_store(
 
     # get the ids of the auth events which allow us to authenticate the
     # conflicted state, picking only from the unconflicting state.
-    #
-    # dict[(str, str), str]: a map from state key to event id
     auth_events = _create_auth_events_from_maps(
         unconflicted_state, conflicted_state, state_map
     )
@@ -122,29 +129,30 @@ async def resolve_events_with_store(
     )
 
 
-def _seperate(state_sets):
+def _seperate(
+    state_sets: Iterable[StateMap[str]],
+) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]:
     """Takes the state_sets and figures out which keys are conflicted and
     which aren't. i.e., which have multiple different event_ids associated
     with them in different state sets.
 
     Args:
-        state_sets(iterable[dict[(str, str), str]]):
+        state_sets:
             List of dicts of (type, state_key) -> event_id, which are the
             different state groups to resolve.
 
     Returns:
-        (dict[(str, str), str], dict[(str, str), set[str]]):
-            A tuple of (unconflicted_state, conflicted_state), where:
+        A tuple of (unconflicted_state, conflicted_state), where:
 
-            unconflicted_state is a dict mapping (type, state_key)->event_id
-            for unconflicted state keys.
+        unconflicted_state is a dict mapping (type, state_key)->event_id
+        for unconflicted state keys.
 
-            conflicted_state is a dict mapping (type, state_key) to a set of
-            event ids for conflicted state keys.
+        conflicted_state is a dict mapping (type, state_key) to a set of
+        event ids for conflicted state keys.
     """
     state_set_iterator = iter(state_sets)
     unconflicted_state = dict(next(state_set_iterator))
-    conflicted_state = {}
+    conflicted_state = {}  # type: MutableStateMap[Set[str]]
 
     for state_set in state_set_iterator:
         for key, value in state_set.items():
@@ -171,7 +179,21 @@ def _seperate(state_sets):
     return unconflicted_state, conflicted_state
 
 
-def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
+def _create_auth_events_from_maps(
+    unconflicted_state: StateMap[str],
+    conflicted_state: StateMap[Set[str]],
+    state_map: Dict[str, EventBase],
+) -> StateMap[str]:
+    """
+
+    Args:
+        unconflicted_state: The unconflicted state map.
+        conflicted_state: The conflicted state map.
+        state_map:
+
+    Returns:
+        A map from state key to event id.
+    """
     auth_events = {}
     for event_ids in conflicted_state.values():
         for event_id in event_ids:
@@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
                 keys = event_auth.auth_types_for_event(state_map[event_id])
                 for key in keys:
                     if key not in auth_events:
-                        event_id = unconflicted_state.get(key, None)
-                        if event_id:
-                            auth_events[key] = event_id
+                        auth_event_id = unconflicted_state.get(key, None)
+                        if auth_event_id:
+                            auth_events[key] = auth_event_id
     return auth_events
 
 
 def _resolve_with_state(
-    unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
+    unconflicted_state_ids: MutableStateMap[str],
+    conflicted_state_ids: StateMap[Set[str]],
+    auth_event_ids: StateMap[str],
+    state_map: Dict[str, EventBase],
 ):
     conflicted_state = {}
     for key, event_ids in conflicted_state_ids.items():
@@ -215,7 +240,9 @@ def _resolve_with_state(
     return new_state
 
 
-def _resolve_state_events(conflicted_state, auth_events):
+def _resolve_state_events(
+    conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
+) -> StateMap[EventBase]:
     """ This is where we actually decide which of the conflicted state to
     use.
 
@@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
     return resolved_state
 
 
-def _resolve_auth_events(events, auth_events):
+def _resolve_auth_events(
+    events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
     reverse = list(reversed(_ordered_events(events)))
 
     auth_keys = {
@@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
     return event
 
 
-def _resolve_normal_events(events, auth_events):
+def _resolve_normal_events(
+    events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
     for event in _ordered_events(events):
         try:
             # The signatures have already been checked at this point
@@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
     return event
 
 
-def _ordered_events(events):
+def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
     def key_func(e):
         # we have to use utf-8 rather than ascii here because it turns out we allow
         # people to send us events with non-ascii event IDs :/
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 6634955cdc..edf94e7ad6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,7 +16,21 @@
 import heapq
 import itertools
 import logging
-from typing import Dict, List, Optional
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    overload,
+)
+
+from typing_extensions import Literal
 
 import synapse.state
 from synapse import event_auth
@@ -24,7 +38,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -40,10 +54,10 @@ async def resolve_events_with_store(
     clock: Clock,
     room_id: str,
     room_version: str,
-    state_sets: List[StateMap[str]],
+    state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "synapse.state.StateResolutionStore",
-):
+) -> StateMap[str]:
     """Resolves the state using the v2 state resolution algorithm
 
     Args:
@@ -63,8 +77,7 @@ async def resolve_events_with_store(
         state_res_store:
 
     Returns:
-        Deferred[dict[(str, str), str]]:
-            a map from (type, state_key) to event_id.
+        A map from (type, state_key) to event_id.
     """
 
     logger.debug("Computing conflicted state")
@@ -171,18 +184,23 @@ async def resolve_events_with_store(
     return resolved_state
 
 
-async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
+async def _get_power_level_for_sender(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> int:
     """Return the power level of the sender of the given event according to
     their auth events.
 
     Args:
-        room_id (str)
-        event_id (str)
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        room_id
+        event_id
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[int]
+        The power level.
     """
     event = await _get_event(room_id, event_id, event_map, state_res_store)
 
@@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st
         return int(level)
 
 
-async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
+async def _get_auth_chain_difference(
+    state_sets: Sequence[StateMap[str]],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> Set[str]:
     """Compare the auth chains of each state set and return the set of events
     that only appear in some but not all of the auth chains.
 
     Args:
-        state_sets (list)
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        state_sets
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[set[str]]: Set of event IDs
+        Set of event IDs
     """
 
     difference = await state_res_store.get_auth_chain_difference(
@@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
     return difference
 
 
-def _seperate(state_sets):
+def _seperate(
+    state_sets: Iterable[StateMap[str]],
+) -> Tuple[StateMap[str], StateMap[Set[str]]]:
     """Return the unconflicted and conflicted state. This is different than in
     the original algorithm, as this defines a key to be conflicted if one of
     the state sets doesn't have that key.
 
     Args:
-        state_sets (list)
+        state_sets
 
     Returns:
-        tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
-        conflicted state dict is a map from type/state_key to set of event IDs
+        A tuple of unconflicted and conflicted state. The conflicted state dict
+        is a map from type/state_key to set of event IDs
     """
     unconflicted_state = {}
     conflicted_state = {}
@@ -260,18 +284,20 @@ def _seperate(state_sets):
             event_ids.discard(None)
             conflicted_state[key] = event_ids
 
-    return unconflicted_state, conflicted_state
+    # mypy doesn't understand that discarding None above means that conflicted
+    # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
+    return unconflicted_state, conflicted_state  # type: ignore
 
 
-def _is_power_event(event):
+def _is_power_event(event: EventBase) -> bool:
     """Return whether or not the event is a "power event", as defined by the
     v2 state resolution algorithm
 
     Args:
-        event (FrozenEvent)
+        event
 
     Returns:
-        boolean
+        True if the event is a power event.
     """
     if (event.type, event.state_key) in (
         (EventTypes.PowerLevels, ""),
@@ -288,19 +314,23 @@ def _is_power_event(event):
 
 
 async def _add_event_and_auth_chain_to_graph(
-    graph, room_id, event_id, event_map, state_res_store, auth_diff
-):
+    graph: Dict[str, Set[str]],
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    auth_diff: Set[str],
+) -> None:
     """Helper function for _reverse_topological_power_sort that add the event
     and its auth chain (that is in the auth diff) to the graph
 
     Args:
-        graph (dict[str, set[str]]): A map from event ID to the events auth
-            event IDs
-        room_id (str): the room we are working in
-        event_id (str): Event to add to the graph
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
-        auth_diff (set[str]): Set of event IDs that are in the auth difference.
+        graph: A map from event ID to the events auth event IDs
+        room_id: the room we are working in
+        event_id: Event to add to the graph
+        event_map
+        state_res_store
+        auth_diff: Set of event IDs that are in the auth difference.
     """
 
     state = [event_id]
@@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph(
 
 
 async def _reverse_topological_power_sort(
-    clock, room_id, event_ids, event_map, state_res_store, auth_diff
-):
+    clock: Clock,
+    room_id: str,
+    event_ids: Iterable[str],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    auth_diff: Set[str],
+) -> List[str]:
     """Returns a list of the event_ids sorted by reverse topological ordering,
     and then by power level and origin_server_ts
 
     Args:
-        clock (Clock)
-        room_id (str): the room we are working in
-        event_ids (list[str]): The events to sort
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
-        auth_diff (set[str]): Set of event IDs that are in the auth difference.
+        clock
+        room_id: the room we are working in
+        event_ids: The events to sort
+        event_map
+        state_res_store
+        auth_diff: Set of event IDs that are in the auth difference.
 
     Returns:
-        Deferred[list[str]]: The sorted list
+        The sorted list
     """
 
-    graph = {}
+    graph = {}  # type: Dict[str, Set[str]]
     for idx, event_id in enumerate(event_ids, start=1):
         await _add_event_and_auth_chain_to_graph(
             graph, room_id, event_id, event_map, state_res_store, auth_diff
@@ -372,24 +407,30 @@ async def _reverse_topological_power_sort(
 
 
 async def _iterative_auth_checks(
-    clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
-):
+    clock: Clock,
+    room_id: str,
+    room_version: str,
+    event_ids: List[str],
+    base_state: StateMap[str],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> MutableStateMap[str]:
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
 
     Args:
-        clock (Clock)
-        room_id (str)
-        room_version (str)
-        event_ids (list[str]): Ordered list of events to apply auth checks to
-        base_state (StateMap[str]): The set of state to start with
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        clock
+        room_id
+        room_version
+        event_ids: Ordered list of events to apply auth checks to
+        base_state: The set of state to start with
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[StateMap[str]]: Returns the final updated state
+        Returns the final updated state
     """
-    resolved_state = base_state.copy()
+    resolved_state = dict(base_state)
     room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
     for idx, event_id in enumerate(event_ids, start=1):
@@ -439,21 +480,26 @@ async def _iterative_auth_checks(
 
 
 async def _mainline_sort(
-    clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
-):
+    clock: Clock,
+    room_id: str,
+    event_ids: List[str],
+    resolved_power_event_id: Optional[str],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> List[str]:
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
 
     Args:
-        clock (Clock)
-        room_id (str): room we're working in
-        event_ids (list[str]): Events to sort
-        resolved_power_event_id (str): The final resolved power level event ID
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        clock
+        room_id: room we're working in
+        event_ids: Events to sort
+        resolved_power_event_id: The final resolved power level event ID
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[list[str]]: The sorted list
+        The sorted list
     """
     if not event_ids:
         # It's possible for there to be no event IDs here to sort, so we can
@@ -505,59 +551,90 @@ async def _mainline_sort(
 
 
 async def _get_mainline_depth_for_event(
-    event, mainline_map, event_map, state_res_store
-):
+    event: EventBase,
+    mainline_map: Dict[str, int],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> int:
     """Get the mainline depths for the given event based on the mainline map
 
     Args:
-        event (FrozenEvent)
-        mainline_map (dict[str, int]): Map from event_id to mainline depth for
-            events in the mainline.
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        event
+        mainline_map: Map from event_id to mainline depth for events in the mainline.
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[int]
+        The mainline depth
     """
 
     room_id = event.room_id
+    tmp_event = event  # type: Optional[EventBase]
 
     # We do an iterative search, replacing `event with the power level in its
     # auth events (if any)
-    while event:
+    while tmp_event:
         depth = mainline_map.get(event.event_id)
         if depth is not None:
             return depth
 
-        auth_events = event.auth_event_ids()
-        event = None
+        auth_events = tmp_event.auth_event_ids()
+        tmp_event = None
 
         for aid in auth_events:
             aev = await _get_event(
                 room_id, aid, event_map, state_res_store, allow_none=True
             )
             if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
-                event = aev
+                tmp_event = aev
                 break
 
     # Didn't find a power level auth event, so we just return 0
     return 0
 
 
-async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
+@overload
+async def _get_event(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    allow_none: Literal[False] = False,
+) -> EventBase:
+    ...
+
+
+@overload
+async def _get_event(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    allow_none: Literal[True],
+) -> Optional[EventBase]:
+    ...
+
+
+async def _get_event(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    allow_none: bool = False,
+) -> Optional[EventBase]:
     """Helper function to look up event in event_map, falling back to looking
     it up in the store
 
     Args:
-        room_id (str)
-        event_id (str)
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
-        allow_none (bool): if the event is not found, return None rather than raising
+        room_id
+        event_id
+        event_map
+        state_res_store
+        allow_none: if the event is not found, return None rather than raising
             an exception
 
     Returns:
-        Deferred[Optional[FrozenEvent]]
+        The event, or none if the event does not exist (and allow_none is True).
     """
     if event_id not in event_map:
         events = await state_res_store.get_events([event_id], allow_rejected=True)
@@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F
     return event
 
 
-def lexicographical_topological_sort(graph, key):
+def lexicographical_topological_sort(
+    graph: Dict[str, Set[str]], key: Callable[[str], Any]
+) -> Generator[str, None, None]:
     """Performs a lexicographic reverse topological sort on the graph.
 
     This returns a reverse topological sort (i.e. if node A references B then B
@@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key):
     NOTE: `graph` is modified during the sort.
 
     Args:
-        graph (dict[str, set[str]]): A representation of the graph where each
-            node is a key in the dict and its value are the nodes edges.
-        key (func): A function that takes a node and returns a value that is
-            comparable and used to order nodes
+        graph: A representation of the graph where each node is a key in the
+            dict and its value are the nodes edges.
+        key: A function that takes a node and returns a value that is comparable
+            and used to order nodes
 
     Yields:
-        str: The next node in the topological sort
+        The next node in the topological sort
     """
 
     # Note, this is basically Kahn's algorithm except we look at nodes with no
     # outgoing edges, c.f.
     # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
     outdegree_map = graph
-    reverse_graph = {}
+    reverse_graph = {}  # type: Dict[str, Set[str]]
 
     # Lists of nodes with zero out degree. Is actually a tuple of
     # `(key(node), node)` so that sorting does the right thing
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 5ef3853559..bbff3c8d5b 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -37,7 +37,7 @@ from synapse.storage.state import StateGroupStorage
 __all__ = ["DataStores", "DataStore"]
 
 
-class Storage(object):
+class Storage:
     """The high level interfaces for talking to various storage layers.
     """
 
@@ -47,6 +47,9 @@ class Storage(object):
         # interfaces.
         self.main = stores.main
 
-        self.persistence = EventsPersistenceStorage(hs, stores)
         self.purge_events = PurgeEventsStorage(hs, stores)
         self.state = StateGroupStorage(hs, stores)
+
+        self.persistence = None
+        if stores.persist_events:
+            self.persistence = EventsPersistenceStorage(hs, stores)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6814bf5fcf..ab49d227de 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,12 +19,11 @@ import random
 from abc import ABCMeta
 from typing import Any, Optional
 
-from canonicaljson import json
-
 from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
 from synapse.types import Collection, get_domain_from_id
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -99,13 +98,13 @@ def db_to_json(db_content):
     if isinstance(db_content, memoryview):
         db_content = db_content.tobytes()
 
-    # Decode it to a Unicode string before feeding it to json.loads, since
+    # Decode it to a Unicode string before feeding it to the JSON decoder, since
     # Python 3.5 does not support deserializing bytes.
     if isinstance(db_content, (bytes, bytearray)):
         db_content = db_content.decode("utf8")
 
     try:
-        return json.loads(db_content)
+        return json_decoder.decode(db_content)
     except Exception:
         logging.warning("Tried to decode '%r' as JSON and failed", db_content)
         raise
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index f43463df53..810721ebe9 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -16,18 +16,15 @@
 import logging
 from typing import Optional
 
-from canonicaljson import json
-
-from twisted.internet import defer
-
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import json_encoder
 
 from . import engines
 
 logger = logging.getLogger(__name__)
 
 
-class BackgroundUpdatePerformance(object):
+class BackgroundUpdatePerformance:
     """Tracks the how long a background update is taking to update its items"""
 
     def __init__(self, name):
@@ -74,7 +71,7 @@ class BackgroundUpdatePerformance(object):
             return float(self.total_item_count) / float(self.total_duration_ms)
 
 
-class BackgroundUpdater(object):
+class BackgroundUpdater:
     """ Background updates are updates to the database that run in the
     background. Each update processes a batch of data at once. We attempt to
     limit the impact of each update by monitoring how long each batch takes to
@@ -308,9 +305,8 @@ class BackgroundUpdater(object):
             update_name (str): Name of update
         """
 
-        @defer.inlineCallbacks
-        def noop_update(progress, batch_size):
-            yield self._end_background_update(update_name)
+        async def noop_update(progress, batch_size):
+            await self._end_background_update(update_name)
             return 1
 
         self.register_background_update_handler(update_name, noop_update)
@@ -409,23 +405,23 @@ class BackgroundUpdater(object):
         else:
             runner = create_index_sqlite
 
-        @defer.inlineCallbacks
-        def updater(progress, batch_size):
+        async def updater(progress, batch_size):
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
-                yield self.db_pool.runWithConnection(runner)
-            yield self._end_background_update(update_name)
+                await self.db_pool.runWithConnection(runner)
+            await self._end_background_update(update_name)
             return 1
 
         self.register_background_update_handler(update_name, updater)
 
-    def _end_background_update(self, update_name):
+    async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.
 
         Args:
-            update_name(str): The name of the completed task to remove
+            update_name:: The name of the completed task to remove
+
         Returns:
-            A deferred that completes once the task is removed.
+            None, completes once the task is removed.
         """
         if update_name != self._current_background_update:
             raise Exception(
@@ -433,11 +429,11 @@ class BackgroundUpdater(object):
                 % update_name
             )
         self._current_background_update = None
-        return self.db_pool.simple_delete_one(
+        await self.db_pool.simple_delete_one(
             "background_updates", keyvalues={"update_name": update_name}
         )
 
-    def _background_update_progress(self, update_name: str, progress: dict):
+    async def _background_update_progress(self, update_name: str, progress: dict):
         """Update the progress of a background update
 
         Args:
@@ -445,7 +441,7 @@ class BackgroundUpdater(object):
             progress: The progress of the update.
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "background_update_progress",
             self._background_update_progress_txn,
             update_name,
@@ -461,7 +457,7 @@ class BackgroundUpdater(object):
             progress(dict): The progress of the update.
         """
 
-        progress_json = json.dumps(progress)
+        progress_json = json_encoder.encode(progress)
 
         self.db_pool.simple_update_one_txn(
             txn,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 4ada6f5563..79ec8f119d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,12 +28,14 @@ from typing import (
     Optional,
     Tuple,
     TypeVar,
+    cast,
+    overload,
 )
 
 from prometheus_client import Histogram
+from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
-from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -125,7 +127,7 @@ class LoggingTransaction:
     method.
 
     Args:
-        txn: The database transcation object to wrap.
+        txn: The database transaction object to wrap.
         name: The name of this transactions for logging.
         database_engine
         after_callbacks: A list that callbacks will be appended to
@@ -160,7 +162,7 @@ class LoggingTransaction:
         self.after_callbacks = after_callbacks
         self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+    def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
@@ -171,7 +173,9 @@ class LoggingTransaction:
         assert self.after_callbacks is not None
         self.after_callbacks.append((callback, args, kwargs))
 
-    def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+    def call_on_exception(
+        self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+    ):
         # if self.exception_callbacks is None, that means that whatever constructed the
         # LoggingTransaction isn't expecting there to be any callbacks; assert that
         # is not the case.
@@ -195,7 +199,7 @@ class LoggingTransaction:
     def description(self) -> Any:
         return self.txn.description
 
-    def execute_batch(self, sql, args):
+    def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch  # type: ignore
 
@@ -204,17 +208,17 @@ class LoggingTransaction:
             for val in args:
                 self.execute(sql, val)
 
-    def execute(self, sql: str, *args: Any):
+    def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
 
-    def executemany(self, sql: str, *args: Any):
+    def executemany(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.executemany, sql, *args)
 
     def _make_sql_one_line(self, sql: str) -> str:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql, *args):
+    def _do_execute(self, func, sql: str, *args: Any) -> None:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -240,22 +244,22 @@ class LoggingTransaction:
             sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
             sql_query_timer.labels(sql.split()[0]).observe(secs)
 
-    def close(self):
+    def close(self) -> None:
         self.txn.close()
 
 
-class PerformanceCounters(object):
+class PerformanceCounters:
     def __init__(self):
         self.current_counters = {}
         self.previous_counters = {}
 
-    def update(self, key, duration_secs):
+    def update(self, key: str, duration_secs: float) -> None:
         count, cum_time = self.current_counters.get(key, (0, 0))
         count += 1
         cum_time += duration_secs
         self.current_counters[key] = (count, cum_time)
 
-    def interval(self, interval_duration_secs, limit=3):
+    def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
         counters = []
         for name, (count, cum_time) in self.current_counters.items():
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
@@ -279,7 +283,10 @@ class PerformanceCounters(object):
         return top_n_counters
 
 
-class DatabasePool(object):
+R = TypeVar("R")
+
+
+class DatabasePool:
     """Wraps a single physical database and connection pool.
 
     A single database may be used by multiple data stores.
@@ -327,13 +334,12 @@ class DatabasePool(object):
                 self._check_safe_to_upsert,
             )
 
-    def is_running(self):
+    def is_running(self) -> bool:
         """Is the database pool currently running
         """
         return self._db_pool.running
 
-    @defer.inlineCallbacks
-    def _check_safe_to_upsert(self):
+    async def _check_safe_to_upsert(self) -> None:
         """
         Is it safe to use native UPSERT?
 
@@ -342,7 +348,7 @@ class DatabasePool(object):
 
         If the background updates have not completed, wait 15 sec and check again.
         """
-        updates = yield self.simple_select_list(
+        updates = await self.simple_select_list(
             "background_updates",
             keyvalues=None,
             retcols=["update_name"],
@@ -364,7 +370,7 @@ class DatabasePool(object):
                 self._check_safe_to_upsert,
             )
 
-    def start_profiling(self):
+    def start_profiling(self) -> None:
         self._previous_loop_ts = monotonic_time()
 
         def loop():
@@ -388,8 +394,15 @@ class DatabasePool(object):
         self._clock.looping_call(loop, 10000)
 
     def new_transaction(
-        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
-    ):
+        self,
+        conn: Connection,
+        desc: str,
+        after_callbacks: List[_CallbackListEntry],
+        exception_callbacks: List[_CallbackListEntry],
+        func: "Callable[..., R]",
+        *args: Any,
+        **kwargs: Any
+    ) -> R:
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -494,8 +507,9 @@ class DatabasePool(object):
             self._txn_perf_counters.update(desc, duration)
             sql_txn_timer.labels(desc).observe(duration)
 
-    @defer.inlineCallbacks
-    def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+    async def runInteraction(
+        self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+    ) -> R:
         """Starts a transaction on the database and runs a given function
 
         Arguments:
@@ -508,7 +522,7 @@ class DatabasePool(object):
             kwargs: named args to pass to `func`
 
         Returns:
-            Deferred: The result of func
+            The result of func
         """
         after_callbacks = []  # type: List[_CallbackListEntry]
         exception_callbacks = []  # type: List[_CallbackListEntry]
@@ -517,7 +531,7 @@ class DatabasePool(object):
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
-            result = yield self.runWithConnection(
+            result = await self.runWithConnection(
                 self.new_transaction,
                 desc,
                 after_callbacks,
@@ -534,10 +548,11 @@ class DatabasePool(object):
                 after_callback(*after_args, **after_kwargs)
             raise
 
-        return result
+        return cast(R, result)
 
-    @defer.inlineCallbacks
-    def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+    async def runWithConnection(
+        self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+    ) -> R:
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
         Arguments:
@@ -548,7 +563,7 @@ class DatabasePool(object):
             kwargs: named args to pass to `func`
 
         Returns:
-            Deferred: The result of func
+            The result of func
         """
         parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
         if not parent_context:
@@ -571,18 +586,16 @@ class DatabasePool(object):
 
                 return func(conn, *args, **kwargs)
 
-        result = yield make_deferred_yieldable(
+        return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
         )
 
-        return result
-
     @staticmethod
-    def cursor_to_dict(cursor):
+    def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
         """Converts a SQL cursor into an list of dicts.
 
         Args:
-            cursor : The DBAPI cursor which has executed a query.
+            cursor: The DBAPI cursor which has executed a query.
         Returns:
             A list of dicts where the key is the column header.
         """
@@ -590,10 +603,29 @@ class DatabasePool(object):
         results = [dict(zip(col_headers, row)) for row in cursor]
         return results
 
-    def execute(self, desc, decoder, query, *args):
+    @overload
+    async def execute(
+        self, desc: str, decoder: Literal[None], query: str, *args: Any
+    ) -> List[Tuple[Any, ...]]:
+        ...
+
+    @overload
+    async def execute(
+        self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
+    ) -> R:
+        ...
+
+    async def execute(
+        self,
+        desc: str,
+        decoder: Optional[Callable[[Cursor], R]],
+        query: str,
+        *args: Any
+    ) -> R:
         """Runs a single query for a result set.
 
         Args:
+            desc: description of the transaction, for logging and metrics
             decoder - The function which can resolve the cursor results to
                 something meaningful.
             query - The query string to execute
@@ -609,29 +641,33 @@ class DatabasePool(object):
             else:
                 return txn.fetchall()
 
-        return self.runInteraction(desc, interaction)
+        return await self.runInteraction(desc, interaction)
 
     # "Simple" SQL API methods that operate on a single table with no JOINs,
     # no complex WHERE clauses, just a dict of values for columns.
 
-    @defer.inlineCallbacks
-    def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+    async def simple_insert(
+        self,
+        table: str,
+        values: Dict[str, Any],
+        or_ignore: bool = False,
+        desc: str = "simple_insert",
+    ) -> bool:
         """Executes an INSERT query on the named table.
 
         Args:
-            table : string giving the table name
-            values : dict of new column names and values for them
-            or_ignore : bool stating whether an exception should be raised
+            table: string giving the table name
+            values: dict of new column names and values for them
+            or_ignore: bool stating whether an exception should be raised
                 when a conflicting row already exists. If True, False will be
                 returned by the function instead
-            desc : string giving a description of the transaction
+            desc: description of the transaction, for logging and metrics
 
         Returns:
-            bool: Whether the row was inserted or not. Only useful when
-            `or_ignore` is True
+             Whether the row was inserted or not. Only useful when `or_ignore` is True
         """
         try:
-            yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+            await self.runInteraction(desc, self.simple_insert_txn, table, values)
         except self.engine.module.IntegrityError:
             # We have to do or_ignore flag at this layer, since we can't reuse
             # a cursor after we receive an error from the db.
@@ -641,7 +677,9 @@ class DatabasePool(object):
         return True
 
     @staticmethod
-    def simple_insert_txn(txn, table, values):
+    def simple_insert_txn(
+        txn: LoggingTransaction, table: str, values: Dict[str, Any]
+    ) -> None:
         keys, vals = zip(*values.items())
 
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -652,11 +690,29 @@ class DatabasePool(object):
 
         txn.execute(sql, vals)
 
-    def simple_insert_many(self, table, values, desc):
-        return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+    async def simple_insert_many(
+        self, table: str, values: List[Dict[str, Any]], desc: str
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        Args:
+            table: string giving the table name
+            values: dict of new column names and values for them
+            desc: description of the transaction, for logging and metrics
+        """
+        await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
 
     @staticmethod
-    def simple_insert_many_txn(txn, table, values):
+    def simple_insert_many_txn(
+        txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        Args:
+            txn: The transaction to use.
+            table: string giving the table name
+            values: dict of new column names and values for them
+        """
         if not values:
             return
 
@@ -684,16 +740,15 @@ class DatabasePool(object):
 
         txn.executemany(sql, vals)
 
-    @defer.inlineCallbacks
-    def simple_upsert(
+    async def simple_upsert(
         self,
-        table,
-        keyvalues,
-        values,
-        insertion_values={},
-        desc="simple_upsert",
-        lock=True,
-    ):
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        desc: str = "simple_upsert",
+        lock: bool = True,
+    ) -> Optional[bool]:
         """
 
         `lock` should generally be set to True (the default), but can be set
@@ -707,21 +762,20 @@ class DatabasePool(object):
           this table.
 
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key columns and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key columns and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            desc: description of the transaction, for logging and metrics
+            lock: True to lock the table when doing the upsert.
         Returns:
-            Deferred(None or bool): Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
+            Native upserts always return None. Emulated upserts return True if a
+            new entry was created, False if an existing one was updated.
         """
         attempts = 0
         while True:
             try:
-                result = yield self.runInteraction(
+                return await self.runInteraction(
                     desc,
                     self.simple_upsert_txn,
                     table,
@@ -730,7 +784,6 @@ class DatabasePool(object):
                     insertion_values,
                     lock=lock,
                 )
-                return result
             except self.engine.module.IntegrityError as e:
                 attempts += 1
                 if attempts >= 5:
@@ -744,29 +797,34 @@ class DatabasePool(object):
                 )
 
     def simple_upsert_txn(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        lock: bool = True,
+    ) -> Optional[bool]:
         """
         Pick the UPSERT method which works best on the platform. Either the
         native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
 
         Args:
             txn: The transaction to use.
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            None or bool: Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
+            Native upserts always return None. Emulated upserts return True if a
+            new entry was created, False if an existing one was updated.
         """
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
-            return self.simple_upsert_txn_native_upsert(
+            self.simple_upsert_txn_native_upsert(
                 txn, table, keyvalues, values, insertion_values=insertion_values
             )
+            return None
         else:
             return self.simple_upsert_txn_emulated(
                 txn,
@@ -778,18 +836,23 @@ class DatabasePool(object):
             )
 
     def simple_upsert_txn_emulated(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        lock: bool = True,
+    ) -> bool:
         """
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            bool: Return True if a new entry was created, False if an existing
+            Returns True if a new entry was created, False if an existing
             one was updated.
         """
         # We need to lock the table :(, unless we're *really* careful
@@ -847,19 +910,21 @@ class DatabasePool(object):
         return True
 
     def simple_upsert_txn_native_upsert(
-        self, txn, table, keyvalues, values, insertion_values={}
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+    ) -> None:
         """
         Use the native UPSERT functionality in recent PostgreSQL versions.
 
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-        Returns:
-            None
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
         """
         allvalues = {}  # type: Dict[str, Any]
         allvalues.update(keyvalues)
@@ -887,7 +952,7 @@ class DatabasePool(object):
         key_names: Collection[str],
         key_values: Collection[Iterable[Any]],
         value_names: Collection[str],
-        value_values: Iterable[Iterable[str]],
+        value_values: Iterable[Iterable[Any]],
     ) -> None:
         """
         Upsert, many times.
@@ -916,7 +981,7 @@ class DatabasePool(object):
         key_names: Iterable[str],
         key_values: Collection[Iterable[Any]],
         value_names: Collection[str],
-        value_values: Iterable[Iterable[str]],
+        value_values: Iterable[Iterable[Any]],
     ) -> None:
         """
         Upsert, many times, but without native UPSERT support or batching.
@@ -989,41 +1054,93 @@ class DatabasePool(object):
 
         return txn.execute_batch(sql, args)
 
-    def simple_select_one(
-        self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
-    ):
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[False] = False,
+        desc: str = "simple_select_one",
+    ) -> Dict[str, Any]:
+        ...
+
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[True] = True,
+        desc: str = "simple_select_one",
+    ) -> Optional[Dict[str, Any]]:
+        ...
+
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: bool = False,
+        desc: str = "simple_select_one",
+    ) -> Optional[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning multiple columns from it.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcols : list of strings giving the names of the columns to return
-
-            allow_none : If true, return None instead of failing if the SELECT
-              statement returns no rows
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            retcols: list of strings giving the names of the columns to return
+            allow_none: If true, return None instead of failing if the SELECT
+                statement returns no rows
+            desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
         )
 
-    def simple_select_one_onecol(
+    @overload
+    async def simple_select_one_onecol(
         self,
-        table,
-        keyvalues,
-        retcol,
-        allow_none=False,
-        desc="simple_select_one_onecol",
-    ):
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: Literal[False] = False,
+        desc: str = "simple_select_one_onecol",
+    ) -> Any:
+        ...
+
+    @overload
+    async def simple_select_one_onecol(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: Literal[True] = True,
+        desc: str = "simple_select_one_onecol",
+    ) -> Optional[Any]:
+        ...
+
+    async def simple_select_one_onecol(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: bool = False,
+        desc: str = "simple_select_one_onecol",
+    ) -> Optional[Any]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcol : string giving the name of the column to return
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            retcol: string giving the name of the column to return
+            allow_none: If true, return None instead of failing if the SELECT
+                statement returns no rows
+            desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc,
             self.simple_select_one_onecol_txn,
             table,
@@ -1032,10 +1149,39 @@ class DatabasePool(object):
             allow_none=allow_none,
         )
 
+    @overload
     @classmethod
     def simple_select_one_onecol_txn(
-        cls, txn, table, keyvalues, retcol, allow_none=False
-    ):
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: Literal[False] = False,
+    ) -> Any:
+        ...
+
+    @overload
+    @classmethod
+    def simple_select_one_onecol_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: Literal[True] = True,
+    ) -> Optional[Any]:
+        ...
+
+    @classmethod
+    def simple_select_one_onecol_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: bool = False,
+    ) -> Optional[Any]:
         ret = cls.simple_select_onecol_txn(
             txn, table=table, keyvalues=keyvalues, retcol=retcol
         )
@@ -1049,7 +1195,9 @@ class DatabasePool(object):
                 raise StoreError(404, "No row found")
 
     @staticmethod
-    def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+    def simple_select_onecol_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
+    ) -> List[Any]:
         sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
 
         if keyvalues:
@@ -1060,53 +1208,72 @@ class DatabasePool(object):
 
         return [r[0] for r in txn]
 
-    def simple_select_onecol(
-        self, table, keyvalues, retcol, desc="simple_select_onecol"
-    ):
+    async def simple_select_onecol(
+        self,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcol: str,
+        desc: str = "simple_select_onecol",
+    ) -> List[Any]:
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
 
         Args:
-            table (str): table name
-            keyvalues (dict|None): column names and values to select the rows with
-            retcol (str): column whos value we wish to retrieve.
+            table: table name
+            keyvalues: column names and values to select the rows with
+            retcol: column whos value we wish to retrieve.
+            desc: description of the transaction, for logging and metrics
 
         Returns:
-            Deferred: Results in a list
+            Results in a list
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_onecol_txn, table, keyvalues, retcol
         )
 
-    def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+    async def simple_select_list(
+        self,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcols: Iterable[str],
+        desc: str = "simple_select_list",
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            keyvalues (dict[str, Any] | None):
+            table: the table name
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
+            retcols: the names of the columns to return
+            desc: description of the transaction, for logging and metrics
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            A list of dictionaries.
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_list_txn, table, keyvalues, retcols
         )
 
     @classmethod
-    def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+    def simple_select_list_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcols: Iterable[str],
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
+            txn: Transaction object
+            table: the table name
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
+            retcols: the names of the columns to return
         """
         if keyvalues:
             sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1121,28 +1288,29 @@ class DatabasePool(object):
 
         return cls.cursor_to_dict(txn)
 
-    @defer.inlineCallbacks
-    def simple_select_many_batch(
+    async def simple_select_many_batch(
         self,
-        table,
-        column,
-        iterable,
-        retcols,
-        keyvalues={},
-        desc="simple_select_many_batch",
-        batch_size=100,
-    ):
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        retcols: Iterable[str],
+        keyvalues: Dict[str, Any] = {},
+        desc: str = "simple_select_many_batch",
+        batch_size: int = 100,
+    ) -> List[Any]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
-        Filters rows by if value of `column` is in `iterable`.
+        Filters rows by whether the value of `column` is in `iterable`.
 
         Args:
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            retcols: list of strings giving the names of the columns to return
+            keyvalues: dict of column names and values to select the rows with
+            desc: description of the transaction, for logging and metrics
+            batch_size: the number of rows for each select query
         """
         results = []  # type: List[Dict[str, Any]]
 
@@ -1156,7 +1324,7 @@ class DatabasePool(object):
             it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
         ]
         for chunk in chunks:
-            rows = yield self.runInteraction(
+            rows = await self.runInteraction(
                 desc,
                 self.simple_select_many_txn,
                 table,
@@ -1171,19 +1339,27 @@ class DatabasePool(object):
         return results
 
     @classmethod
-    def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+    def simple_select_many_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
-        Filters rows by if value of `column` is in `iterable`.
+        Filters rows by whether the value of `column` is in `iterable`.
 
         Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
+            txn: Transaction object
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
+            retcols: list of strings giving the names of the columns to return
         """
         if not iterable:
             return []
@@ -1204,13 +1380,24 @@ class DatabasePool(object):
         txn.execute(sql, values)
         return cls.cursor_to_dict(txn)
 
-    def simple_update(self, table, keyvalues, updatevalues, desc):
-        return self.runInteraction(
+    async def simple_update(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+        desc: str,
+    ) -> int:
+        return await self.runInteraction(
             desc, self.simple_update_txn, table, keyvalues, updatevalues
         )
 
     @staticmethod
-    def simple_update_txn(txn, table, keyvalues, updatevalues):
+    def simple_update_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+    ) -> int:
         if keyvalues:
             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
         else:
@@ -1226,32 +1413,34 @@ class DatabasePool(object):
 
         return txn.rowcount
 
-    def simple_update_one(
-        self, table, keyvalues, updatevalues, desc="simple_update_one"
-    ):
+    async def simple_update_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+        desc: str = "simple_update_one",
+    ) -> None:
         """Executes an UPDATE query on the named table, setting new values for
         columns in a row matching the key values.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            updatevalues : dict giving column names and values to update
-            retcols : optional list of column names to return
-
-        If present, retcols gives a list of column names on which to perform
-        a SELECT statement *before* performing the UPDATE statement. The values
-        of these will be returned in a dict.
-
-        These are performed within the same transaction, allowing an atomic
-        get-and-set.  This can be used to implement compare-and-set by putting
-        the update column in the 'keyvalues' dict as well.
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            updatevalues: dict giving column names and values to update
+            desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(
+        await self.runInteraction(
             desc, self.simple_update_one_txn, table, keyvalues, updatevalues
         )
 
     @classmethod
-    def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+    def simple_update_one_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+    ) -> None:
         rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
 
         if rowcount == 0:
@@ -1259,8 +1448,18 @@ class DatabasePool(object):
         if rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
+    # Ideally we could use the overload decorator here to specify that the
+    # return type is only optional if allow_none is True, but this does not work
+    # when you call a static method from an instance.
+    # See https://github.com/python/mypy/issues/7781
     @staticmethod
-    def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+    def simple_select_one_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: bool = False,
+    ) -> Optional[Dict[str, Any]]:
         select_sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
@@ -1279,24 +1478,29 @@ class DatabasePool(object):
 
         return dict(zip(retcols, row))
 
-    def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+    async def simple_delete_one(
+        self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+    ) -> None:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+        await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
 
     @staticmethod
-    def simple_delete_one_txn(txn, table, keyvalues):
+    def simple_delete_one_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+    ) -> None:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
         """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
@@ -1309,11 +1513,38 @@ class DatabasePool(object):
         if txn.rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
-    def simple_delete(self, table, keyvalues, desc):
-        return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+    async def simple_delete(
+        self, table: str, keyvalues: Dict[str, Any], desc: str
+    ) -> int:
+        """Executes a DELETE query on the named table.
+
+        Filters rows by the key-value pairs.
+
+        Args:
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            desc: description of the transaction, for logging and metrics
+
+        Returns:
+            The number of deleted rows.
+        """
+        return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
 
     @staticmethod
-    def simple_delete_txn(txn, table, keyvalues):
+    def simple_delete_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+    ) -> int:
+        """Executes a DELETE query on the named table.
+
+        Filters rows by the key-value pairs.
+
+        Args:
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+
+        Returns:
+            The number of deleted rows.
+        """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
             " AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1322,26 +1553,53 @@ class DatabasePool(object):
         txn.execute(sql, list(keyvalues.values()))
         return txn.rowcount
 
-    def simple_delete_many(self, table, column, iterable, keyvalues, desc):
-        return self.runInteraction(
+    async def simple_delete_many(
+        self,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+        desc: str,
+    ) -> int:
+        """Executes a DELETE query on the named table.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
+            desc: description of the transaction, for logging and metrics
+
+        Returns:
+            Number rows deleted
+        """
+        return await self.runInteraction(
             desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
         )
 
     @staticmethod
-    def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+    def simple_delete_many_txn(
+        txn: LoggingTransaction,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+    ) -> int:
         """Executes a DELETE query on the named table.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
+            txn: Transaction object
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
 
         Returns:
-            int: Number rows deleted
+            Number rows deleted
         """
         if not iterable:
             return 0
@@ -1362,8 +1620,14 @@ class DatabasePool(object):
         return txn.rowcount
 
     def get_cache_dict(
-        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
-    ):
+        self,
+        db_conn: Connection,
+        table: str,
+        entity_column: str,
+        stream_column: str,
+        max_value: int,
+        limit: int = 100000,
+    ) -> Tuple[Dict[Any, int], int]:
         # Fetch a mapping of room_id -> max stream position for "recent" rooms.
         # It doesn't really matter how many we get, the StreamChangeCache will
         # do the right thing to ensure it respects the max size of cache.
@@ -1394,65 +1658,19 @@ class DatabasePool(object):
 
         return cache, min_val
 
-    def simple_select_list_paginate(
-        self,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-        desc="simple_select_list_paginate",
-    ):
-        """
-        Executes a SELECT query on the named table with start and limit,
-        of row numbers, which may return zero or number of rows from start to limit,
-        returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            filters (dict[str, T] | None):
-                column names and values to filter the rows with, or None to not
-                apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self.runInteraction(
-            desc,
-            self.simple_select_list_paginate_txn,
-            table,
-            orderby,
-            start,
-            limit,
-            retcols,
-            filters=filters,
-            keyvalues=keyvalues,
-            order_direction=order_direction,
-        )
-
     @classmethod
     def simple_select_list_paginate_txn(
         cls,
-        txn,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-    ):
+        txn: LoggingTransaction,
+        table: str,
+        orderby: str,
+        start: int,
+        limit: int,
+        retcols: Iterable[str],
+        filters: Optional[Dict[str, Any]] = None,
+        keyvalues: Optional[Dict[str, Any]] = None,
+        order_direction: str = "ASC",
+    ) -> List[Dict[str, Any]]:
         """
         Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
@@ -1463,21 +1681,22 @@ class DatabasePool(object):
         using 'AND'.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            filters (dict[str, T] | None):
+            txn: Transaction object
+            table: the table name
+            orderby: Column to order the results by.
+            start: Index to begin the query at.
+            limit: Number of results to return.
+            retcols: the names of the columns to return
+            filters:
                 column names and values to filter the rows with, or None to not
                 apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+            order_direction: Whether the results should be ordered "ASC" or "DESC".
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            The result as a list of dictionaries.
         """
         if order_direction not in ["ASC", "DESC"]:
             raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -1503,51 +1722,65 @@ class DatabasePool(object):
 
         return cls.cursor_to_dict(txn)
 
-    def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+    async def simple_search_list(
+        self,
+        table: str,
+        term: Optional[str],
+        col: str,
+        retcols: Iterable[str],
+        desc="simple_search_list",
+    ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
+            table: the table name
+            term: term for searching the table matched to a column.
+            col: column to query term should be matched to
+            retcols: the names of the columns to return
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
+            A list of dictionaries or None.
         """
 
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_search_list_txn, table, term, col, retcols
         )
 
     @classmethod
-    def simple_search_list_txn(cls, txn, table, term, col, retcols):
+    def simple_search_list_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        term: Optional[str],
+        col: str,
+        retcols: Iterable[str],
+    ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
+            txn: Transaction object
+            table: the table name
+            term: term for searching the table matched to a column.
+            col: column to query term should be matched to
+            retcols: the names of the columns to return
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
+            None if no term is given, otherwise a list of dictionaries.
         """
         if term:
             sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
             termvalues = ["%%" + term + "%%"]
             txn.execute(sql, termvalues)
         else:
-            return 0
+            return None
 
         return cls.cursor_to_dict(txn)
 
 
 def make_in_list_sql_clause(
-    database_engine, column: str, iterable: Iterable
+    database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
 ) -> Tuple[str, list]:
     """Returns an SQL clause that checks the given column is in the iterable.
 
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 4406e58273..aa5d490624 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -24,7 +24,7 @@ from synapse.storage.prepare_database import prepare_database
 logger = logging.getLogger(__name__)
 
 
-class Databases(object):
+class Databases:
     """The various databases.
 
     These are low level interfaces to physical databases.
@@ -47,9 +47,14 @@ class Databases(object):
             engine = create_engine(database_config.config)
 
             with make_conn(database_config, engine) as db_conn:
-                logger.info("Preparing database %r...", db_name)
-
+                logger.info("[database config %r]: Checking database server", db_name)
                 engine.check_database(db_conn)
+
+                logger.info(
+                    "[database config %r]: Preparing for databases %r",
+                    db_name,
+                    database_config.databases,
+                )
                 prepare_database(
                     db_conn, engine, hs.config, databases=database_config.databases,
                 )
@@ -57,7 +62,9 @@ class Databases(object):
                 database = DatabasePool(hs, database_config, engine)
 
                 if "main" in database_config.databases:
-                    logger.info("Starting 'main' data store")
+                    logger.info(
+                        "[database config %r]: Starting 'main' database", db_name
+                    )
 
                     # Sanity check we don't try and configure the main store on
                     # multiple databases.
@@ -68,11 +75,13 @@ class Databases(object):
 
                     # If we're on a process that can persist events also
                     # instantiate a `PersistEventsStore`
-                    if hs.config.worker.writers.events == hs.get_instance_name():
+                    if hs.get_instance_name() in hs.config.worker.writers.events:
                         persist_events = PersistEventsStore(hs, database, main)
 
                 if "state" in database_config.databases:
-                    logger.info("Starting 'state' data store")
+                    logger.info(
+                        "[database config %r]: Starting 'state' database", db_name
+                    )
 
                     # Sanity check we don't try and configure the state store on
                     # multiple databases.
@@ -85,14 +94,23 @@ class Databases(object):
 
                 self.databases.append(database)
 
-                logger.info("Database %r prepared", db_name)
+                logger.info("[database config %r]: prepared", db_name)
+
+            # Closing the context manager doesn't close the connection.
+            # psycopg will close the connection when the object gets GCed, but *only*
+            # if the PID is the same as when the connection was opened [1], and
+            # it may not be if we fork in the meantime.
+            #
+            # [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378
+
+            db_conn.close()
 
         # Sanity check that we have actually configured all the required stores.
         if not main:
-            raise Exception("No 'main' data store configured")
+            raise Exception("No 'main' database configured")
 
         if not state:
-            raise Exception("No 'main' data store configured")
+            raise Exception("No 'state' database configured")
 
         # We use local variables here to ensure that the databases do not have
         # optional types.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 17fa470919..2ae2fbd5d7 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,6 +18,7 @@
 import calendar
 import logging
 import time
+from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import PresenceState
 from synapse.config.homeserver import HomeServerConfig
@@ -28,6 +29,7 @@ from synapse.storage.util.id_generators import (
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
+from synapse.types import get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from .account_data import AccountDataStore
@@ -263,6 +265,9 @@ class DataStore(
         # Used in _generate_user_daily_visits to keep track of progress
         self._last_user_visit_update = self._get_start_of_day()
 
+    def get_device_stream_token(self) -> int:
+        return self._device_list_id_gen.get_current_token()
+
     def take_presence_startup_info(self):
         active_on_startup = self._presence_on_startup
         self._presence_on_startup = None
@@ -290,16 +295,16 @@ class DataStore(
 
         return [UserPresenceState(**row) for row in rows]
 
-    def count_daily_users(self):
+    async def count_daily_users(self) -> int:
         """
         Counts the number of users who used this homeserver in the last 24 hours.
         """
         yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_daily_users", self._count_users, yesterday
         )
 
-    def count_monthly_users(self):
+    async def count_monthly_users(self) -> int:
         """
         Counts the number of users who used this homeserver in the last 30 days.
         Note this method is intended for phonehome metrics only and is different
@@ -307,7 +312,7 @@ class DataStore(
         amongst other things, includes a 3 day grace period before a user counts.
         """
         thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
@@ -326,15 +331,15 @@ class DataStore(
         (count,) = txn.fetchone()
         return count
 
-    def count_r30_users(self):
+    async def count_r30_users(self) -> Dict[str, int]:
         """
         Counts the number of 30 day retained users, defined as:-
          * Users who have created their accounts more than 30 days ago
          * Where last seen at most 30 days ago
          * Where account creation and last_seen are > 30 days apart
 
-         Returns counts globaly for a given user as well as breaking
-         by platform
+        Returns:
+             A mapping of counts globally as well as broken out by platform.
         """
 
         def _count_r30_users(txn):
@@ -407,7 +412,7 @@ class DataStore(
 
             return results
 
-        return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+        return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
 
     def _get_start_of_day(self):
         """
@@ -417,7 +422,7 @@ class DataStore(
         today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
         return today_start * 1000
 
-    def generate_user_daily_visits(self):
+    async def generate_user_daily_visits(self) -> None:
         """
         Generates daily visit data for use in cohort/ retention analysis
         """
@@ -472,18 +477,17 @@ class DataStore(
             # frequently
             self._last_user_visit_update = now
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "generate_user_daily_visits", _generate_user_daily_visits
         )
 
-    def get_users(self):
+    async def get_users(self) -> List[Dict[str, Any]]:
         """Function to retrieve a list of users in users table.
 
-        Args:
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            A list of dictionaries representing users.
         """
-        return self.db_pool.simple_select_list(
+        return await self.db_pool.simple_select_list(
             table="users",
             keyvalues={},
             retcols=[
@@ -497,30 +501,40 @@ class DataStore(
             desc="get_users",
         )
 
-    def get_users_paginate(
-        self, start, limit, name=None, guests=True, deactivated=False
-    ):
+    async def get_users_paginate(
+        self,
+        start: int,
+        limit: int,
+        user_id: Optional[str] = None,
+        name: Optional[str] = None,
+        guests: bool = True,
+        deactivated: bool = False,
+    ) -> Tuple[List[Dict[str, Any]], int]:
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users and the
         total number of users matching the filter criteria.
 
         Args:
-            start (int): start number to begin the query from
-            limit (int): number of rows to retrieve
-            name (string): filter for user names
-            guests (bool): whether to in include guest users
-            deactivated (bool): whether to include deactivated users
+            start: start number to begin the query from
+            limit: number of rows to retrieve
+            user_id: search for user_id. ignored if name is not None
+            name: search for local part of user_id or display name
+            guests: whether to in include guest users
+            deactivated: whether to include deactivated users
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]], int
+            A tuple of a list of mappings from user to information and a count of total users.
         """
 
         def get_users_paginate_txn(txn):
             filters = []
-            args = []
+            args = [self.hs.config.server_name]
 
             if name:
+                filters.append("(name LIKE ? OR displayname LIKE ?)")
+                args.extend(["@%" + name + "%:%", "%" + name + "%"])
+            elif user_id:
                 filters.append("name LIKE ?")
-                args.append("%" + name + "%")
+                args.extend(["%" + user_id + "%"])
 
             if not guests:
                 filters.append("is_guest = 0")
@@ -530,39 +544,42 @@ class DataStore(
 
             where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
 
-            sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
-            txn.execute(sql, args)
-            count = txn.fetchone()[0]
-
-            args = [self.hs.config.server_name] + args + [limit, start]
-            sql = """
-                SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+            sql_base = """
                 FROM users as u
                 LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
                 {}
-                ORDER BY u.name LIMIT ? OFFSET ?
                 """.format(
                 where_clause
             )
+            sql = "SELECT COUNT(*) as total_users " + sql_base
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = (
+                "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+                + sql_base
+                + " ORDER BY u.name LIMIT ? OFFSET ?"
+            )
+            args += [limit, start]
             txn.execute(sql, args)
             users = self.db_pool.cursor_to_dict(txn)
             return users, count
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_paginate_txn", get_users_paginate_txn
         )
 
-    def search_users(self, term):
+    async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
         """Function to search users list for one or more users with
         the matched term.
 
         Args:
-            term (str): search term
-            col (str): column to query term should be matched to
+            term: search term
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            A list of dictionaries or None.
         """
-        return self.db_pool.simple_search_list(
+        return await self.db_pool.simple_search_list(
             table="users",
             term=term,
             col="name",
@@ -575,21 +592,24 @@ def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig
     """Called before upgrading an existing database to check that it is broadly sane
     compared with the configuration.
     """
-    domain = config.server_name
+    logger.info("Checking database for consistency with configuration...")
 
-    sql = database_engine.convert_param_style(
-        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
-    )
-    pat = "%:" + domain
-    cur.execute(sql, (pat,))
-    num_not_matching = cur.fetchall()[0][0]
-    if num_not_matching == 0:
+    # if there are any users in the database, check that the username matches our
+    # configured server name.
+
+    cur.execute("SELECT name FROM users LIMIT 1")
+    rows = cur.fetchall()
+    if not rows:
+        return
+
+    user_domain = get_domain_from_id(rows[0][0])
+    if user_domain == config.server_name:
         return
 
     raise Exception(
         "Found users in database not native to %s!\n"
-        "You cannot changed a synapse server_name after it's been configured"
-        % (domain,)
+        "You cannot change a synapse server_name after it's been configured"
+        % (config.server_name,)
     )
 
 
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..4436b1a83d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,9 +16,7 @@
 
 import abc
 import logging
-from typing import List, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Dict, List, Optional, Tuple
 
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
@@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cached()
-    def get_account_data_for_user(self, user_id):
+    async def get_account_data_for_user(
+        self, user_id: str
+    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
         """Get all the client account_data for a user.
 
         Args:
-            user_id(str): The user to get the account_data for.
+            user_id: The user to get the account_data for.
         Returns:
-            A deferred pair of a dict of global account_data and a dict
-            mapping from room_id string to per room account_data dicts.
+            A 2-tuple of a dict of global account_data and a dict mapping from
+            room_id string to per room account_data dicts.
         """
 
         def get_account_data_for_user_txn(txn):
@@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return global_account_data, by_room
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
@@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
             return None
 
     @cached(num_args=2)
-    def get_account_data_for_room(self, user_id, room_id):
+    async def get_account_data_for_room(
+        self, user_id: str, room_id: str
+    ) -> Dict[str, JsonDict]:
         """Get all the client account_data for a user for a room.
 
         Args:
-            user_id(str): The user to get the account_data for.
-            room_id(str): The room to get the account_data for.
+            user_id: The user to get the account_data for.
+            room_id: The room to get the account_data for.
         Returns:
-            A deferred dict of the room account_data
+            A dict of the room account_data
         """
 
         def get_account_data_for_room_txn(txn):
@@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
                 row["account_data_type"]: db_to_json(row["content"]) for row in rows
             }
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
     @cached(num_args=3, max_entries=5000)
-    def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+    async def get_account_data_for_room_and_type(
+        self, user_id: str, room_id: str, account_data_type: str
+    ) -> Optional[JsonDict]:
         """Get the client account_data of given type for a user for a room.
 
         Args:
-            user_id(str): The user to get the account_data for.
-            room_id(str): The room to get the account_data for.
-            account_data_type (str): The account data type to get.
+            user_id: The user to get the account_data for.
+            room_id: The room to get the account_data for.
+            account_data_type: The account data type to get.
         Returns:
-            A deferred of the room account_data for that type, or None if
-            there isn't any set.
+            The room account_data for that type, or None if there isn't any set.
         """
 
         def get_account_data_for_room_and_type_txn(txn):
@@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return db_to_json(content_json) if content_json else None
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
@@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
             "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
-    def get_updated_account_data_for_user(self, user_id, stream_id):
+    async def get_updated_account_data_for_user(
+        self, user_id: str, stream_id: int
+    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
         """Get all the client account_data for a that's changed for a user
 
         Args:
-            user_id(str): The user to get the account_data for.
-            stream_id(int): The point in the stream since which to get updates
+            user_id: The user to get the account_data for.
+            stream_id: The point in the stream since which to get updates
         Returns:
             A deferred pair of a dict of global account_data and a dict
             mapping from room_id string to per room account_data dicts.
@@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             user_id, int(stream_id)
         )
         if not changed:
-            return defer.succeed(({}, {}))
+            return ({}, {})
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
@@ -336,7 +341,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as room_account_data has a unique constraint
             # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
@@ -384,7 +389,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as account_data has a unique constraint on
             # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
@@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
 
         return self._account_data_id_gen.get_current_token()
 
-    def _update_max_stream_id(self, next_id: int):
+    async def _update_max_stream_id(self, next_id: int) -> None:
         """Update the max stream_id
 
         Args:
@@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
             )
             txn.execute(update_max_id_sql, (next_id, next_id))
 
-        return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+        await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5cf1a88399..454c0bc50c 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -16,13 +16,12 @@
 import logging
 import re
 
-from canonicaljson import json
-
 from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import json_encoder
 
 logger = logging.getLogger(__name__)
 
@@ -162,20 +161,18 @@ class ApplicationServiceTransactionWorkerStore(
             return result.get("state")
         return None
 
-    def set_appservice_state(self, service, state):
+    async def set_appservice_state(self, service, state) -> None:
         """Set the application service state.
 
         Args:
             service(ApplicationService): The service whose state to set.
             state(ApplicationServiceState): The connectivity state to apply.
-        Returns:
-            A Deferred which resolves when the state was set successfully.
         """
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             "application_services_state", {"as_id": service.id}, {"state": state}
         )
 
-    def create_appservice_txn(self, service, events):
+    async def create_appservice_txn(self, service, events):
         """Atomically creates a new transaction for this application service
         with the given list of events.
 
@@ -204,7 +201,7 @@ class ApplicationServiceTransactionWorkerStore(
             new_txn_id = max(highest_txn_id, last_txn_id) + 1
 
             # Insert new txn into txn table
-            event_ids = json.dumps([e.event_id for e in events])
+            event_ids = json_encoder.encode([e.event_id for e in events])
             txn.execute(
                 "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
                 "VALUES(?,?,?)",
@@ -212,20 +209,17 @@ class ApplicationServiceTransactionWorkerStore(
             )
             return AppServiceTransaction(service=service, id=new_txn_id, events=events)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "create_appservice_txn", _create_appservice_txn
         )
 
-    def complete_appservice_txn(self, txn_id, service):
+    async def complete_appservice_txn(self, txn_id, service) -> None:
         """Completes an application service transaction.
 
         Args:
             txn_id(str): The transaction ID being completed.
             service(ApplicationService): The application service which was sent
             this transaction.
-        Returns:
-            A Deferred which resolves if this transaction was stored
-            successfully.
         """
         txn_id = int(txn_id)
 
@@ -261,7 +255,7 @@ class ApplicationServiceTransactionWorkerStore(
                 {"txn_id": txn_id, "as_id": service.id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "complete_appservice_txn", _complete_appservice_txn
         )
 
@@ -315,13 +309,13 @@ class ApplicationServiceTransactionWorkerStore(
         else:
             return int(last_txn_id[0])  # select 'last_txn' col
 
-    def set_appservice_last_pos(self, pos):
+    async def set_appservice_last_pos(self, pos) -> None:
         def set_appservice_last_pos_txn(txn):
             txn.execute(
                 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "set_appservice_last_pos", set_appservice_last_pos_txn
         )
 
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 10de446065..1e7637a6f5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 },
             )
 
-    def get_cache_stream_token(self, instance_name):
+    def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
         if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token(instance_name)
+            return self._cache_id_gen.get_current_token_for_writer(instance_name)
         else:
             return 0
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 216a5925fc..c2fc847fbc 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         self._batch_row_update[key] = (user_agent, device_id, now)
 
     @wrap_as_background_process("update_client_ips")
-    def _update_client_ips_batch(self):
+    async def _update_client_ips_batch(self) -> None:
 
         # If the DB pool has already terminated, don't try updating
         if not self.db_pool.is_running():
@@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         to_update = self._batch_row_update
         self._batch_row_update = {}
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
         )
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..0044433110 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -190,15 +190,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         )
 
     @trace
-    def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
+    async def delete_device_msgs_for_remote(
+        self, destination: str, up_to_stream_id: int
+    ) -> None:
         """Used to delete messages when the remote destination acknowledges
         their receipt.
 
         Args:
-            destination(str): The destination server_name
-            up_to_stream_id(int): Where to delete messages up to.
-        Returns:
-            A deferred that resolves when the messages have been deleted.
+            destination: The destination server_name
+            up_to_stream_id: Where to delete messages up to.
         """
 
         def delete_messages_for_remote_destination_txn(txn):
@@ -209,7 +209,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             )
             txn.execute(sql, (destination, up_to_stream_id))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
         )
 
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
-        with self._device_inbox_id_gen.get_next() as stream_id:
+        with await self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 txn, stream_id, local_messages_by_user_then_device
             )
 
-        with self._device_inbox_id_gen.get_next() as stream_id:
+        with await self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2b33060480..306fc6947c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -14,8 +14,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
@@ -47,7 +48,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
-    def get_device(self, user_id: str, device_id: str):
+    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
 
@@ -55,11 +56,11 @@ class DeviceWorkerStore(SQLBaseStore):
             user_id: The ID of the user which owns the device
             device_id: The ID of the device to retrieve
         Returns:
-            defer.Deferred for a dict containing the device information
+            A dict containing the device information
         Raises:
             StoreError: if the device is not found
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
             update included in the response), and the list of updates, where
             each update is a pair of EDU type and EDU contents.
         """
-        now_stream_id = self._device_list_id_gen.get_current_token()
+        now_stream_id = self.get_device_stream_token()
 
         has_changed = self._device_list_federation_stream_cache.has_entity_changed(
             destination, int(from_stream_id)
@@ -254,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
             List of objects representing an device update EDU
         """
         devices = (
-            await self.db_pool.runInteraction(
-                "_get_e2e_device_keys_txn",
-                self._get_e2e_device_keys_txn,
+            await self.get_e2e_device_keys_and_signatures(
                 query_map.keys(),
                 include_all_devices=True,
                 include_deleted_devices=True,
@@ -292,17 +291,11 @@ class DeviceWorkerStore(SQLBaseStore):
                 prev_id = stream_id
 
                 if device is not None:
-                    key_json = device.get("key_json", None)
-                    if key_json:
-                        result["keys"] = db_to_json(key_json)
-
-                        if "signatures" in device:
-                            for sig_user_id, sigs in device["signatures"].items():
-                                result["keys"].setdefault("signatures", {}).setdefault(
-                                    sig_user_id, {}
-                                ).update(sigs)
+                    keys = device.keys
+                    if keys:
+                        result["keys"] = keys
 
-                    device_display_name = device.get("device_display_name", None)
+                    device_display_name = device.display_name
                     if device_display_name:
                         result["device_display_name"] = device_display_name
                 else:
@@ -312,9 +305,9 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return results
 
-    def _get_last_device_update_for_remote_user(
+    async def _get_last_device_update_for_remote_user(
         self, destination: str, user_id: str, from_stream_id: int
-    ):
+    ) -> int:
         def f(txn):
             prev_sent_id_sql = """
                 SELECT coalesce(max(stream_id), 0) as stream_id
@@ -325,12 +318,16 @@ class DeviceWorkerStore(SQLBaseStore):
             rows = txn.fetchall()
             return rows[0][0]
 
-        return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
+        return await self.db_pool.runInteraction(
+            "get_last_device_update_for_remote_user", f
+        )
 
-    def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
+    async def mark_as_sent_devices_by_remote(
+        self, destination: str, stream_id: int
+    ) -> None:
         """Mark that updates have successfully been sent to the destination.
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "mark_as_sent_devices_by_remote",
             self._mark_as_sent_devices_by_remote_txn,
             destination,
@@ -380,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
             THe new stream ID.
         """
 
-        with self._device_list_id_gen.get_next() as stream_id:
+        with await self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -412,8 +409,10 @@ class DeviceWorkerStore(SQLBaseStore):
             },
         )
 
+    @abc.abstractmethod
     def get_device_stream_token(self) -> int:
-        return self._device_list_id_gen.get_current_token()
+        """Get the current stream id from the _device_list_id_gen"""
+        ...
 
     @trace
     async def get_user_devices_from_cache(
@@ -481,55 +480,8 @@ class DeviceWorkerStore(SQLBaseStore):
             device["device_id"]: db_to_json(device["content"]) for device in devices
         }
 
-    def get_devices_with_keys_by_user(self, user_id: str):
-        """Get all devices (with any device keys) for a user
-
-        Returns:
-            Deferred which resolves to (stream_id, devices)
-        """
-        return self.db_pool.runInteraction(
-            "get_devices_with_keys_by_user",
-            self._get_devices_with_keys_by_user_txn,
-            user_id,
-        )
-
-    def _get_devices_with_keys_by_user_txn(
-        self, txn: LoggingTransaction, user_id: str
-    ) -> Tuple[int, List[JsonDict]]:
-        now_stream_id = self._device_list_id_gen.get_current_token()
-
-        devices = self._get_e2e_device_keys_txn(
-            txn, [(user_id, None)], include_all_devices=True
-        )
-
-        if devices:
-            user_devices = devices[user_id]
-            results = []
-            for device_id, device in user_devices.items():
-                result = {"device_id": device_id}
-
-                key_json = device.get("key_json", None)
-                if key_json:
-                    result["keys"] = db_to_json(key_json)
-
-                    if "signatures" in device:
-                        for sig_user_id, sigs in device["signatures"].items():
-                            result["keys"].setdefault("signatures", {}).setdefault(
-                                sig_user_id, {}
-                            ).update(sigs)
-
-                device_display_name = device.get("device_display_name", None)
-                if device_display_name:
-                    result["device_display_name"] = device_display_name
-
-                results.append(result)
-
-            return now_stream_id, results
-
-        return now_stream_id, []
-
     async def get_users_whose_devices_changed(
-        self, from_key: str, user_ids: Iterable[str]
+        self, from_key: int, user_ids: Iterable[str]
     ) -> Set[str]:
         """Get set of users whose devices have changed since `from_key` that
         are in the given list of user_ids.
@@ -541,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             The set of user_ids whose devices have changed since `from_key`
         """
-        from_key = int(from_key)
 
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
@@ -575,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     async def get_users_whose_signatures_changed(
-        self, user_id: str, from_key: str
+        self, user_id: str, from_key: int
     ) -> Set[str]:
         """Get the users who have new cross-signing signatures made by `user_id` since
         `from_key`.
@@ -587,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             A set of user IDs with updated signatures.
         """
-        from_key = int(from_key)
+
         if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
             sql = """
                 SELECT DISTINCT user_ids FROM user_signature_stream
@@ -656,11 +607,13 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    def get_device_list_last_stream_id_for_remote(self, user_id: str):
+    async def get_device_list_last_stream_id_for_remote(
+        self, user_id: str
+    ) -> Optional[Any]:
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
             retcol="stream_id",
@@ -671,10 +624,9 @@ class DeviceWorkerStore(SQLBaseStore):
     @cachedList(
         cached_method_name="get_device_list_last_stream_id_for_remote",
         list_name="user_ids",
-        inlineCallbacks=True,
     )
-    def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
-        rows = yield self.db_pool.simple_select_many_batch(
+    async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+        rows = await self.db_pool.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
             iterable=user_ids,
@@ -715,11 +667,11 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return {row["user_id"] for row in rows}
 
-    def mark_remote_user_device_cache_as_stale(self, user_id: str):
+    async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
         """Records that the server has reason to believe the cache of the devices
         for the remote users is out of date.
         """
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="device_lists_remote_resync",
             keyvalues={"user_id": user_id},
             values={},
@@ -727,7 +679,7 @@ class DeviceWorkerStore(SQLBaseStore):
             desc="make_remote_user_device_cache_as_stale",
         )
 
-    def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
+    async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
         """Mark that we no longer track device lists for remote user.
         """
 
@@ -741,7 +693,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "mark_remote_user_device_list_as_unsubscribed",
             _mark_remote_user_device_list_as_unsubscribed_txn,
         )
@@ -1002,9 +954,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             desc="update_device",
         )
 
-    def update_remote_device_list_cache_entry(
+    async def update_remote_device_list_cache_entry(
         self, user_id: str, device_id: str, content: JsonDict, stream_id: int
-    ):
+    ) -> None:
         """Updates a single device in the cache of a remote user's devicelist.
 
         Note: assumes that we are the only thread that can be updating this user's
@@ -1015,11 +967,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             device_id: ID of decivice being updated
             content: new data on this device
             stream_id: the version of the device list
-
-        Returns:
-            Deferred[None]
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_remote_device_list_cache_entry",
             self._update_remote_device_list_cache_entry_txn,
             user_id,
@@ -1071,9 +1020,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             lock=False,
         )
 
-    def update_remote_device_list_cache(
+    async def update_remote_device_list_cache(
         self, user_id: str, devices: List[dict], stream_id: int
-    ):
+    ) -> None:
         """Replace the entire cache of the remote user's devices.
 
         Note: assumes that we are the only thread that can be updating this user's
@@ -1083,11 +1032,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id: User to update device list for
             devices: list of device objects supplied over federation
             stream_id: the version of the device list
-
-        Returns:
-            Deferred[None]
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_remote_device_list_cache",
             self._update_remote_device_list_cache_txn,
             user_id,
@@ -1097,7 +1043,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
     def _update_remote_device_list_cache_txn(
         self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
-    ):
+    ) -> None:
         self.db_pool.simple_delete_txn(
             txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
         )
@@ -1147,7 +1093,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return
 
-        with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+        with await self._device_list_id_gen.get_next_mult(
+            len(device_ids)
+        ) as stream_ids:
             await self.db_pool.runInteraction(
                 "add_device_change_to_stream",
                 self._add_device_change_to_stream_txn,
@@ -1160,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return stream_ids[-1]
 
         context = get_active_span_text_map()
-        with self._device_list_id_gen.get_next_mult(
+        with await self._device_list_id_gen.get_next_mult(
             len(hosts) * len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 037e02603c..e5060d4c46 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from collections import namedtuple
-from typing import Iterable, Optional
+from typing import Iterable, List, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
 
         return RoomAliasMapping(room_id, room_alias.to_string(), servers)
 
-    def get_room_alias_creator(self, room_alias):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_room_alias_creator(self, room_alias: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="room_aliases",
             keyvalues={"room_alias": room_alias},
             retcol="creator",
@@ -68,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=5000)
-    def get_aliases_for_room(self, room_id):
-        return self.db_pool.simple_select_onecol(
+    async def get_aliases_for_room(self, room_id: str) -> List[str]:
+        return await self.db_pool.simple_select_onecol(
             "room_aliases",
             {"room_id": room_id},
             "room_alias",
@@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
 
         return room_id
 
-    def update_aliases_for_room(
+    async def update_aliases_for_room(
         self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
-    ):
+    ) -> None:
         """Repoint all of the aliases for a given room, to a different room.
 
         Args:
@@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
                 txn, self.get_aliases_for_room, (new_room_id,)
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "_update_aliases_for_room_txn", _update_aliases_for_room_txn
         )
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 2eeb9f97dc..12cecceec2 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Optional
+
 from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -149,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         return sessions
 
-    def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+    async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
         """Get multiple room keys at a time.  The difference between this function and
         get_e2e_room_keys is that this function can be used to retrieve
         multiple specific keys at a time, whereas get_e2e_room_keys is used for
@@ -164,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 that we want to query
 
         Returns:
-           Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
+           dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_e2e_room_keys_multi",
             self._get_e2e_room_keys_multi_txn,
             user_id,
@@ -223,15 +225,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         return ret
 
-    def count_e2e_room_keys(self, user_id, version):
+    async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
         """Get the number of keys in a backup version.
 
         Args:
-            user_id (str): the user whose backup we're querying
-            version (str): the version ID of the backup we're querying about
+            user_id: the user whose backup we're querying
+            version: the version ID of the backup we're querying about
         """
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="e2e_room_keys",
             keyvalues={"user_id": user_id, "version": version},
             retcol="COUNT(*)",
@@ -281,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             raise StoreError(404, "No current backup version")
         return row[0]
 
-    def get_e2e_room_keys_version_info(self, user_id, version=None):
+    async def get_e2e_room_keys_version_info(self, user_id, version=None):
         """Get info metadata about a version of our room_keys backup.
 
         Args:
@@ -291,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         Raises:
             StoreError: with code 404 if there are no e2e_room_keys_versions present
         Returns:
-            A deferred dict giving the info metadata for this backup version, with
+            A dict giving the info metadata for this backup version, with
             fields including:
                 version(str)
                 algorithm(str)
@@ -322,12 +324,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 result["etag"] = 0
             return result
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
         )
 
     @trace
-    def create_e2e_room_keys_version(self, user_id, info):
+    async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
         """Atomically creates a new version of this user's e2e_room_keys store
         with the given version info.
 
@@ -336,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             info(dict): the info about the backup version to be created
 
         Returns:
-            A deferred string for the newly created version ID
+            The newly created version ID
         """
 
         def _create_e2e_room_keys_version_txn(txn):
@@ -363,23 +365,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
             return new_version
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
         )
 
     @trace
-    def update_e2e_room_keys_version(
-        self, user_id, version, info=None, version_etag=None
-    ):
+    async def update_e2e_room_keys_version(
+        self,
+        user_id: str,
+        version: str,
+        info: Optional[dict] = None,
+        version_etag: Optional[int] = None,
+    ) -> None:
         """Update a given backup version
 
         Args:
-            user_id(str): the user whose backup version we're updating
-            version(str): the version ID of the backup version we're updating
-            info (dict): the new backup version info to store.  If None, then
-                the backup version info is not updated
-            version_etag (Optional[int]): etag of the keys in the backup.  If
-                None, then the etag is not updated
+            user_id: the user whose backup version we're updating
+            version: the version ID of the backup version we're updating
+            info: the new backup version info to store. If None, then the backup
+                version info is not updated.
+            version_etag: etag of the keys in the backup. If None, then the etag
+                is not updated.
         """
         updatevalues = {}
 
@@ -389,7 +395,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             updatevalues["etag"] = version_etag
 
         if updatevalues:
-            return self.db_pool.simple_update(
+            await self.db_pool.simple_update(
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": version},
                 updatevalues=updatevalues,
@@ -397,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             )
 
     @trace
-    def delete_e2e_room_keys_version(self, user_id, version=None):
+    async def delete_e2e_room_keys_version(
+        self, user_id: str, version: Optional[str] = None
+    ) -> None:
         """Delete a given backup version of the user's room keys.
         Doesn't delete their actual key data.
 
         Args:
-            user_id(str): the user whose backup version we're deleting
-            version(str): Optional. the version ID of the backup version we're deleting
+            user_id: the user whose backup version we're deleting
+            version: Optional. the version ID of the backup version we're deleting
                 If missing, we delete the current backup version info.
         Raises:
             StoreError: with code 404 if there are no e2e_room_keys_versions present,
@@ -424,13 +432,13 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 keyvalues={"user_id": user_id, "version": this_version},
             )
 
-            return self.db_pool.simple_update_one_txn(
+            self.db_pool.simple_update_one_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": this_version},
                 updatevalues={"deleted": 1},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
         )
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..c8df0bcb3f 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,8 +14,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Dict, Iterable, List, Optional, Tuple
+import abc
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
+import attr
 from canonicaljson import encode_canonical_json
 
 from twisted.enterprise.adbapi import Connection
@@ -23,24 +25,68 @@ from twisted.enterprise.adbapi import Connection
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.types import Cursor
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
 
+if TYPE_CHECKING:
+    from synapse.handlers.e2e_keys import SignatureListItem
+
+
+@attr.s(slots=True)
+class DeviceKeyLookupResult:
+    """The type returned by get_e2e_device_keys_and_signatures"""
+
+    display_name = attr.ib(type=Optional[str])
+
+    # the key data from e2e_device_keys_json. Typically includes fields like
+    # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
+    # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
+    keys = attr.ib(type=Optional[JsonDict])
+
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
+    async def get_e2e_device_keys_for_federation_query(
+        self, user_id: str
+    ) -> Tuple[int, List[JsonDict]]:
+        """Get all devices (with any device keys) for a user
+
+        Returns:
+            (stream_id, devices)
+        """
+        now_stream_id = self.get_device_stream_token()
+
+        devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
+
+        if devices:
+            user_devices = devices[user_id]
+            results = []
+            for device_id, device in user_devices.items():
+                result = {"device_id": device_id}
+
+                keys = device.keys
+                if keys:
+                    result["keys"] = keys
+
+                device_display_name = device.display_name
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+            return now_stream_id, results
+
+        return now_stream_id, []
+
     @trace
-    async def get_e2e_device_keys(
-        self, query_list, include_all_devices=False, include_deleted_devices=False
-    ):
-        """Fetch a list of device keys.
+    async def get_e2e_device_keys_for_cs_api(
+        self, query_list: List[Tuple[str, Optional[str]]]
+    ) -> Dict[str, Dict[str, JsonDict]]:
+        """Fetch a list of device keys, formatted suitably for the C/S API.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
-            include_all_devices (bool): whether to include entries for devices
-                that don't have device keys
-            include_deleted_devices (bool): whether to include null entries for
-                devices which no longer exist (but were in the query_list).
-                This option only takes effect if include_all_devices is true.
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             key data.  The key data will be a dict in the same format as the
@@ -50,13 +96,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         if not query_list:
             return {}
 
-        results = await self.db_pool.runInteraction(
-            "get_e2e_device_keys",
-            self._get_e2e_device_keys_txn,
-            query_list,
-            include_all_devices,
-            include_deleted_devices,
-        )
+        results = await self.get_e2e_device_keys_and_signatures(query_list)
 
         # Build the result structure, un-jsonify the results, and add the
         # "unsigned" section
@@ -64,31 +104,95 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         for user_id, device_keys in results.items():
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
-                r = db_to_json(device_info.pop("key_json"))
+                r = device_info.keys
                 r["unsigned"] = {}
-                display_name = device_info["device_display_name"]
+                display_name = device_info.display_name
                 if display_name is not None:
                     r["unsigned"]["device_display_name"] = display_name
-                if "signatures" in device_info:
-                    for sig_user_id, sigs in device_info["signatures"].items():
-                        r.setdefault("signatures", {}).setdefault(
-                            sig_user_id, {}
-                        ).update(sigs)
                 rv[user_id][device_id] = r
 
         return rv
 
     @trace
-    def _get_e2e_device_keys_txn(
-        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
-    ):
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: List[Tuple[str, Optional[str]]],
+        include_all_devices: bool = False,
+        include_deleted_devices: bool = False,
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+        """Fetch a list of device keys
+
+        Any cross-signatures made on the keys by the owner of the device are also
+        included.
+
+        The cross-signatures are added to the `signatures` field within the `keys`
+        object in the response.
+
+        Args:
+            query_list: List of pairs of user_ids and device_ids. Device id can be None
+                to indicate "all devices for this user"
+
+            include_all_devices: whether to return devices without device keys
+
+            include_deleted_devices: whether to include null entries for
+                devices which no longer exist (but were in the query_list).
+                This option only takes effect if include_all_devices is true.
+
+        Returns:
+            Dict mapping from user-id to dict mapping from device_id to
+            key data.
+        """
         set_tag("include_all_devices", include_all_devices)
         set_tag("include_deleted_devices", include_deleted_devices)
 
+        result = await self.db_pool.runInteraction(
+            "get_e2e_device_keys",
+            self._get_e2e_device_keys_txn,
+            query_list,
+            include_all_devices,
+            include_deleted_devices,
+        )
+
+        # get the (user_id, device_id) tuples to look up cross-signatures for
+        signature_query = (
+            (user_id, device_id)
+            for user_id, dev in result.items()
+            for device_id, d in dev.items()
+            if d is not None and d.keys is not None
+        )
+
+        for batch in batch_iter(signature_query, 50):
+            cross_sigs_result = await self.db_pool.runInteraction(
+                "get_e2e_cross_signing_signatures",
+                self._get_e2e_cross_signing_signatures_for_devices_txn,
+                batch,
+            )
+
+            # add each cross-signing signature to the correct device in the result dict.
+            for (user_id, key_id, device_id, signature) in cross_sigs_result:
+                target_device_result = result[user_id][device_id]
+                target_device_signatures = target_device_result.keys.setdefault(
+                    "signatures", {}
+                )
+                signing_user_signatures = target_device_signatures.setdefault(
+                    user_id, {}
+                )
+                signing_user_signatures[key_id] = signature
+
+        log_kv(result)
+        return result
+
+    def _get_e2e_device_keys_txn(
+        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+        """Get information on devices from the database
+
+        The results include the device's keys and self-signatures, but *not* any
+        cross-signing signatures which have been added subsequently (for which, see
+        get_e2e_device_keys_and_signatures)
+        """
         query_clauses = []
         query_params = []
-        signature_query_clauses = []
-        signature_query_params = []
 
         if include_all_devices is False:
             include_deleted_devices = False
@@ -99,24 +203,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         for (user_id, device_id) in query_list:
             query_clause = "user_id = ?"
             query_params.append(user_id)
-            signature_query_clause = "target_user_id = ?"
-            signature_query_params.append(user_id)
 
             if device_id is not None:
                 query_clause += " AND device_id = ?"
                 query_params.append(device_id)
-                signature_query_clause += " AND target_device_id = ?"
-                signature_query_params.append(device_id)
-
-            signature_query_clause += " AND user_id = ?"
-            signature_query_params.append(user_id)
 
             query_clauses.append(query_clause)
-            signature_query_clauses.append(signature_query_clause)
 
         sql = (
             "SELECT user_id, device_id, "
-            "    d.display_name AS device_display_name, "
+            "    d.display_name, "
             "    k.key_json"
             " FROM devices d"
             "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@@ -127,51 +223,49 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         )
 
         txn.execute(sql, query_params)
-        rows = self.db_pool.cursor_to_dict(txn)
 
-        result = {}
-        for row in rows:
+        result = {}  # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+        for (user_id, device_id, display_name, key_json) in txn:
             if include_deleted_devices:
-                deleted_devices.remove((row["user_id"], row["device_id"]))
-            result.setdefault(row["user_id"], {})[row["device_id"]] = row
+                deleted_devices.remove((user_id, device_id))
+            result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+                display_name, db_to_json(key_json) if key_json else None
+            )
 
         if include_deleted_devices:
             for user_id, device_id in deleted_devices:
                 result.setdefault(user_id, {})[device_id] = None
 
-        # get signatures on the device
-        signature_sql = ("SELECT *  FROM e2e_cross_signing_signatures WHERE %s") % (
-            " OR ".join("(" + q + ")" for q in signature_query_clauses)
-        )
+        return result
 
-        txn.execute(signature_sql, signature_query_params)
-        rows = self.db_pool.cursor_to_dict(txn)
-
-        # add each cross-signing signature to the correct device in the result dict.
-        for row in rows:
-            signing_user_id = row["user_id"]
-            signing_key_id = row["key_id"]
-            target_user_id = row["target_user_id"]
-            target_device_id = row["target_device_id"]
-            signature = row["signature"]
-
-            target_user_result = result.get(target_user_id)
-            if not target_user_result:
-                continue
+    def _get_e2e_cross_signing_signatures_for_devices_txn(
+        self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+    ) -> List[Tuple[str, str, str, str]]:
+        """Get cross-signing signatures for a given list of devices
 
-            target_device_result = target_user_result.get(target_device_id)
-            if not target_device_result:
-                # note that target_device_result will be None for deleted devices.
-                continue
+        Returns signatures made by the owners of the devices.
 
-            target_device_signatures = target_device_result.setdefault("signatures", {})
-            signing_user_signatures = target_device_signatures.setdefault(
-                signing_user_id, {}
+        Returns: a list of results; each entry in the list is a tuple of
+            (user_id, key_id, target_device_id, signature).
+        """
+        signature_query_clauses = []
+        signature_query_params = []
+
+        for (user_id, device_id) in device_query:
+            signature_query_clauses.append(
+                "target_user_id = ? AND target_device_id = ? AND user_id = ?"
             )
-            signing_user_signatures[signing_key_id] = signature
+            signature_query_params.extend([user_id, device_id, user_id])
 
-        log_kv(result)
-        return result
+        signature_sql = """
+            SELECT user_id, key_id, target_device_id, signature
+            FROM e2e_cross_signing_signatures WHERE %s
+            """ % (
+            " OR ".join("(" + q + ")" for q in signature_query_clauses)
+        )
+
+        txn.execute(signature_sql, signature_query_params)
+        return txn.fetchall()
 
     async def get_e2e_one_time_keys(
         self, user_id: str, device_id: str, key_ids: List[str]
@@ -249,10 +343,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    def count_e2e_one_time_keys(self, user_id, device_id):
+    async def count_e2e_one_time_keys(
+        self, user_id: str, device_id: str
+    ) -> Dict[str, int]:
         """ Count the number of one time keys the server has for a device
         Returns:
-            Dict mapping from algorithm to number of keys for that algorithm.
+            A mapping from algorithm to number of keys for that algorithm.
         """
 
         def _count_e2e_one_time_keys(txn):
@@ -267,7 +363,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 result[algorithm] = key_count
             return result
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
@@ -305,7 +401,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         list_name="user_ids",
         num_args=1,
     )
-    def _get_bare_e2e_cross_signing_keys_bulk(
+    async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
@@ -313,16 +409,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         the signatures for the calling user need to be fetched.
 
         Args:
-            user_ids (list[str]): the users whose keys are being requested
+            user_ids: the users whose keys are being requested
 
         Returns:
-            dict[str, dict[str, dict]]: mapping from user ID to key type to key
-                data.  If a user's cross-signing keys were not found, either
-                their user ID will not be in the dict, or their user ID will map
-                to None.
+            A mapping from user ID to key type to key data. If a user's cross-signing
+            keys were not found, either their user ID will not be in the dict, or
+            their user ID will map to None.
 
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_bare_e2e_cross_signing_keys_bulk",
             self._get_bare_e2e_cross_signing_keys_bulk_txn,
             user_ids,
@@ -538,9 +633,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             _get_all_user_signature_changes_for_remotes_txn,
         )
 
+    @abc.abstractmethod
+    def get_device_stream_token(self) -> int:
+        """Get the current stream id from the _device_list_id_gen"""
+        ...
+
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
-    def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+    async def set_e2e_device_keys(
+        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+    ) -> bool:
         """Stores device keys for a device. Returns whether there was a change
         or the keys were already in the database.
         """
@@ -576,12 +678,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             log_kv({"message": "Device keys stored."})
             return True
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "set_e2e_device_keys", _set_e2e_device_keys_txn
         )
 
-    def claim_e2e_one_time_keys(self, query_list):
-        """Take a list of one time keys out of the database"""
+    async def claim_e2e_one_time_keys(
+        self, query_list: Iterable[Tuple[str, str, str]]
+    ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+        """Take a list of one time keys out of the database.
+
+        Args:
+            query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+        """
 
         @trace
         def _claim_e2e_one_time_keys(txn):
@@ -617,11 +728,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 )
             return result
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
-    def delete_e2e_keys_by_device(self, user_id, device_id):
+    async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
         def delete_e2e_keys_by_device_txn(txn):
             log_kv(
                 {
@@ -644,11 +755,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
-    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
         """Set a user's cross-signing key.
 
         Args:
@@ -658,6 +769,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 for a master key, 'self_signing' for a self-signing key, or
                 'user_signing' for a user-signing key
             key (dict): the key data
+            stream_id (int)
         """
         # the 'key' dict will look something like:
         # {
@@ -695,23 +807,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             )
 
         # and finally, store the key itself
-        with self._cross_signing_id_gen.get_next() as stream_id:
-            self.db_pool.simple_insert_txn(
-                txn,
-                "e2e_cross_signing_keys",
-                values={
-                    "user_id": user_id,
-                    "keytype": key_type,
-                    "keydata": json_encoder.encode(key),
-                    "stream_id": stream_id,
-                },
-            )
+        self.db_pool.simple_insert_txn(
+            txn,
+            "e2e_cross_signing_keys",
+            values={
+                "user_id": user_id,
+                "keytype": key_type,
+                "keydata": json_encoder.encode(key),
+                "stream_id": stream_id,
+            },
+        )
 
         self._invalidate_cache_and_stream(
             txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
         )
 
-    def set_e2e_cross_signing_key(self, user_id, key_type, key):
+    async def set_e2e_cross_signing_key(self, user_id, key_type, key):
         """Set a user's cross-signing key.
 
         Args:
@@ -719,22 +830,27 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key_type (str): the type of cross-signing key to set
             key (dict): the key data
         """
-        return self.db_pool.runInteraction(
-            "add_e2e_cross_signing_key",
-            self._set_e2e_cross_signing_key_txn,
-            user_id,
-            key_type,
-            key,
-        )
 
-    def store_e2e_cross_signing_signatures(self, user_id, signatures):
+        with await self._cross_signing_id_gen.get_next() as stream_id:
+            return await self.db_pool.runInteraction(
+                "add_e2e_cross_signing_key",
+                self._set_e2e_cross_signing_key_txn,
+                user_id,
+                key_type,
+                key,
+                stream_id,
+            )
+
+    async def store_e2e_cross_signing_signatures(
+        self, user_id: str, signatures: "Iterable[SignatureListItem]"
+    ) -> None:
         """Stores cross-signing signatures.
 
         Args:
-            user_id (str): the user who made the signatures
-            signatures (iterable[SignatureListItem]): signatures to add
+            user_id: the user who made the signatures
+            signatures: signatures to add
         """
-        return self.db_pool.simple_insert_many(
+        await self.db_pool.simple_insert_many(
             "e2e_cross_signing_signatures",
             [
                 {
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 484875f989..4c3c162acf 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,14 +15,16 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Set, Tuple
 
 from synapse.api.errors import StoreError
+from synapse.events import EventBase
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.types import Collection
 from synapse.util.caches.descriptors import cached
 from synapse.util.iterutils import batch_iter
 
@@ -30,57 +32,51 @@ logger = logging.getLogger(__name__)
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    def get_auth_chain(self, event_ids, include_given=False):
+    async def get_auth_chain(
+        self, event_ids: Collection[str], include_given: bool = False
+    ) -> List[EventBase]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
-            event_ids (list): state events
-            include_given (bool): include the given events in result
+            event_ids: state events
+            include_given: include the given events in result
 
         Returns:
             list of events
         """
-        return self.get_auth_chain_ids(
+        event_ids = await self.get_auth_chain_ids(
             event_ids, include_given=include_given
-        ).addCallback(self.get_events_as_list)
-
-    def get_auth_chain_ids(
-        self,
-        event_ids: List[str],
-        include_given: bool = False,
-        ignore_events: Optional[Set[str]] = None,
-    ):
+        )
+        return await self.get_events_as_list(event_ids)
+
+    async def get_auth_chain_ids(
+        self, event_ids: Collection[str], include_given: bool = False,
+    ) -> List[str]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
             event_ids: state events
             include_given: include the given events in result
-            ignore_events: Set of events to exclude from the returned auth
-                chain. This is useful if the caller will just discard the
-                given events anyway, and saves us from figuring out their auth
-                chains if not required.
 
         Returns:
-            list of event_ids
+            An awaitable which resolve to a list of event_ids
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_auth_chain_ids",
             self._get_auth_chain_ids_txn,
             event_ids,
             include_given,
-            ignore_events,
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
-        if ignore_events is None:
-            ignore_events = set()
-
+    def _get_auth_chain_ids_txn(
+        self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
+    ) -> List[str]:
         if include_given:
             results = set(event_ids)
         else:
             results = set()
 
-        base_sql = "SELECT auth_id FROM event_auth WHERE "
+        base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
 
         front = set(event_ids)
         while front:
@@ -92,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 txn.execute(base_sql + clause, args)
                 new_front.update(r[0] for r in txn)
 
-            new_front -= ignore_events
             new_front -= results
 
             front = new_front
@@ -100,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return list(results)
 
-    def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+    async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
 
@@ -109,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         chain.
 
         Returns:
-            Deferred[Set[str]]
+            The set of the difference in auth chains.
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_auth_chain_difference",
             self._get_auth_chain_difference_txn,
             state_sets,
@@ -257,13 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
-    def get_oldest_events_in_room(self, room_id):
-        return self.db_pool.runInteraction(
-            "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
-        )
-
-    def get_oldest_events_with_depth_in_room(self, room_id):
-        return self.db_pool.runInteraction(
+    async def get_oldest_events_with_depth_in_room(self, room_id):
+        return await self.db_pool.runInteraction(
             "get_oldest_events_with_depth_in_room",
             self.get_oldest_events_with_depth_in_room_txn,
             room_id,
@@ -303,15 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         else:
             return max(row["depth"] for row in rows)
 
-    def _get_oldest_events_in_room_txn(self, txn, room_id):
-        return self.db_pool.simple_select_onecol_txn(
-            txn,
-            table="event_backward_extremities",
-            keyvalues={"room_id": room_id},
-            retcol="event_id",
-        )
-
-    def get_prev_events_for_room(self, room_id: str):
+    async def get_prev_events_for_room(self, room_id: str) -> List[str]:
         """
         Gets a subset of the current forward extremities in the given room.
 
@@ -319,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         events which refer to hundreds of prev_events.
 
         Args:
-            room_id (str): room_id
+            room_id: room_id
 
         Returns:
-            Deferred[List[str]]: the event ids of the forward extremites
+            The event ids of the forward extremities.
 
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
         )
 
@@ -346,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return [row[0] for row in txn]
 
-    def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
+    async def get_rooms_with_many_extremities(
+        self, min_count: int, limit: int, room_id_filter: Iterable[str]
+    ) -> List[str]:
         """Get the top rooms with at least N extremities.
 
         Args:
-            min_count (int): The minimum number of extremities
-            limit (int): The maximum number of rooms to return.
-            room_id_filter (iterable[str]): room_ids to exclude from the results
+            min_count: The minimum number of extremities
+            limit: The maximum number of rooms to return.
+            room_id_filter: room_ids to exclude from the results
 
         Returns:
-            Deferred[list]: At most `limit` room IDs that have at least
-            `min_count` extremities, sorted by extremity count.
+            At most `limit` room IDs that have at least `min_count` extremities,
+            sorted by extremity count.
         """
 
         def _get_rooms_with_many_extremities_txn(txn):
@@ -381,23 +365,23 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             txn.execute(sql, query_args)
             return [room_id for room_id, in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
         )
 
     @cached(max_entries=5000, iterable=True)
-    def get_latest_event_ids_in_room(self, room_id):
-        return self.db_pool.simple_select_onecol(
+    async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+        return await self.db_pool.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
             retcol="event_id",
             desc="get_latest_event_ids_in_room",
         )
 
-    def get_min_depth(self, room_id):
-        """ For hte given room, get the minimum depth we have seen for it.
+    async def get_min_depth(self, room_id: str) -> int:
+        """For the given room, get the minimum depth we have seen for it.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_min_depth", self._get_min_depth_interaction, room_id
         )
 
@@ -412,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return int(min_depth) if min_depth is not None else None
 
-    def get_forward_extremeties_for_room(self, room_id, stream_ordering):
+    async def get_forward_extremeties_for_room(
+        self, room_id: str, stream_ordering: int
+    ) -> List[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -420,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         stream_orderings from that point.
 
         Args:
-            room_id (str):
-            stream_ordering (int):
+            room_id:
+            stream_ordering:
 
         Returns:
-            deferred, which resolves to a list of event_ids
+            A list of event_ids
         """
         # We want to make the cache more effective, so we clamp to the last
         # change before the given ordering.
@@ -440,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         if last_change > self.stream_ordering_month_ago:
             stream_ordering = min(last_change, stream_ordering)
 
-        return self._get_forward_extremeties_for_room(room_id, stream_ordering)
+        return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
 
     @cached(max_entries=5000, num_args=2)
-    def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
+    async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -452,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """
 
         if stream_ordering <= self.stream_ordering_month_ago:
-            raise StoreError(400, "stream_ordering too old")
+            raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
 
         sql = """
                 SELECT event_id FROM stream_ordering_to_exterm
@@ -468,31 +454,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             txn.execute(sql, (stream_ordering, room_id))
             return [event_id for event_id, in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
-    def get_backfill_events(self, room_id, event_list, limit):
+    async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
         """Get a list of Events for a given topic that occurred before (and
         including) the events in event_list. Return a list of max size `limit`
 
         Args:
-            txn
-            room_id (str)
-            event_list (list)
-            limit (int)
+            room_id
+            event_list
+            limit
         """
-        return (
-            self.db_pool.runInteraction(
-                "get_backfill_events",
-                self._get_backfill_events,
-                room_id,
-                event_list,
-                limit,
-            )
-            .addCallback(self.get_events_as_list)
-            .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
+        event_ids = await self.db_pool.runInteraction(
+            "get_backfill_events",
+            self._get_backfill_events,
+            room_id,
+            event_list,
+            limit,
         )
+        events = await self.get_events_as_list(event_ids)
+        return sorted(events, key=lambda e: -e.depth)
 
     def _get_backfill_events(self, txn, room_id, event_list, limit):
         logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -553,8 +536,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             latest_events,
             limit,
         )
-        events = await self.get_events_as_list(ids)
-        return events
+        return await self.get_events_as_list(ids)
 
     def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
 
@@ -652,8 +634,8 @@ class EventFederationStore(EventFederationWorkerStore):
             _delete_old_forward_extrem_cache_txn,
         )
 
-    def clean_room_for_join(self, room_id):
-        return self.db_pool.runInteraction(
+    async def clean_room_for_join(self, room_id):
+        return await self.db_pool.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
         )
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7c246d3e4c..7805fb814e 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,13 +15,15 @@
 # limitations under the License.
 
 import logging
-from typing import List
+from typing import Dict, List, Optional, Tuple, Union
+
+import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
 
@@ -86,83 +88,107 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         self._rotate_delay = 3
         self._rotate_count = 10000
 
-    @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
-    def get_unread_event_push_actions_by_room_for_user(
-        self, room_id, user_id, last_read_event_id
-    ):
-        ret = yield self.db_pool.runInteraction(
+    @cached(num_args=3, tree=True, max_entries=5000)
+    async def get_unread_event_push_actions_by_room_for_user(
+        self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+    ) -> Dict[str, int]:
+        """Get the notification count, the highlight count and the unread message count
+        for a given user in a given room after the given read receipt.
+
+        Note that this function assumes the user to be a current member of the room,
+        since it's either called by the sync handler to handle joined room entries, or by
+        the HTTP pusher to calculate the badge of unread joined rooms.
+
+        Args:
+            room_id: The room to retrieve the counts in.
+            user_id: The user to retrieve the counts for.
+            last_read_event_id: The event associated with the latest read receipt for
+                this user in this room. None if no receipt for this user in this room.
+
+        Returns
+            A dict containing the counts mentioned earlier in this docstring,
+            respectively under the keys "notify_count", "highlight_count" and
+            "unread_count".
+        """
+        return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
             room_id,
             user_id,
             last_read_event_id,
         )
-        return ret
 
     def _get_unread_counts_by_receipt_txn(
-        self, txn, room_id, user_id, last_read_event_id
+        self, txn, room_id, user_id, last_read_event_id,
     ):
-        sql = (
-            "SELECT stream_ordering"
-            " FROM events"
-            " WHERE room_id = ? AND event_id = ?"
-        )
-        txn.execute(sql, (room_id, last_read_event_id))
-        results = txn.fetchall()
-        if len(results) == 0:
-            return {"notify_count": 0, "highlight_count": 0}
+        stream_ordering = None
 
-        stream_ordering = results[0][0]
+        if last_read_event_id is not None:
+            stream_ordering = self.get_stream_id_for_event_txn(
+                txn, last_read_event_id, allow_none=True,
+            )
+
+        if stream_ordering is None:
+            # Either last_read_event_id is None, or it's an event we don't have (e.g.
+            # because it's been purged), in which case retrieve the stream ordering for
+            # the latest membership event from this user in this room (which we assume is
+            # a join).
+            event_id = self.db_pool.simple_select_one_onecol_txn(
+                txn=txn,
+                table="local_current_membership",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+                retcol="event_id",
+            )
+
+            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
 
         return self._get_unread_counts_by_pos_txn(
             txn, room_id, user_id, stream_ordering
         )
 
     def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
-
-        # First get number of notifications.
-        # We don't need to put a notif=1 clause as all rows always have
-        # notif=1
         sql = (
-            "SELECT count(*)"
+            "SELECT"
+            "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
+            "   COUNT(CASE WHEN highlight = 1 THEN 1 END),"
+            "   COUNT(CASE WHEN unread = 1 THEN 1 END)"
             " FROM event_push_actions ea"
-            " WHERE"
-            " user_id = ?"
-            " AND room_id = ?"
-            " AND stream_ordering > ?"
+            " WHERE user_id = ?"
+            "   AND room_id = ?"
+            "   AND stream_ordering > ?"
         )
 
         txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
-        notify_count = row[0] if row else 0
+
+        (notif_count, highlight_count, unread_count) = (0, 0, 0)
+
+        if row:
+            (notif_count, highlight_count, unread_count) = row
 
         txn.execute(
             """
-            SELECT notif_count FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
-        """,
+                SELECT notif_count, unread_count FROM event_push_summary
+                WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+            """,
             (room_id, user_id, stream_ordering),
         )
-        rows = txn.fetchall()
-        if rows:
-            notify_count += rows[0][0]
+        row = txn.fetchone()
 
-        # Now get the number of highlights
-        sql = (
-            "SELECT count(*)"
-            " FROM event_push_actions ea"
-            " WHERE"
-            " highlight = 1"
-            " AND user_id = ?"
-            " AND room_id = ?"
-            " AND stream_ordering > ?"
-        )
+        if row:
+            notif_count += row[0]
 
-        txn.execute(sql, (user_id, room_id, stream_ordering))
-        row = txn.fetchone()
-        highlight_count = row[0] if row else 0
+            if row[1] is not None:
+                # The unread_count column of event_push_summary is NULLable, so we need
+                # to make sure we don't try increasing the unread counts if it's NULL
+                # for this row.
+                unread_count += row[1]
 
-        return {"notify_count": notify_count, "highlight_count": highlight_count}
+        return {
+            "notify_count": notif_count,
+            "unread_count": unread_count,
+            "highlight_count": highlight_count,
+        }
 
     async def get_push_action_users_in_range(
         self, min_stream_ordering, max_stream_ordering
@@ -170,7 +196,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         def f(txn):
             sql = (
                 "SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
-                " stream_ordering >= ? AND stream_ordering <= ?"
+                " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1"
             )
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
             return [r[0] for r in txn]
@@ -223,6 +249,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -251,6 +278,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -325,6 +353,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -353,6 +382,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -384,62 +414,66 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # Now return the first `limit`
         return notifs[:limit]
 
-    def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
+    async def get_if_maybe_push_in_range_for_user(
+        self, user_id: str, min_stream_ordering: int
+    ) -> bool:
         """A fast check to see if there might be something to push for the
         user since the given stream ordering. May return false positives.
 
         Useful to know whether to bother starting a pusher on start up or not.
 
         Args:
-            user_id (str)
-            min_stream_ordering (int)
+            user_id
+            min_stream_ordering
 
         Returns:
-            Deferred[bool]: True if there may be push to process, False if
-            there definitely isn't.
+            True if there may be push to process, False if there definitely isn't.
         """
 
         def _get_if_maybe_push_in_range_for_user_txn(txn):
             sql = """
                 SELECT 1 FROM event_push_actions
-                WHERE user_id = ? AND stream_ordering > ?
+                WHERE user_id = ? AND stream_ordering > ? AND notif = 1
                 LIMIT 1
             """
 
             txn.execute(sql, (user_id, min_stream_ordering))
             return bool(txn.fetchone())
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_if_maybe_push_in_range_for_user",
             _get_if_maybe_push_in_range_for_user_txn,
         )
 
-    async def add_push_actions_to_staging(self, event_id, user_id_actions):
+    async def add_push_actions_to_staging(
+        self,
+        event_id: str,
+        user_id_actions: Dict[str, List[Union[dict, str]]],
+        count_as_unread: bool,
+    ) -> None:
         """Add the push actions for the event to the push action staging area.
 
         Args:
-            event_id (str)
-            user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
-                user_id to list of push actions, where an action can either be
-                a string or dict.
-
-        Returns:
-            Deferred
+            event_id
+            user_id_actions: A mapping of user_id to list of push actions, where
+                an action can either be a string or dict.
+            count_as_unread: Whether this event should increment unread counts.
         """
-
         if not user_id_actions:
             return
 
         # This is a helper function for generating the necessary tuple that
-        # can be used to inert into the `event_push_actions_staging` table.
+        # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(user_id, actions):
             is_highlight = 1 if _action_has_highlight(actions) else 0
+            notif = 1 if "notify" in actions else 0
             return (
                 event_id,  # event_id column
                 user_id,  # user_id column
                 _serialize_action(actions, is_highlight),  # actions column
-                1,  # notif column
+                notif,  # notif column
                 is_highlight,  # highlight column
+                int(count_as_unread),  # unread column
             )
 
         def _add_push_actions_to_staging_txn(txn):
@@ -448,8 +482,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
             sql = """
                 INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight)
-                VALUES (?, ?, ?, ?, ?)
+                    (event_id, user_id, actions, notif, highlight, unread)
+                VALUES (?, ?, ?, ?, ?, ?)
             """
 
             txn.executemany(
@@ -508,7 +542,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
         )
 
-    def find_first_stream_ordering_after_ts(self, ts):
+    async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
         """Gets the stream ordering corresponding to a given timestamp.
 
         Specifically, finds the stream_ordering of the first event that was
@@ -517,13 +551,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         relatively slow.
 
         Args:
-            ts (int): timestamp in millis
+            ts: timestamp in millis
 
         Returns:
-            Deferred[int]: stream ordering of the first event received on/after
-                the timestamp
+            stream ordering of the first event received on/after the timestamp
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "_find_first_stream_ordering_after_ts_txn",
             self._find_first_stream_ordering_after_ts_txn,
             ts,
@@ -611,7 +644,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "SELECT e.received_ts"
                 " FROM event_push_actions AS ep"
                 " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
-                " WHERE ep.stream_ordering > ?"
+                " WHERE ep.stream_ordering > ? AND notif = 1"
                 " ORDER BY ep.stream_ordering ASC"
                 " LIMIT 1"
             )
@@ -675,6 +708,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                 " FROM event_push_actions epa, events e"
                 " WHERE epa.event_id = e.event_id"
                 " AND epa.user_id = ? %s"
+                " AND epa.notif = 1"
                 " ORDER BY epa.stream_ordering DESC"
                 " LIMIT ?" % (before_clause,)
             )
@@ -814,24 +848,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
             SELECT user_id, room_id,
-                coalesce(old.notif_count, 0) + upd.notif_count,
+                coalesce(old.%s, 0) + upd.cnt,
                 upd.stream_ordering,
                 old.user_id
             FROM (
-                SELECT user_id, room_id, count(*) as notif_count,
+                SELECT user_id, room_id, count(*) as cnt,
                     max(stream_ordering) as stream_ordering
                 FROM event_push_actions
                 WHERE ? <= stream_ordering AND stream_ordering < ?
                     AND highlight = 0
+                    AND %s = 1
                 GROUP BY user_id, room_id
             ) AS upd
             LEFT JOIN event_push_summary AS old USING (user_id, room_id)
         """
 
-        txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
-        rows = txn.fetchall()
+        # First get the count of unread messages.
+        txn.execute(
+            sql % ("unread_count", "unread"),
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+        )
+
+        # We need to merge results from the two requests (the one that retrieves the
+        # unread count and the one that retrieves the notifications count) into a single
+        # object because we might not have the same amount of rows in each of them. To do
+        # this, we use a dict indexed on the user ID and room ID to make it easier to
+        # populate.
+        summaries = {}  # type: Dict[Tuple[str, str], _EventPushSummary]
+        for row in txn:
+            summaries[(row[0], row[1])] = _EventPushSummary(
+                unread_count=row[2],
+                stream_ordering=row[3],
+                old_user_id=row[4],
+                notif_count=0,
+            )
+
+        # Then get the count of notifications.
+        txn.execute(
+            sql % ("notif_count", "notif"),
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+        )
+
+        for row in txn:
+            if (row[0], row[1]) in summaries:
+                summaries[(row[0], row[1])].notif_count = row[2]
+            else:
+                # Because the rules on notifying are different than the rules on marking
+                # a message unread, we might end up with messages that notify but aren't
+                # marked unread, so we might not have a summary for this (user, room)
+                # tuple to complete.
+                summaries[(row[0], row[1])] = _EventPushSummary(
+                    unread_count=0,
+                    stream_ordering=row[3],
+                    old_user_id=row[4],
+                    notif_count=row[2],
+                )
 
-        logger.info("Rotating notifications, handling %d rows", len(rows))
+        logger.info("Rotating notifications, handling %d rows", len(summaries))
 
         # If the `old.user_id` above is NULL then we know there isn't already an
         # entry in the table, so we simply insert it. Otherwise we update the
@@ -841,22 +914,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             table="event_push_summary",
             values=[
                 {
-                    "user_id": row[0],
-                    "room_id": row[1],
-                    "notif_count": row[2],
-                    "stream_ordering": row[3],
+                    "user_id": user_id,
+                    "room_id": room_id,
+                    "notif_count": summary.notif_count,
+                    "unread_count": summary.unread_count,
+                    "stream_ordering": summary.stream_ordering,
                 }
-                for row in rows
-                if row[4] is None
+                for ((user_id, room_id), summary) in summaries.items()
+                if summary.old_user_id is None
             ],
         )
 
         txn.executemany(
             """
-                UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+                UPDATE event_push_summary
+                SET notif_count = ?, unread_count = ?, stream_ordering = ?
                 WHERE user_id = ? AND room_id = ?
             """,
-            ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
+            (
+                (
+                    summary.notif_count,
+                    summary.unread_count,
+                    summary.stream_ordering,
+                    user_id,
+                    room_id,
+                )
+                for ((user_id, room_id), summary) in summaries.items()
+                if summary.old_user_id is not None
+            ),
         )
 
         txn.execute(
@@ -882,3 +967,15 @@ def _action_has_highlight(actions):
             pass
 
     return False
+
+
+@attr.s(slots=True)
+class _EventPushSummary:
+    """Summary of pending event push actions for a given user in a given room.
+    Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+    """
+
+    unread_count = attr.ib(type=int)
+    stream_ordering = attr.ib(type=int)
+    old_user_id = attr.ib(type=str)
+    notif_count = attr.ib(type=int)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1a68bf32cb..9a80f419e3 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,13 +17,11 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
 
 import attr
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
@@ -34,7 +32,7 @@ from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.iterutils import batch_iter
@@ -99,29 +97,31 @@ class PersistEventsStore:
         self.store = main_data_store
         self.database_engine = db.engine
         self._clock = hs.get_clock()
+        self._instance_name = hs.get_instance_name()
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
         # Ideally we'd move these ID gens here, unfortunately some other ID
         # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen = self.store._backfill_id_gen  # type: StreamIdGenerator
-        self._stream_id_gen = self.store._stream_id_gen  # type: StreamIdGenerator
+        self._backfill_id_gen = (
+            self.store._backfill_id_gen
+        )  # type: MultiWriterIdGenerator
+        self._stream_id_gen = self.store._stream_id_gen  # type: MultiWriterIdGenerator
 
         # This should only exist on instances that are configured to write
         assert (
-            hs.config.worker.writers.events == hs.get_instance_name()
+            hs.get_instance_name() in hs.config.worker.writers.events
         ), "Can only instantiate EventsStore on master"
 
-    @defer.inlineCallbacks
-    def _persist_events_and_state_updates(
+    async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremeties: Dict[str, List[str]],
         backfilled: bool = False,
-    ):
+    ) -> None:
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
 
@@ -136,7 +136,7 @@ class PersistEventsStore:
             backfilled
 
         Returns:
-            Deferred: resolves when the events have been persisted
+            Resolves when the events have been persisted
         """
 
         # We want to calculate the stream orderings as late as possible, as
@@ -156,11 +156,11 @@ class PersistEventsStore:
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
         if backfilled:
-            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+            stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
         else:
-            stream_ordering_manager = self._stream_id_gen.get_next_mult(
+            stream_ordering_manager = await self._stream_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
 
@@ -168,7 +168,7 @@ class PersistEventsStore:
             for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "persist_events",
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
@@ -206,18 +206,17 @@ class PersistEventsStore:
                     (room_id,), list(latest_event_ids)
                 )
 
-    @defer.inlineCallbacks
-    def _get_events_which_are_prevs(self, event_ids):
+    async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
         """Filter the supplied list of event_ids to get those which are prev_events of
         existing (non-outlier/rejected) events.
 
         Args:
-            event_ids (Iterable[str]): event ids to filter
+            event_ids: event ids to filter
 
         Returns:
-            Deferred[List[str]]: filtered event ids
+            Filtered event ids
         """
-        results = []
+        results = []  # type: List[str]
 
         def _get_events_which_are_prevs_txn(txn, batch):
             sql = """
@@ -240,14 +239,13 @@ class PersistEventsStore:
             results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
             )
 
         return results
 
-    @defer.inlineCallbacks
-    def _get_prevs_before_rejected(self, event_ids):
+    async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
         """Get soft-failed ancestors to remove from the extremities.
 
         Given a set of events, find all those that have been soft-failed or
@@ -259,11 +257,11 @@ class PersistEventsStore:
         are separated by soft failed events.
 
         Args:
-            event_ids (Iterable[str]): Events to find prev events for. Note
-                that these must have already been persisted.
+            event_ids: Events to find prev events for. Note that these must have
+                already been persisted.
 
         Returns:
-            Deferred[set[str]]
+            The previous events.
         """
 
         # The set of event_ids to return. This includes all soft-failed events
@@ -304,7 +302,7 @@ class PersistEventsStore:
                         existing_prevs.add(prev_event_id)
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
             )
 
@@ -636,7 +634,9 @@ class PersistEventsStore:
         )
 
     @classmethod
-    def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+    def _filter_events_and_contexts_for_duplicates(
+        cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Ensure that we don't have the same event twice.
 
         Pick the earliest non-outlier if there is one, else the earliest one.
@@ -646,7 +646,9 @@ class PersistEventsStore:
         Returns:
             list[(EventBase, EventContext)]: filtered list
         """
-        new_events_and_contexts = OrderedDict()
+        new_events_and_contexts = (
+            OrderedDict()
+        )  # type: OrderedDict[str, Tuple[EventBase, EventContext]]
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
             if prev_event_context:
@@ -660,7 +662,12 @@ class PersistEventsStore:
                 new_events_and_contexts[event.event_id] = (event, context)
         return list(new_events_and_contexts.values())
 
-    def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+    def _update_room_depths_txn(
+        self,
+        txn,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        backfilled: bool,
+    ):
         """Update min_depth for each room
 
         Args:
@@ -669,7 +676,7 @@ class PersistEventsStore:
                 we are persisting
             backfilled (bool): True if the events were backfilled
         """
-        depth_updates = {}
+        depth_updates = {}  # type: Dict[str, int]
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -805,6 +812,7 @@ class PersistEventsStore:
             table="events",
             values=[
                 {
+                    "instance_name": self._instance_name,
                     "stream_ordering": event.internal_metadata.stream_ordering,
                     "topological_ordering": event.depth,
                     "depth": event.depth,
@@ -1301,9 +1309,9 @@ class PersistEventsStore:
         sql = """
             INSERT INTO event_push_actions (
                 room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight
+                topological_ordering, notif, highlight, unread
             )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
             FROM event_push_actions_staging
             WHERE event_id = ?
         """
@@ -1441,7 +1449,7 @@ class PersistEventsStore:
 
         Forward extremities are handled when we first start persisting the events.
         """
-        events_by_room = {}
+        events_by_room = {}  # type: Dict[str, List[EventBase]]
         for ev in events:
             events_by_room.setdefault(ev.room_id, []).append(ev)
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 35a0e09e3c..e53c6373a8 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventContentFields
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
@@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             where_clause="NOT have_censored",
         )
 
-    @defer.inlineCallbacks
-    def _background_reindex_fields_sender(self, progress, batch_size):
+    async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(rows)
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
         )
 
         if not result:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
             )
 
         return result
 
-    @defer.inlineCallbacks
-    def _background_reindex_origin_server_ts(self, progress, batch_size):
+    async def _background_reindex_origin_server_ts(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(rows_to_update)
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
         )
 
         if not result:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.EVENT_ORIGIN_SERVER_TS_NAME
             )
 
         return result
 
-    @defer.inlineCallbacks
-    def _cleanup_extremities_bg_update(self, progress, batch_size):
+    async def _cleanup_extremities_bg_update(self, progress, batch_size):
         """Background update to clean out extremities that should have been
         deleted previously.
 
@@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(original_set)
 
-        num_handled = yield self.db_pool.runInteraction(
+        num_handled = await self.db_pool.runInteraction(
             "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
         )
 
         if not num_handled:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.DELETE_SOFT_FAILED_EXTREMITIES
             )
 
             def _drop_table_txn(txn):
                 txn.execute("DROP TABLE _extremities_to_check")
 
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
             )
 
         return num_handled
 
-    @defer.inlineCallbacks
-    def _redactions_received_ts(self, progress, batch_size):
+    async def _redactions_received_ts(self, progress, batch_size):
         """Handles filling out the `received_ts` column in redactions.
         """
         last_event_id = progress.get("last_event_id", "")
@@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(rows)
 
-        count = yield self.db_pool.runInteraction(
+        count = await self.db_pool.runInteraction(
             "_redactions_received_ts", _redactions_received_ts_txn
         )
 
         if not count:
-            yield self.db_pool.updates._end_background_update("redactions_received_ts")
+            await self.db_pool.updates._end_background_update("redactions_received_ts")
 
         return count
 
-    @defer.inlineCallbacks
-    def _event_fix_redactions_bytes(self, progress, batch_size):
+    async def _event_fix_redactions_bytes(self, progress, batch_size):
         """Undoes hex encoded censored redacted event JSON.
         """
 
@@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             txn.execute("DROP INDEX redactions_censored_redacts")
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
         )
 
-        yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
+        await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
 
         return 1
 
-    @defer.inlineCallbacks
-    def _event_store_labels(self, progress, batch_size):
+    async def _event_store_labels(self, progress, batch_size):
         """Background update handler which will store labels for existing events."""
         last_event_id = progress.get("last_event_id", "")
 
@@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return nbrows
 
-        num_rows = yield self.db_pool.runInteraction(
+        num_rows = await self.db_pool.runInteraction(
             desc="event_store_labels", func=_event_store_labels_txn
         )
 
         if not num_rows:
-            yield self.db_pool.updates._end_background_update("event_store_labels")
+            await self.db_pool.updates._end_background_update("event_store_labels")
 
         return num_rows
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 755b7a2a85..17f5997b89 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools
 import logging
 import threading
 from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
 
 from constantly import NamedConstant, Names
+from typing_extensions import Literal
 
 from twisted.internet import defer
 
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
     EventFormatVersions,
     RoomVersions,
 )
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -41,9 +42,10 @@ from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -77,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
-        if hs.config.worker.writers.events == hs.get_instance_name():
-            # We are the process in charge of generating stream ids for events,
-            # so instantiate ID generators based on the database
-            self._stream_id_gen = StreamIdGenerator(
-                db_conn, "events", "stream_ordering",
+        if isinstance(database.engine, PostgresEngine):
+            # If we're using Postgres than we can use `MultiWriterIdGenerator`
+            # regardless of whether this process writes to the streams or not.
+            self._stream_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                instance_name=hs.get_instance_name(),
+                table="events",
+                instance_column="instance_name",
+                id_column="stream_ordering",
+                sequence_name="events_stream_seq",
             )
-            self._backfill_id_gen = StreamIdGenerator(
-                db_conn,
-                "events",
-                "stream_ordering",
-                step=-1,
-                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+            self._backfill_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                instance_name=hs.get_instance_name(),
+                table="events",
+                instance_column="instance_name",
+                id_column="stream_ordering",
+                sequence_name="events_backfill_stream_seq",
+                positive=False,
             )
         else:
-            # Another process is in charge of persisting events and generating
-            # stream IDs: rely on the replication streams to let us know which
-            # IDs we can process.
-            self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
-            self._backfill_id_gen = SlavedIdTracker(
-                db_conn, "events", "stream_ordering", step=-1
-            )
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.events:
+                self._stream_id_gen = StreamIdGenerator(
+                    db_conn, "events", "stream_ordering",
+                )
+                self._backfill_id_gen = StreamIdGenerator(
+                    db_conn,
+                    "events",
+                    "stream_ordering",
+                    step=-1,
+                    extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+                )
+            else:
+                self._stream_id_gen = SlavedIdTracker(
+                    db_conn, "events", "stream_ordering"
+                )
+                self._backfill_id_gen = SlavedIdTracker(
+                    db_conn, "events", "stream_ordering", step=-1
+                )
 
         self._get_event_cache = Cache(
             "*getEvent*",
@@ -112,69 +141,58 @@ class EventsWorkerStore(SQLBaseStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == EventsStream.NAME:
-            self._stream_id_gen.advance(token)
+            self._stream_id_gen.advance(instance_name, token)
         elif stream_name == BackfillStream.NAME:
-            self._backfill_id_gen.advance(-token)
+            self._backfill_id_gen.advance(instance_name, -token)
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def get_received_ts(self, event_id):
+    async def get_received_ts(self, event_id: str) -> Optional[int]:
         """Get received_ts (when it was persisted) for the event.
 
         Raises an exception for unknown events.
 
         Args:
-            event_id (str)
+            event_id: The event ID to query.
 
         Returns:
-            Deferred[int|None]: Timestamp in milliseconds, or None for events
-            that were persisted before received_ts was implemented.
+            Timestamp in milliseconds, or None for events that were persisted
+            before received_ts was implemented.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="events",
             keyvalues={"event_id": event_id},
             retcol="received_ts",
             desc="get_received_ts",
         )
 
-    def get_received_ts_by_stream_pos(self, stream_ordering):
-        """Given a stream ordering get an approximate timestamp of when it
-        happened.
-
-        This is done by simply taking the received ts of the first event that
-        has a stream ordering greater than or equal to the given stream pos.
-        If none exists returns the current time, on the assumption that it must
-        have happened recently.
-
-        Args:
-            stream_ordering (int)
-
-        Returns:
-            Deferred[int]
-        """
-
-        def _get_approximate_received_ts_txn(txn):
-            sql = """
-                SELECT received_ts FROM events
-                WHERE stream_ordering >= ?
-                LIMIT 1
-            """
-
-            txn.execute(sql, (stream_ordering,))
-            row = txn.fetchone()
-            if row and row[0]:
-                ts = row[0]
-            else:
-                ts = self.clock.time_msec()
-
-            return ts
+    # Inform mypy that if allow_none is False (the default) then get_event
+    # always returns an EventBase.
+    @overload
+    async def get_event(
+        self,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: Literal[False] = False,
+        check_room_id: Optional[str] = None,
+    ) -> EventBase:
+        ...
 
-        return self.db_pool.runInteraction(
-            "get_approximate_received_ts", _get_approximate_received_ts_txn
-        )
+    @overload
+    async def get_event(
+        self,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: Literal[True] = False,
+        check_room_id: Optional[str] = None,
+    ) -> Optional[EventBase]:
+        ...
 
-    @defer.inlineCallbacks
-    def get_event(
+    async def get_event(
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -182,7 +200,7 @@ class EventsWorkerStore(SQLBaseStore):
         allow_rejected: bool = False,
         allow_none: bool = False,
         check_room_id: Optional[str] = None,
-    ):
+    ) -> Optional[EventBase]:
         """Get an event from the database by event_id.
 
         Args:
@@ -207,12 +225,12 @@ class EventsWorkerStore(SQLBaseStore):
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
-            Deferred[EventBase|None]
+            The event, or None if the event was not found.
         """
         if not isinstance(event_id, str):
             raise TypeError("Invalid event event_id %r" % (event_id,))
 
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [event_id],
             redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
@@ -230,14 +248,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event
 
-    @defer.inlineCallbacks
-    def get_events(
+    async def get_events(
         self,
-        event_ids: List[str],
+        event_ids: Iterable[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
-    ):
+    ) -> Dict[str, EventBase]:
         """Get events from the database
 
         Args:
@@ -256,9 +273,9 @@ class EventsWorkerStore(SQLBaseStore):
                 omits rejeted events from the response.
 
         Returns:
-            Deferred : Dict from event_id to event.
+            A mapping from event_id to event.
         """
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             event_ids,
             redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
@@ -267,14 +284,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
-    @defer.inlineCallbacks
-    def get_events_as_list(
+    async def get_events_as_list(
         self,
-        event_ids: List[str],
+        event_ids: Collection[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
-    ):
+    ) -> List[EventBase]:
         """Get events from the database and return in a list in the same order
         as given by `event_ids` arg.
 
@@ -295,8 +311,8 @@ class EventsWorkerStore(SQLBaseStore):
                 omits rejected events from the response.
 
         Returns:
-            Deferred[list[EventBase]]: List of events fetched from the database. The
-            events are in the same order as `event_ids` arg.
+            List of events fetched from the database. The events are in the same
+            order as `event_ids` arg.
 
             Note that the returned list may be smaller than the list of event
             IDs if not all events could be fetched.
@@ -306,7 +322,7 @@ class EventsWorkerStore(SQLBaseStore):
             return []
 
         # there may be duplicates so we cast the list to a set
-        event_entry_map = yield self._get_events_from_cache_or_db(
+        event_entry_map = await self._get_events_from_cache_or_db(
             set(event_ids), allow_rejected=allow_rejected
         )
 
@@ -341,7 +357,7 @@ class EventsWorkerStore(SQLBaseStore):
                     continue
 
                 redacted_event_id = entry.event.redacts
-                event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+                event_map = await self._get_events_from_cache_or_db([redacted_event_id])
                 original_event_entry = event_map.get(redacted_event_id)
                 if not original_event_entry:
                     # we don't have the redacted event (or it was rejected).
@@ -407,7 +423,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             if get_prev_content:
                 if "replaces_state" in event.unsigned:
-                    prev = yield self.get_event(
+                    prev = await self.get_event(
                         event.unsigned["replaces_state"],
                         get_prev_content=False,
                         allow_none=True,
@@ -419,8 +435,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return events
 
-    @defer.inlineCallbacks
-    def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -435,7 +450,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected events are omitted from the response.
 
         Returns:
-            Deferred[Dict[str, _EventCacheEntry]]:
+            Dict[str, _EventCacheEntry]:
                 map from event id to result
         """
         event_entry_map = self._get_events_from_cache(
@@ -453,7 +468,7 @@ class EventsWorkerStore(SQLBaseStore):
             # the events have been redacted, and if so pulling the redaction event out
             # of the database to check it.
             #
-            missing_events = yield self._get_events_from_db(
+            missing_events = await self._get_events_from_db(
                 missing_events_ids, allow_rejected=allow_rejected
             )
 
@@ -561,8 +576,7 @@ class EventsWorkerStore(SQLBaseStore):
                 with PreserveLoggingContext():
                     self.hs.get_reactor().callFromThread(fire, event_list, e)
 
-    @defer.inlineCallbacks
-    def _get_events_from_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_db(self, event_ids, allow_rejected=False):
         """Fetch a bunch of events from the database.
 
         Returned events will be added to the cache for future lookups.
@@ -576,7 +590,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected events are omitted from the response.
 
         Returns:
-            Deferred[Dict[str, _EventCacheEntry]]:
+            Dict[str, _EventCacheEntry]:
                 map from event id to result. May return extra events which
                 weren't asked for.
         """
@@ -584,7 +598,7 @@ class EventsWorkerStore(SQLBaseStore):
         events_to_fetch = event_ids
 
         while events_to_fetch:
-            row_map = yield self._enqueue_events(events_to_fetch)
+            row_map = await self._enqueue_events(events_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids = set()
@@ -610,8 +624,20 @@ class EventsWorkerStore(SQLBaseStore):
             if not allow_rejected and rejected_reason:
                 continue
 
-            d = db_to_json(row["json"])
-            internal_metadata = db_to_json(row["internal_metadata"])
+            # If the event or metadata cannot be parsed, log the error and act
+            # as if the event is unknown.
+            try:
+                d = db_to_json(row["json"])
+            except ValueError:
+                logger.error("Unable to parse json from event: %s", event_id)
+                continue
+            try:
+                internal_metadata = db_to_json(row["internal_metadata"])
+            except ValueError:
+                logger.error(
+                    "Unable to parse internal_metadata from event: %s", event_id
+                )
+                continue
 
             format_version = row["format_version"]
             if format_version is None:
@@ -622,19 +648,38 @@ class EventsWorkerStore(SQLBaseStore):
             room_version_id = row["room_version_id"]
 
             if not room_version_id:
-                # this should only happen for out-of-band membership events
-                if not internal_metadata.get("out_of_band_membership"):
-                    logger.warning(
-                        "Room %s for event %s is unknown", d["room_id"], event_id
+                # this should only happen for out-of-band membership events which
+                # arrived before #6983 landed. For all other events, we should have
+                # an entry in the 'rooms' table.
+                #
+                # However, the 'out_of_band_membership' flag is unreliable for older
+                # invites, so just accept it for all membership events.
+                #
+                if d["type"] != EventTypes.Member:
+                    raise Exception(
+                        "Room %s for event %s is unknown" % (d["room_id"], event_id)
                     )
-                    continue
 
-                # take a wild stab at the room version based on the event format
+                # so, assuming this is an out-of-band-invite that arrived before #6983
+                # landed, we know that the room version must be v5 or earlier (because
+                # v6 hadn't been invented at that point, so invites from such rooms
+                # would have been rejected.)
+                #
+                # The main reason we need to know the room version here (other than
+                # choosing the right python Event class) is in case the event later has
+                # to be redacted - and all the room versions up to v5 used the same
+                # redaction algorithm.
+                #
+                # So, the following approximations should be adequate.
+
                 if format_version == EventFormatVersions.V1:
+                    # if it's event format v1 then it must be room v1 or v2
                     room_version = RoomVersions.V1
                 elif format_version == EventFormatVersions.V2:
+                    # if it's event format v2 then it must be room v3
                     room_version = RoomVersions.V3
                 else:
+                    # if it's event format v3 then it must be room v4 or v5
                     room_version = RoomVersions.V5
             else:
                 room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
@@ -686,8 +731,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    @defer.inlineCallbacks
-    def _enqueue_events(self, events):
+    async def _enqueue_events(self, events):
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -696,7 +740,7 @@ class EventsWorkerStore(SQLBaseStore):
             events (Iterable[str]): events to be fetched.
 
         Returns:
-            Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+            Dict[str, Dict]: map from event id to row data from the database.
                 May contain events that weren't requested.
         """
 
@@ -719,7 +763,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
-            row_map = yield events_d
+            row_map = await events_d
         logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
 
         return row_map
@@ -807,20 +851,24 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_dict
 
-    def _maybe_redact_event_row(self, original_ev, redactions, event_map):
+    def _maybe_redact_event_row(
+        self,
+        original_ev: EventBase,
+        redactions: Iterable[str],
+        event_map: Dict[str, EventBase],
+    ) -> Optional[EventBase]:
         """Given an event object and a list of possible redacting event ids,
         determine whether to honour any of those redactions and if so return a redacted
         event.
 
         Args:
-             original_ev (EventBase):
-             redactions (iterable[str]): list of event ids of potential redaction events
-             event_map (dict[str, EventBase]): other events which have been fetched, in
-                 which we can look up the redaaction events. Map from event id to event.
+             original_ev: The original event.
+             redactions: list of event ids of potential redaction events
+             event_map: other events which have been fetched, in which we can
+                look up the redaaction events. Map from event id to event.
 
         Returns:
-            Deferred[EventBase|None]: if the event should be redacted, a pruned
-                event object. Otherwise, None.
+            If the event should be redacted, a pruned event object. Otherwise, None.
         """
         if original_ev.type == "m.room.create":
             # we choose to ignore redactions of m.room.create events.
@@ -878,12 +926,11 @@ class EventsWorkerStore(SQLBaseStore):
         # no valid redaction found for this event
         return None
 
-    @defer.inlineCallbacks
-    def have_events_in_timeline(self, event_ids):
+    async def have_events_in_timeline(self, event_ids):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="events",
             retcols=("event_id",),
             column="event_id",
@@ -894,15 +941,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
-    @defer.inlineCallbacks
-    def have_seen_events(self, event_ids):
+    async def have_seen_events(self, event_ids):
         """Given a list of event ids, check if we have already processed them.
 
         Args:
             event_ids (iterable[str]):
 
         Returns:
-            Deferred[set[str]]: The events we have already seen.
+            set[str]: The events we have already seen.
         """
         results = set()
 
@@ -918,41 +964,11 @@ class EventsWorkerStore(SQLBaseStore):
         # break the input up into chunks of 100
         input_iterator = iter(event_ids)
         for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "have_seen_events", have_seen_events_txn, chunk
             )
         return results
 
-    def _get_total_state_event_counts_txn(self, txn, room_id):
-        """
-        See get_total_state_event_counts.
-        """
-        # We join against the events table as that has an index on room_id
-        sql = """
-            SELECT COUNT(*) FROM state_events
-            INNER JOIN events USING (room_id, event_id)
-            WHERE room_id=?
-        """
-        txn.execute(sql, (room_id,))
-        row = txn.fetchone()
-        return row[0] if row else 0
-
-    def get_total_state_event_counts(self, room_id):
-        """
-        Gets the total number of state events in a room.
-
-        Args:
-            room_id (str)
-
-        Returns:
-            Deferred[int]
-        """
-        return self.db_pool.runInteraction(
-            "get_total_state_event_counts",
-            self._get_total_state_event_counts_txn,
-            room_id,
-        )
-
     def _get_current_state_event_counts_txn(self, txn, room_id):
         """
         See get_current_state_event_counts.
@@ -962,24 +978,23 @@ class EventsWorkerStore(SQLBaseStore):
         row = txn.fetchone()
         return row[0] if row else 0
 
-    def get_current_state_event_counts(self, room_id):
+    async def get_current_state_event_counts(self, room_id: str) -> int:
         """
         Gets the current number of state events in a room.
 
         Args:
-            room_id (str)
+            room_id: The room ID to query.
 
         Returns:
-            Deferred[int]
+            The current number of state events.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_current_state_event_counts",
             self._get_current_state_event_counts_txn,
             room_id,
         )
 
-    @defer.inlineCallbacks
-    def get_room_complexity(self, room_id):
+    async def get_room_complexity(self, room_id):
         """
         Get a rough approximation of the complexity of the room. This is used by
         remote servers to decide whether they wish to join the room or not.
@@ -990,9 +1005,9 @@ class EventsWorkerStore(SQLBaseStore):
             room_id (str)
 
         Returns:
-            Deferred[dict[str:int]] of complexity version to complexity.
+            dict[str:int] of complexity version to complexity.
         """
-        state_events = yield self.get_current_state_event_counts(room_id)
+        state_events = await self.get_current_state_event_counts(room_id)
 
         # Call this one "v1", so we can introduce new ones as we want to develop
         # it.
@@ -1008,7 +1023,9 @@ class EventsWorkerStore(SQLBaseStore):
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
-    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+    async def get_all_new_forward_event_rows(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -1016,7 +1033,7 @@ class EventsWorkerStore(SQLBaseStore):
             current_id: the maximum stream_id to return up to
             limit: the maximum number of rows to return
 
-        Returns: Deferred[List[Tuple]]
+        Returns:
             a list of events stream rows. Each tuple consists of a stream id as
             the first element, followed by fields suitable for casting into an
             EventsStreamRow.
@@ -1037,18 +1054,20 @@ class EventsWorkerStore(SQLBaseStore):
             txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
         )
 
-    def get_ex_outlier_stream_rows(self, last_id, current_id):
+    async def get_ex_outlier_stream_rows(
+        self, last_id: int, current_id: int
+    ) -> List[Tuple]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
             last_id: the last stream_id from the previous batch.
             current_id: the maximum stream_id to return up to
 
-        Returns: Deferred[List[Tuple]]
+        Returns:
             a list of events stream rows. Each tuple consists of a stream id as
             the first element, followed by fields suitable for casting into an
             EventsStreamRow.
@@ -1071,7 +1090,7 @@ class EventsWorkerStore(SQLBaseStore):
             txn.execute(sql, (last_id, current_id))
             return txn.fetchall()
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
         )
 
@@ -1222,97 +1241,6 @@ class EventsWorkerStore(SQLBaseStore):
 
         return rows, to_token, True
 
-    @cached(num_args=5, max_entries=10)
-    def get_all_new_events(
-        self,
-        last_backfill_id,
-        last_forward_id,
-        current_backfill_id,
-        current_forward_id,
-        limit,
-    ):
-        """Get all the new events that have arrived at the server either as
-        new events or as backfilled events"""
-        have_backfill_events = last_backfill_id != current_backfill_id
-        have_forward_events = last_forward_id != current_forward_id
-
-        if not have_backfill_events and not have_forward_events:
-            return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
-        def get_all_new_events_txn(txn):
-            sql = (
-                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " WHERE ? < stream_ordering AND stream_ordering <= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            if have_forward_events:
-                txn.execute(sql, (last_forward_id, current_forward_id, limit))
-                new_forward_events = txn.fetchall()
-
-                if len(new_forward_events) == limit:
-                    upper_bound = new_forward_events[-1][0]
-                else:
-                    upper_bound = current_forward_id
-
-                sql = (
-                    "SELECT event_stream_ordering, event_id, state_group"
-                    " FROM ex_outlier_stream"
-                    " WHERE ? > event_stream_ordering"
-                    " AND event_stream_ordering >= ?"
-                    " ORDER BY event_stream_ordering DESC"
-                )
-                txn.execute(sql, (last_forward_id, upper_bound))
-                forward_ex_outliers = txn.fetchall()
-            else:
-                new_forward_events = []
-                forward_ex_outliers = []
-
-            sql = (
-                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " WHERE ? > stream_ordering AND stream_ordering >= ?"
-                " ORDER BY stream_ordering DESC"
-                " LIMIT ?"
-            )
-            if have_backfill_events:
-                txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
-                new_backfill_events = txn.fetchall()
-
-                if len(new_backfill_events) == limit:
-                    upper_bound = new_backfill_events[-1][0]
-                else:
-                    upper_bound = current_backfill_id
-
-                sql = (
-                    "SELECT -event_stream_ordering, event_id, state_group"
-                    " FROM ex_outlier_stream"
-                    " WHERE ? > event_stream_ordering"
-                    " AND event_stream_ordering >= ?"
-                    " ORDER BY event_stream_ordering DESC"
-                )
-                txn.execute(sql, (-last_backfill_id, -upper_bound))
-                backward_ex_outliers = txn.fetchall()
-            else:
-                new_backfill_events = []
-                backward_ex_outliers = []
-
-            return AllNewEventsResult(
-                new_forward_events,
-                new_backfill_events,
-                forward_ex_outliers,
-                backward_ex_outliers,
-            )
-
-        return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
-
     async def is_event_after(self, event_id1, event_id2):
         """Returns True if event_id1 is after event_id2 in the stream
         """
@@ -1320,9 +1248,9 @@ class EventsWorkerStore(SQLBaseStore):
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_event_ordering(self, event_id):
-        res = yield self.db_pool.simple_select_one(
+    @cached(max_entries=5000)
+    async def get_event_ordering(self, event_id):
+        res = await self.db_pool.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
             keyvalues={"event_id": event_id},
@@ -1334,11 +1262,11 @@ class EventsWorkerStore(SQLBaseStore):
 
         return (int(res["topological_ordering"]), int(res["stream_ordering"]))
 
-    def get_next_event_to_expire(self):
+    async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
         """Retrieve the entry with the lowest expiry timestamp in the event_expiry
         table, or None if there's no more event to expire.
 
-        Returns: Deferred[Optional[Tuple[str, int]]]
+        Returns:
             A tuple containing the event ID as its first element and an expiry timestamp
             as its second one, if there's at least one row in the event_expiry table.
             None otherwise.
@@ -1354,17 +1282,6 @@ class EventsWorkerStore(SQLBaseStore):
 
             return txn.fetchone()
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
-
-
-AllNewEventsResult = namedtuple(
-    "AllNewEventsResult",
-    [
-        "new_forward_events",
-        "new_backfill_events",
-        "forward_ex_outliers",
-        "backward_ex_outliers",
-    ],
-)
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 45a1760170..d2f5b9a502 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
 
@@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
 
         return db_to_json(def_json)
 
-    def add_user_filter(self, user_localpart, user_filter):
+    async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
         def_json = encode_canonical_json(user_filter)
 
         # Need an atomic transaction to SELECT the maximal ID so far then
@@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
 
             return filter_id
 
-        return self.db_pool.runInteraction("add_user_filter", _do_txn)
+        return await self.db_pool.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 380db3a3f3..ccfbb2135e 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
 
 
 class GroupServerWorkerStore(SQLBaseStore):
-    def get_group(self, group_id):
-        return self.db_pool.simple_select_one(
+    async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="groups",
             keyvalues={"group_id": group_id},
             retcols=(
@@ -44,31 +44,35 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_group",
         )
 
-    def get_users_in_group(self, group_id, include_private=False):
+    async def get_users_in_group(
+        self, group_id: str, include_private: bool = False
+    ) -> List[Dict[str, Any]]:
         # TODO: Pagination
 
         keyvalues = {"group_id": group_id}
         if not include_private:
             keyvalues["is_public"] = True
 
-        return self.db_pool.simple_select_list(
+        return await self.db_pool.simple_select_list(
             table="group_users",
             keyvalues=keyvalues,
             retcols=("user_id", "is_public", "is_admin"),
             desc="get_users_in_group",
         )
 
-    def get_invited_users_in_group(self, group_id):
+    async def get_invited_users_in_group(self, group_id: str) -> List[str]:
         # TODO: Pagination
 
-        return self.db_pool.simple_select_onecol(
+        return await self.db_pool.simple_select_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id},
             retcol="user_id",
             desc="get_invited_users_in_group",
         )
 
-    def get_rooms_in_group(self, group_id: str, include_private: bool = False):
+    async def get_rooms_in_group(
+        self, group_id: str, include_private: bool = False
+    ) -> List[Dict[str, Union[str, bool]]]:
         """Retrieve the rooms that belong to a given group. Does not return rooms that
         lack members.
 
@@ -77,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             include_private: Whether to return private rooms in results
 
         Returns:
-            Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
-            form of:
+            A list of dictionaries, each in the form of:
 
             {
               "room_id": "!a_room_id:example.com",  # The ID of the room
@@ -115,13 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore):
                 for room_id, is_public in txn
             ]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_rooms_in_group", _get_rooms_in_group_txn
         )
 
-    def get_rooms_for_summary_by_category(
+    async def get_rooms_for_summary_by_category(
         self, group_id: str, include_private: bool = False,
-    ):
+    ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
         """Get the rooms and categories that should be included in a summary request
 
         Args:
@@ -129,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             include_private: Whether to return private rooms in results
 
         Returns:
-            Deferred[Tuple[List, Dict]]: A tuple containing:
+            A tuple containing:
 
                 * A list of dictionaries with the keys:
                     * "room_id": str, the room ID
@@ -205,7 +208,7 @@ class GroupServerWorkerStore(SQLBaseStore):
 
             return rooms, categories
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_rooms_for_summary", _get_rooms_for_summary_txn
         )
 
@@ -265,25 +268,25 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         return role
 
-    def get_local_groups_for_room(self, room_id):
+    async def get_local_groups_for_room(self, room_id: str) -> List[str]:
         """Get all of the local group that contain a given room
         Args:
-            room_id (str): The ID of a room
+            room_id: The ID of a room
         Returns:
-            Deferred[list[str]]: A twisted.Deferred containing a list of group ids
-                containing this room
+            A list of group ids containing this room
         """
-        return self.db_pool.simple_select_onecol(
+        return await self.db_pool.simple_select_onecol(
             table="group_rooms",
             keyvalues={"room_id": room_id},
             retcol="group_id",
             desc="get_local_groups_for_room",
         )
 
-    def get_users_for_summary_by_role(self, group_id, include_private=False):
+    async def get_users_for_summary_by_role(self, group_id, include_private=False):
         """Get the users and roles that should be included in a summary request
 
-        Returns ([users], [roles])
+        Returns:
+            ([users], [roles])
         """
 
         def _get_users_for_summary_txn(txn):
@@ -337,21 +340,24 @@ class GroupServerWorkerStore(SQLBaseStore):
 
             return users, roles
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_for_summary_by_role", _get_users_for_summary_txn
         )
 
-    def is_user_in_group(self, user_id, group_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
+        result = await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
             allow_none=True,
             desc="is_user_in_group",
-        ).addCallback(lambda r: bool(r))
+        )
+        return bool(result)
 
-    def is_user_admin_in_group(self, group_id, user_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_user_admin_in_group(
+        self, group_id: str, user_id: str
+    ) -> Optional[bool]:
+        return await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="is_admin",
@@ -359,10 +365,12 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="is_user_admin_in_group",
         )
 
-    def is_user_invited_to_local_group(self, group_id, user_id):
+    async def is_user_invited_to_local_group(
+        self, group_id: str, user_id: str
+    ) -> Optional[bool]:
         """Has the group server invited a user?
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
@@ -370,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    def get_users_membership_info_in_group(self, group_id, user_id):
+    async def get_users_membership_info_in_group(self, group_id, user_id):
         """Get a dict describing the membership of a user in a group.
 
         Example if joined:
@@ -381,7 +389,8 @@ class GroupServerWorkerStore(SQLBaseStore):
                 "is_privileged": False,
             }
 
-        Returns an empty dict if the user is not join/invite/etc
+        Returns:
+             An empty dict if the user is not join/invite/etc
         """
 
         def _get_users_membership_in_group_txn(txn):
@@ -413,21 +422,21 @@ class GroupServerWorkerStore(SQLBaseStore):
 
             return {}
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_membership_info_in_group", _get_users_membership_in_group_txn
         )
 
-    def get_publicised_groups_for_user(self, user_id):
+    async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
         """Get all groups a user is publicising
         """
-        return self.db_pool.simple_select_onecol(
+        return await self.db_pool.simple_select_onecol(
             table="local_group_membership",
             keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
             retcol="group_id",
             desc="get_publicised_groups_for_user",
         )
 
-    def get_attestations_need_renewals(self, valid_until_ms):
+    async def get_attestations_need_renewals(self, valid_until_ms):
         """Get all attestations that need to be renewed until givent time
         """
 
@@ -439,7 +448,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             txn.execute(sql, (valid_until_ms,))
             return self.db_pool.cursor_to_dict(txn)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_attestations_need_renewals", _get_attestations_need_renewals_txn
         )
 
@@ -461,15 +470,15 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         return None
 
-    def get_joined_groups(self, user_id):
-        return self.db_pool.simple_select_onecol(
+    async def get_joined_groups(self, user_id: str) -> List[str]:
+        return await self.db_pool.simple_select_onecol(
             table="local_group_membership",
             keyvalues={"user_id": user_id, "membership": "join"},
             retcol="group_id",
             desc="get_joined_groups",
         )
 
-    def get_all_groups_for_user(self, user_id, now_token):
+    async def get_all_groups_for_user(self, user_id, now_token):
         def _get_all_groups_for_user_txn(txn):
             sql = """
                 SELECT group_id, type, membership, u.content
@@ -489,7 +498,7 @@ class GroupServerWorkerStore(SQLBaseStore):
                 for row in txn
             ]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_all_groups_for_user", _get_all_groups_for_user_txn
         )
 
@@ -580,22 +589,41 @@ class GroupServerWorkerStore(SQLBaseStore):
 
 
 class GroupServerStore(GroupServerWorkerStore):
-    def set_group_join_policy(self, group_id, join_policy):
+    async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
         """Set the join policy of a group.
 
         join_policy can be one of:
          * "invite"
          * "open"
         """
-        return self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
             updatevalues={"join_policy": join_policy},
             desc="set_group_join_policy",
         )
 
-    def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
-        return self.db_pool.runInteraction(
+    async def add_room_to_summary(
+        self,
+        group_id: str,
+        room_id: str,
+        category_id: str,
+        order: int,
+        is_public: Optional[bool],
+    ) -> None:
+        """Add (or update) room's entry in summary.
+
+        Args:
+            group_id
+            room_id
+            category_id: If not None then adds the category to the end of
+                the summary if its not already there.
+            order: If not None inserts the room at that position, e.g. an order
+                of 1 will put the room first. Otherwise, the room gets added to
+                the end.
+            is_public
+        """
+        await self.db_pool.runInteraction(
             "add_room_to_summary",
             self._add_room_to_summary_txn,
             group_id,
@@ -606,18 +634,26 @@ class GroupServerStore(GroupServerWorkerStore):
         )
 
     def _add_room_to_summary_txn(
-        self, txn, group_id, room_id, category_id, order, is_public
-    ):
+        self,
+        txn,
+        group_id: str,
+        room_id: str,
+        category_id: str,
+        order: int,
+        is_public: Optional[bool],
+    ) -> None:
         """Add (or update) room's entry in summary.
 
         Args:
-            group_id (str)
-            room_id (str)
-            category_id (str): If not None then adds the category to the end of
-                the summary if its not already there. [Optional]
-            order (int): If not None inserts the room at that position, e.g.
-                an order of 1 will put the room first. Otherwise, the room gets
-                added to the end.
+            txn
+            group_id
+            room_id
+            category_id: If not None then adds the category to the end of
+                the summary if its not already there.
+            order: If not None inserts the room at that position, e.g. an order
+                of 1 will put the room first. Otherwise, the room gets added to
+                the end.
+            is_public
         """
         room_in_group = self.db_pool.simple_select_one_onecol_txn(
             txn,
@@ -722,11 +758,13 @@ class GroupServerStore(GroupServerWorkerStore):
                 },
             )
 
-    def remove_room_from_summary(self, group_id, room_id, category_id):
+    async def remove_room_from_summary(
+        self, group_id: str, room_id: str, category_id: str
+    ) -> int:
         if category_id is None:
             category_id = _DEFAULT_CATEGORY_ID
 
-        return self.db_pool.simple_delete(
+        return await self.db_pool.simple_delete(
             table="group_summary_rooms",
             keyvalues={
                 "group_id": group_id,
@@ -736,7 +774,13 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="remove_room_from_summary",
         )
 
-    def upsert_group_category(self, group_id, category_id, profile, is_public):
+    async def upsert_group_category(
+        self,
+        group_id: str,
+        category_id: str,
+        profile: Optional[JsonDict],
+        is_public: Optional[bool],
+    ) -> None:
         """Add/update room category for group
         """
         insertion_values = {}
@@ -752,7 +796,7 @@ class GroupServerStore(GroupServerWorkerStore):
         else:
             update_values["is_public"] = is_public
 
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             values=update_values,
@@ -760,14 +804,20 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="upsert_group_category",
         )
 
-    def remove_group_category(self, group_id, category_id):
-        return self.db_pool.simple_delete(
+    async def remove_group_category(self, group_id: str, category_id: str) -> int:
+        return await self.db_pool.simple_delete(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             desc="remove_group_category",
         )
 
-    def upsert_group_role(self, group_id, role_id, profile, is_public):
+    async def upsert_group_role(
+        self,
+        group_id: str,
+        role_id: str,
+        profile: Optional[JsonDict],
+        is_public: Optional[bool],
+    ) -> None:
         """Add/remove user role
         """
         insertion_values = {}
@@ -783,7 +833,7 @@ class GroupServerStore(GroupServerWorkerStore):
         else:
             update_values["is_public"] = is_public
 
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             values=update_values,
@@ -791,15 +841,34 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="upsert_group_role",
         )
 
-    def remove_group_role(self, group_id, role_id):
-        return self.db_pool.simple_delete(
+    async def remove_group_role(self, group_id: str, role_id: str) -> int:
+        return await self.db_pool.simple_delete(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             desc="remove_group_role",
         )
 
-    def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
-        return self.db_pool.runInteraction(
+    async def add_user_to_summary(
+        self,
+        group_id: str,
+        user_id: str,
+        role_id: str,
+        order: int,
+        is_public: Optional[bool],
+    ) -> None:
+        """Add (or update) user's entry in summary.
+
+        Args:
+            group_id
+            user_id
+            role_id: If not None then adds the role to the end of the summary if
+                its not already there.
+            order: If not None inserts the user at that position, e.g. an order
+                of 1 will put the user first. Otherwise, the user gets added to
+                the end.
+            is_public
+        """
+        await self.db_pool.runInteraction(
             "add_user_to_summary",
             self._add_user_to_summary_txn,
             group_id,
@@ -810,18 +879,26 @@ class GroupServerStore(GroupServerWorkerStore):
         )
 
     def _add_user_to_summary_txn(
-        self, txn, group_id, user_id, role_id, order, is_public
+        self,
+        txn,
+        group_id: str,
+        user_id: str,
+        role_id: str,
+        order: int,
+        is_public: Optional[bool],
     ):
         """Add (or update) user's entry in summary.
 
         Args:
-            group_id (str)
-            user_id (str)
-            role_id (str): If not None then adds the role to the end of
-                the summary if its not already there. [Optional]
-            order (int): If not None inserts the user at that position, e.g.
-                an order of 1 will put the user first. Otherwise, the user gets
-                added to the end.
+            txn
+            group_id
+            user_id
+            role_id: If not None then adds the role to the end of the summary if
+                its not already there.
+            order: If not None inserts the user at that position, e.g. an order
+                of 1 will put the user first. Otherwise, the user gets added to
+                the end.
+            is_public
         """
         user_in_group = self.db_pool.simple_select_one_onecol_txn(
             txn,
@@ -922,46 +999,47 @@ class GroupServerStore(GroupServerWorkerStore):
                 },
             )
 
-    def remove_user_from_summary(self, group_id, user_id, role_id):
+    async def remove_user_from_summary(
+        self, group_id: str, user_id: str, role_id: str
+    ) -> int:
         if role_id is None:
             role_id = _DEFAULT_ROLE_ID
 
-        return self.db_pool.simple_delete(
+        return await self.db_pool.simple_delete(
             table="group_summary_users",
             keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
             desc="remove_user_from_summary",
         )
 
-    def add_group_invite(self, group_id, user_id):
+    async def add_group_invite(self, group_id: str, user_id: str) -> None:
         """Record that the group server has invited a user
         """
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             table="group_invites",
             values={"group_id": group_id, "user_id": user_id},
             desc="add_group_invite",
         )
 
-    def add_user_to_group(
+    async def add_user_to_group(
         self,
-        group_id,
-        user_id,
-        is_admin=False,
-        is_public=True,
-        local_attestation=None,
-        remote_attestation=None,
-    ):
+        group_id: str,
+        user_id: str,
+        is_admin: bool = False,
+        is_public: bool = True,
+        local_attestation: dict = None,
+        remote_attestation: dict = None,
+    ) -> None:
         """Add a user to the group server.
 
         Args:
-            group_id (str)
-            user_id (str)
-            is_admin (bool)
-            is_public (bool)
-            local_attestation (dict): The attestation the GS created to give
-                to the remote server. Optional if the user and group are on the
-                same server
-            remote_attestation (dict): The attestation given to GS by remote
+            group_id
+            user_id
+            is_admin
+            is_public
+            local_attestation: The attestation the GS created to give to the remote
                 server. Optional if the user and group are on the same server
+            remote_attestation: The attestation given to GS by remote server.
+                Optional if the user and group are on the same server
         """
 
         def _add_user_to_group_txn(txn):
@@ -1004,9 +1082,9 @@ class GroupServerStore(GroupServerWorkerStore):
                     },
                 )
 
-        return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
+        await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
 
-    def remove_user_from_group(self, group_id, user_id):
+    async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
         def _remove_user_from_group_txn(txn):
             self.db_pool.simple_delete_txn(
                 txn,
@@ -1034,26 +1112,30 @@ class GroupServerStore(GroupServerWorkerStore):
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "remove_user_from_group", _remove_user_from_group_txn
         )
 
-    def add_room_to_group(self, group_id, room_id, is_public):
-        return self.db_pool.simple_insert(
+    async def add_room_to_group(
+        self, group_id: str, room_id: str, is_public: bool
+    ) -> None:
+        await self.db_pool.simple_insert(
             table="group_rooms",
             values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
             desc="add_room_to_group",
         )
 
-    def update_room_in_group_visibility(self, group_id, room_id, is_public):
-        return self.db_pool.simple_update(
+    async def update_room_in_group_visibility(
+        self, group_id: str, room_id: str, is_public: bool
+    ) -> int:
+        return await self.db_pool.simple_update(
             table="group_rooms",
             keyvalues={"group_id": group_id, "room_id": room_id},
             updatevalues={"is_public": is_public},
             desc="update_room_in_group_visibility",
         )
 
-    def remove_room_from_group(self, group_id, room_id):
+    async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
         def _remove_room_from_group_txn(txn):
             self.db_pool.simple_delete_txn(
                 txn,
@@ -1067,14 +1149,16 @@ class GroupServerStore(GroupServerWorkerStore):
                 keyvalues={"group_id": group_id, "room_id": room_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "remove_room_from_group", _remove_room_from_group_txn
         )
 
-    def update_group_publicity(self, group_id, user_id, publicise):
+    async def update_group_publicity(
+        self, group_id: str, user_id: str, publicise: bool
+    ) -> None:
         """Update whether the user is publicising their membership of the group
         """
-        return self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="local_group_membership",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={"is_publicised": publicise},
@@ -1181,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        with self._group_updates_id_gen.get_next() as next_id:
+        with await self._group_updates_id_gen.get_next() as next_id:
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
@@ -1213,20 +1297,24 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="update_group_profile",
         )
 
-    def update_attestation_renewal(self, group_id, user_id, attestation):
+    async def update_attestation_renewal(
+        self, group_id: str, user_id: str, attestation: dict
+    ) -> None:
         """Update an attestation that we have renewed
         """
-        return self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
             desc="update_attestation_renewal",
         )
 
-    def update_remote_attestion(self, group_id, user_id, attestation):
+    async def update_remote_attestion(
+        self, group_id: str, user_id: str, attestation: dict
+    ) -> None:
         """Update an attestation that a remote has renewed
         """
-        return self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="group_attestations_remote",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={
@@ -1236,16 +1324,16 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="update_remote_attestion",
         )
 
-    def remove_attestation_renewal(self, group_id, user_id):
+    async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int:
         """Remove an attestation that we thought we should renew, but actually
         shouldn't. Ideally this would never get called as we would never
         incorrectly try and do attestations for local users on local groups.
 
         Args:
-            group_id (str)
-            user_id (str)
+            group_id
+            user_id
         """
-        return self.db_pool.simple_delete(
+        return await self.db_pool.simple_delete(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
             desc="remove_attestation_renewal",
@@ -1254,14 +1342,11 @@ class GroupServerStore(GroupServerWorkerStore):
     def get_group_stream_token(self):
         return self._group_updates_id_gen.get_current_token()
 
-    def delete_group(self, group_id):
+    async def delete_group(self, group_id: str) -> None:
         """Deletes a group fully from the database.
 
         Args:
-            group_id (str)
-
-        Returns:
-            Deferred
+            group_id: The group ID to delete.
         """
 
         def _delete_group_txn(txn):
@@ -1285,4 +1370,4 @@ class GroupServerStore(GroupServerWorkerStore):
                     txn, table=table, keyvalues={"group_id": group_id}
                 )
 
-        return self.db_pool.runInteraction("delete_group", _delete_group_txn)
+        await self.db_pool.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 384e9c5eb0..ad43bb05ab 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@
 
 import itertools
 import logging
+from typing import Dict, Iterable, List, Optional, Tuple
 
 from signedjson.key import decode_verify_key_bytes
 
@@ -41,16 +42,17 @@ class KeyStore(SQLBaseStore):
     @cachedList(
         cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
     )
-    def get_server_verify_keys(self, server_name_and_key_ids):
+    async def get_server_verify_keys(
+        self, server_name_and_key_ids: Iterable[Tuple[str, str]]
+    ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
         """
         Args:
-            server_name_and_key_ids (iterable[Tuple[str, str]]):
+            server_name_and_key_ids:
                 iterable of (server_name, key-id) tuples to fetch keys for
 
         Returns:
-            Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
-                map from (server_name, key_id) -> FetchKeyResult, or None if the key is
-                unknown
+            A map from (server_name, key_id) -> FetchKeyResult, or None if the
+            key is unknown
         """
         keys = {}
 
@@ -86,14 +88,19 @@ class KeyStore(SQLBaseStore):
                 _get_keys(txn, batch)
             return keys
 
-        return self.db_pool.runInteraction("get_server_verify_keys", _txn)
+        return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
 
-    def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+    async def store_server_verify_keys(
+        self,
+        from_server: str,
+        ts_added_ms: int,
+        verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+    ) -> None:
         """Stores NACL verification keys for remote servers.
         Args:
-            from_server (str): Where the verification keys were looked up
-            ts_added_ms (int): The time to record that the key was added
-            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+            from_server: Where the verification keys were looked up
+            ts_added_ms: The time to record that the key was added
+            verify_keys:
                 keys to be stored. Each entry is a triplet of
                 (server_name, key_id, key).
         """
@@ -115,13 +122,7 @@ class KeyStore(SQLBaseStore):
             # param, which is itself the 2-tuple (server_name, key_id).
             invalidations.append((server_name, key_id))
 
-        def _invalidate(res):
-            f = self._get_server_verify_key.invalidate
-            for i in invalidations:
-                f((i,))
-            return res
-
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "store_server_verify_keys",
             self.db_pool.simple_upsert_many_txn,
             table="server_signature_keys",
@@ -134,24 +135,34 @@ class KeyStore(SQLBaseStore):
                 "verify_key",
             ),
             value_values=value_values,
-        ).addCallback(_invalidate)
+        )
 
-    def store_server_keys_json(
-        self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
-    ):
+        invalidate = self._get_server_verify_key.invalidate
+        for i in invalidations:
+            invalidate((i,))
+
+    async def store_server_keys_json(
+        self,
+        server_name: str,
+        key_id: str,
+        from_server: str,
+        ts_now_ms: int,
+        ts_expires_ms: int,
+        key_json_bytes: bytes,
+    ) -> None:
         """Stores the JSON bytes for a set of keys from a server
         The JSON should be signed by the originating server, the intermediate
         server, and by this server. Updates the value for the
         (server_name, key_id, from_server) triplet if one already existed.
         Args:
-            server_name (str): The name of the server.
-            key_id (str): The identifer of the key this JSON is for.
-            from_server (str): The server this JSON was fetched from.
-            ts_now_ms (int): The time now in milliseconds.
-            ts_valid_until_ms (int): The time when this json stops being valid.
-            key_json (bytes): The encoded JSON.
+            server_name: The name of the server.
+            key_id: The identifer of the key this JSON is for.
+            from_server: The server this JSON was fetched from.
+            ts_now_ms: The time now in milliseconds.
+            ts_valid_until_ms: The time when this json stops being valid.
+            key_json_bytes: The encoded JSON.
         """
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="server_keys_json",
             keyvalues={
                 "server_name": server_name,
@@ -169,7 +180,9 @@ class KeyStore(SQLBaseStore):
             desc="store_server_keys_json",
         )
 
-    def get_server_keys_json(self, server_keys):
+    async def get_server_keys_json(
+        self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
+    ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
         """Retrive the key json for a list of server_keys and key ids.
         If no keys are found for a given server, key_id and source then
         that server, key_id, and source triplet entry will be an empty list.
@@ -178,8 +191,7 @@ class KeyStore(SQLBaseStore):
         Args:
             server_keys (list): List of (server_name, key_id, source) triplets.
         Returns:
-            Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
-                Dict mapping (server_name, key_id, source) triplets to lists of dicts
+            A mapping from (server_name, key_id, source) triplets to a list of dicts
         """
 
         def _get_server_keys_json_txn(txn):
@@ -205,6 +217,6 @@ class KeyStore(SQLBaseStore):
                 results[(server_name, key_id, from_server)] = rows
             return results
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_server_keys_json", _get_server_keys_json_txn
         )
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 80fc1cd009..1d76c761a6 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,9 +12,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 
+BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
+    "media_repository_drop_index_wo_method"
+)
+
 
 class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -30,6 +36,59 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
             where_clause="url_cache IS NOT NULL",
         )
 
+        # The following the updates add the method to the unique constraint of
+        # the thumbnail databases. That fixes an issue, where thumbnails of the
+        # same resolution, but different methods could overwrite one another.
+        # This can happen with custom thumbnail configs or with dynamic thumbnailing.
+        self.db_pool.updates.register_background_index_update(
+            update_name="local_media_repository_thumbnails_method_idx",
+            index_name="local_media_repository_thumbn_media_id_width_height_method_key",
+            table="local_media_repository_thumbnails",
+            columns=[
+                "media_id",
+                "thumbnail_width",
+                "thumbnail_height",
+                "thumbnail_type",
+                "thumbnail_method",
+            ],
+            unique=True,
+        )
+
+        self.db_pool.updates.register_background_index_update(
+            update_name="remote_media_repository_thumbnails_method_idx",
+            index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key",
+            table="remote_media_cache_thumbnails",
+            columns=[
+                "media_origin",
+                "media_id",
+                "thumbnail_width",
+                "thumbnail_height",
+                "thumbnail_type",
+                "thumbnail_method",
+            ],
+            unique=True,
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD,
+            self._drop_media_index_without_method,
+        )
+
+    async def _drop_media_index_without_method(self, progress, batch_size):
+        def f(txn):
+            txn.execute(
+                "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+            )
+            txn.execute(
+                "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+            )
+
+        await self.db_pool.runInteraction("drop_media_indices_without_method", f)
+        await self.db_pool.updates._end_background_update(
+            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
+        )
+        return 1
+
 
 class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
@@ -37,12 +96,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
 
-    def get_local_media(self, media_id):
+    async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
         """Get the metadata for a local piece of media
+
         Returns:
             None if the media_id doesn't exist.
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -57,7 +117,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_local_media",
         )
 
-    def store_local_media(
+    async def store_local_media(
         self,
         media_id,
         media_type,
@@ -66,8 +126,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         media_length,
         user_id,
         url_cache=None,
-    ):
-        return self.db_pool.simple_insert(
+    ) -> None:
+        await self.db_pool.simple_insert(
             "local_media_repository",
             {
                 "media_id": media_id,
@@ -81,16 +141,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_media",
         )
 
-    def mark_local_media_as_safe(self, media_id: str):
+    async def mark_local_media_as_safe(self, media_id: str) -> None:
         """Mark a local media as safe from quarantining."""
-        return self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="local_media_repository",
             keyvalues={"media_id": media_id},
             updatevalues={"safe_from_quarantine": True},
             desc="mark_local_media_as_safe",
         )
 
-    def get_url_cache(self, url, ts):
+    async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
         """Get the media_id and ts for a cached URL as of the given timestamp
         Returns:
             None if the URL isn't cached.
@@ -136,12 +196,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 )
             )
 
-        return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
+        return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
 
-    def store_url_cache(
+    async def store_url_cache(
         self, url, response_code, etag, expires_ts, og, media_id, download_ts
     ):
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "local_media_repository_url_cache",
             {
                 "url": url,
@@ -155,8 +215,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_url_cache",
         )
 
-    def get_local_media_thumbnails(self, media_id):
-        return self.db_pool.simple_select_list(
+    async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
+        return await self.db_pool.simple_select_list(
             "local_media_repository_thumbnails",
             {"media_id": media_id},
             (
@@ -169,7 +229,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_local_media_thumbnails",
         )
 
-    def store_local_thumbnail(
+    async def store_local_thumbnail(
         self,
         media_id,
         thumbnail_width,
@@ -178,7 +238,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         thumbnail_method,
         thumbnail_length,
     ):
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "local_media_repository_thumbnails",
             {
                 "media_id": media_id,
@@ -191,8 +251,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_thumbnail",
         )
 
-    def get_cached_remote_media(self, origin, media_id):
-        return self.db_pool.simple_select_one(
+    async def get_cached_remote_media(
+        self, origin, media_id: str
+    ) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -207,7 +269,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_cached_remote_media",
         )
 
-    def store_cached_remote_media(
+    async def store_cached_remote_media(
         self,
         origin,
         media_id,
@@ -217,7 +279,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         upload_name,
         filesystem_id,
     ):
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "remote_media_cache",
             {
                 "media_origin": origin,
@@ -232,12 +294,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_cached_remote_media",
         )
 
-    def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+    async def update_cached_last_access_time(
+        self,
+        local_media: Iterable[str],
+        remote_media: Iterable[Tuple[str, str]],
+        time_ms: int,
+    ):
         """Updates the last access time of the given media
 
         Args:
-            local_media (iterable[str]): Set of media_ids
-            remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+            local_media: Set of media_ids
+            remote_media: Set of (server_name, media_id)
             time_ms: Current time in milliseconds
         """
 
@@ -262,12 +329,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
         )
 
-    def get_remote_media_thumbnails(self, origin, media_id):
-        return self.db_pool.simple_select_list(
+    async def get_remote_media_thumbnails(
+        self, origin: str, media_id: str
+    ) -> List[Dict[str, Any]]:
+        return await self.db_pool.simple_select_list(
             "remote_media_cache_thumbnails",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -281,7 +350,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_remote_media_thumbnails",
         )
 
-    def store_remote_media_thumbnail(
+    async def store_remote_media_thumbnail(
         self,
         origin,
         media_id,
@@ -292,7 +361,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         thumbnail_method,
         thumbnail_length,
     ):
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "remote_media_cache_thumbnails",
             {
                 "media_origin": origin,
@@ -307,18 +376,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_remote_media_thumbnail",
         )
 
-    def get_remote_media_before(self, before_ts):
+    async def get_remote_media_before(self, before_ts):
         sql = (
             "SELECT media_origin, media_id, filesystem_id"
             " FROM remote_media_cache"
             " WHERE last_access_ts < ?"
         )
 
-        return self.db_pool.execute(
+        return await self.db_pool.execute(
             "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
         )
 
-    def delete_remote_media(self, media_origin, media_id):
+    async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
         def delete_remote_media_txn(txn):
             self.db_pool.simple_delete_txn(
                 txn,
@@ -331,11 +400,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 keyvalues={"media_origin": media_origin, "media_id": media_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_remote_media", delete_remote_media_txn
         )
 
-    def get_expired_url_cache(self, now_ts):
+    async def get_expired_url_cache(self, now_ts: int) -> List[str]:
         sql = (
             "SELECT media_id FROM local_media_repository_url_cache"
             " WHERE expires_ts < ?"
@@ -347,7 +416,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             txn.execute(sql, (now_ts,))
             return [row[0] for row in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_expired_url_cache", _get_expired_url_cache_txn
         )
 
@@ -364,7 +433,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "delete_url_cache", _delete_url_cache_txn
         )
 
-    def get_url_cache_media_before(self, before_ts):
+    async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
         sql = (
             "SELECT media_id FROM local_media_repository"
             " WHERE created_ts < ? AND url_cache IS NOT NULL"
@@ -376,7 +445,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             txn.execute(sql, (before_ts,))
             return [row[0] for row in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_url_cache_media_before", _get_url_cache_media_before_txn
         )
 
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e71cdd2cb4..1d793d3deb 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List
+from typing import Dict, List
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
@@ -33,11 +33,11 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         self.hs = hs
 
     @cached(num_args=0)
-    def get_monthly_active_count(self):
+    async def get_monthly_active_count(self) -> int:
         """Generates current count of monthly active users
 
         Returns:
-            Defered[int]: Number of current monthly active users
+            Number of current monthly active users
         """
 
         def _count_users(txn):
@@ -46,10 +46,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        return self.db_pool.runInteraction("count_users", _count_users)
+        return await self.db_pool.runInteraction("count_users", _count_users)
 
     @cached(num_args=0)
-    def get_monthly_active_count_by_service(self):
+    async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
         """Generates current count of monthly active users broken down by service.
         A service is typically an appservice but also includes native matrix users.
         Since the `monthly_active_users` table is populated from the `user_ips` table
@@ -57,8 +57,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         method to return anything other than native matrix users.
 
         Returns:
-            Deferred[dict]: dict that includes a mapping between app_service_id
-                and the number of occurrences.
+            A mapping between app_service_id and the number of occurrences.
 
         """
 
@@ -74,7 +73,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             result = txn.fetchall()
             return dict(result)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_users_by_service", _count_users_by_service
         )
 
@@ -99,17 +98,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         return users
 
     @cached(num_args=1)
-    def user_last_seen_monthly_active(self, user_id):
+    async def user_last_seen_monthly_active(self, user_id: str) -> int:
         """
-            Checks if a given user is part of the monthly active user group
-            Arguments:
-                user_id (str): user to add/update
-            Return:
-                Deferred[int] : timestamp since last seen, None if never seen
+        Checks if a given user is part of the monthly active user group
 
+        Arguments:
+            user_id: user to add/update
+
+        Return:
+            Timestamp since last seen, None if never seen
         """
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="monthly_active_users",
             keyvalues={"user_id": user_id},
             retcol="timestamp",
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index dcd1ff911a..2aac64901b 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -1,9 +1,13 @@
+from typing import Optional
+
 from synapse.storage._base import SQLBaseStore
 
 
 class OpenIdStore(SQLBaseStore):
-    def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
-        return self.db_pool.simple_insert(
+    async def insert_open_id_token(
+        self, token: str, ts_valid_until_ms: int, user_id: str
+    ) -> None:
+        await self.db_pool.simple_insert(
             table="open_id_tokens",
             values={
                 "token": token,
@@ -13,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
             desc="insert_open_id_token",
         )
 
-    def get_user_id_for_open_id_token(self, token, ts_now_ms):
+    async def get_user_id_for_open_id_token(
+        self, token: str, ts_now_ms: int
+    ) -> Optional[str]:
         def get_user_id_for_token_txn(txn):
             sql = (
                 "SELECT user_id FROM open_id_tokens"
@@ -28,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
             else:
                 return rows[0][0]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_user_id_for_token", get_user_id_for_token_txn
         )
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 59ba12820a..c9f655dfb7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -15,15 +15,15 @@
 
 from typing import List, Tuple
 
+from synapse.api.presence import UserPresenceState
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.presence import UserPresenceState
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
 
 
 class PresenceStore(SQLBaseStore):
     async def update_presence(self, presence_states):
-        stream_ordering_manager = self._presence_id_gen.get_next_mult(
+        stream_ordering_manager = await self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
 
@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_presence_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
+        cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
     )
-    def get_presence_for_users(self, user_ids):
-        rows = yield self.db_pool.simple_select_many_batch(
+    async def get_presence_for_users(self, user_ids):
+        rows = await self.db_pool.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
             iterable=user_ids,
@@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore):
 
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
-
-    def allow_presence_visible(self, observed_localpart, observer_userid):
-        return self.db_pool.simple_insert(
-            table="presence_allow_inbound",
-            values={
-                "observed_user_id": observed_localpart,
-                "observer_user_id": observer_userid,
-            },
-            desc="allow_presence_visible",
-            or_ignore=True,
-        )
-
-    def disallow_presence_visible(self, observed_localpart, observer_userid):
-        return self.db_pool.simple_delete_one(
-            table="presence_allow_inbound",
-            keyvalues={
-                "observed_user_id": observed_localpart,
-                "observer_user_id": observer_userid,
-            },
-            desc="disallow_presence_visible",
-        )
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8261357d4..d2e0685e9e 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, Optional
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
 
 
 class ProfileWorkerStore(SQLBaseStore):
-    async def get_profileinfo(self, user_localpart):
+    async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
         try:
             profile = await self.db_pool.simple_select_one(
                 table="profiles",
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
             avatar_url=profile["avatar_url"], display_name=profile["displayname"]
         )
 
-    def get_profile_displayname(self, user_localpart):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_profile_displayname(self, user_localpart: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="displayname",
             desc="get_profile_displayname",
         )
 
-    def get_profile_avatar_url(self, user_localpart):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_profile_avatar_url(self, user_localpart: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="avatar_url",
             desc="get_profile_avatar_url",
         )
 
-    def get_from_remote_profile_cache(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_from_remote_profile_cache(
+        self, user_id: str
+    ) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             retcols=("displayname", "avatar_url"),
@@ -63,21 +66,25 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="get_from_remote_profile_cache",
         )
 
-    def create_profile(self, user_localpart):
-        return self.db_pool.simple_insert(
+    async def create_profile(self, user_localpart: str) -> None:
+        await self.db_pool.simple_insert(
             table="profiles", values={"user_id": user_localpart}, desc="create_profile"
         )
 
-    def set_profile_displayname(self, user_localpart, new_displayname):
-        return self.db_pool.simple_update_one(
+    async def set_profile_displayname(
+        self, user_localpart: str, new_displayname: str
+    ) -> None:
+        await self.db_pool.simple_update_one(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             updatevalues={"displayname": new_displayname},
             desc="set_profile_displayname",
         )
 
-    def set_profile_avatar_url(self, user_localpart, new_avatar_url):
-        return self.db_pool.simple_update_one(
+    async def set_profile_avatar_url(
+        self, user_localpart: str, new_avatar_url: str
+    ) -> None:
+        await self.db_pool.simple_update_one(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             updatevalues={"avatar_url": new_avatar_url},
@@ -86,13 +93,15 @@ class ProfileWorkerStore(SQLBaseStore):
 
 
 class ProfileStore(ProfileWorkerStore):
-    def add_remote_profile_cache(self, user_id, displayname, avatar_url):
+    async def add_remote_profile_cache(
+        self, user_id: str, displayname: str, avatar_url: str
+    ) -> None:
         """Ensure we are caching the remote user's profiles.
 
         This should only be called when `is_subscribed_remote_profile_for_user`
         would return true for the user.
         """
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             values={
@@ -103,8 +112,10 @@ class ProfileStore(ProfileWorkerStore):
             desc="add_remote_profile_cache",
         )
 
-    def update_remote_profile_cache(self, user_id, displayname, avatar_url):
-        return self.db_pool.simple_update(
+    async def update_remote_profile_cache(
+        self, user_id: str, displayname: str, avatar_url: str
+    ) -> int:
+        return await self.db_pool.simple_update(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             updatevalues={
@@ -127,7 +138,9 @@ class ProfileStore(ProfileWorkerStore):
                 desc="delete_remote_profile_cache",
             )
 
-    def get_remote_profile_cache_entries_that_expire(self, last_checked):
+    async def get_remote_profile_cache_entries_that_expire(
+        self, last_checked: int
+    ) -> Dict[str, str]:
         """Get all users who haven't been checked since `last_checked`
         """
 
@@ -142,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
 
             return self.db_pool.cursor_to_dict(txn)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_remote_profile_cache_entries_that_expire",
             _get_remote_profile_cache_entries_that_expire_txn,
         )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3526b6fd66..d7a03cbf7d 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Tuple
+from typing import Any, List, Set, Tuple
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore
@@ -25,25 +25,24 @@ logger = logging.getLogger(__name__)
 
 
 class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
-    def purge_history(self, room_id, token, delete_local_events):
+    async def purge_history(
+        self, room_id: str, token: str, delete_local_events: bool
+    ) -> Set[int]:
         """Deletes room history before a certain point
 
         Args:
-            room_id (str):
-
-            token (str): A topological token to delete events before
-
-            delete_local_events (bool):
+            room_id:
+            token: A topological token to delete events before
+            delete_local_events:
                 if True, we will delete local events as well as remote ones
                 (instead of just marking them as outliers and deleting their
                 state groups).
 
         Returns:
-            Deferred[set[int]]: The set of state groups that are referenced by
-            deleted events.
+            The set of state groups that are referenced by deleted events.
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "purge_history",
             self._purge_history_txn,
             room_id,
@@ -70,6 +69,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         #     room_depth
         #     state_groups
         #     state_groups_state
+        #     destination_rooms
 
         # we will build a temporary table listing the events so that we don't
         # have to keep shovelling the list back and forth across the
@@ -283,17 +283,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
 
         return referenced_state_groups
 
-    def purge_room(self, room_id):
+    async def purge_room(self, room_id: str) -> List[int]:
         """Deletes all record of a room
 
         Args:
-            room_id (str)
+            room_id
 
         Returns:
-            Deferred[List[int]]: The list of state groups to delete.
+            The list of state groups to delete.
         """
-
-        return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
+        return await self.db_pool.runInteraction(
+            "purge_room", self._purge_room_txn, room_id
+        )
 
     def _purge_room_txn(self, txn, room_id):
         # First we fetch all the state groups that should be deleted, before
@@ -336,6 +337,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         # and finally, the tables with an index on room_id (or no useful index)
         for table in (
             "current_state_events",
+            "destination_rooms",
             "event_backward_extremities",
             "event_forward_extremities",
             "event_json",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 6562db5c2b..9790a31998 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -13,13 +13,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import abc
 import logging
 from typing import List, Tuple, Union
 
-from twisted.internet import defer
-
+from synapse.api.errors import NotFoundError, StoreError
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -29,10 +27,11 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.pusher import PusherWorkerStore
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
@@ -82,9 +81,9 @@ class PushRulesWorkerStore(
         super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen = ChainedIdGenerator(
-                self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
-            )  # type: Union[ChainedIdGenerator, SlavedIdTracker]
+            self._push_rules_stream_id_gen = StreamIdGenerator(
+                db_conn, "push_rules_stream", "stream_id"
+            )  # type: Union[StreamIdGenerator, SlavedIdTracker]
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
@@ -115,9 +114,9 @@ class PushRulesWorkerStore(
         """
         raise NotImplementedError()
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_push_rules_for_user(self, user_id):
-        rows = yield self.db_pool.simple_select_list(
+    @cached(max_entries=5000)
+    async def get_push_rules_for_user(self, user_id):
+        rows = await self.db_pool.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
             retcols=(
@@ -133,17 +132,15 @@ class PushRulesWorkerStore(
 
         rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
 
-        enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+        enabled_map = await self.get_push_rules_enabled_for_user(user_id)
 
         use_new_defaults = user_id in self._users_new_default_push_rules
 
-        rules = _load_rules(rows, enabled_map, use_new_defaults)
+        return _load_rules(rows, enabled_map, use_new_defaults)
 
-        return rules
-
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_push_rules_enabled_for_user(self, user_id):
-        results = yield self.db_pool.simple_select_list(
+    @cached(max_entries=5000)
+    async def get_push_rules_enabled_for_user(self, user_id):
+        results = await self.db_pool.simple_select_list(
             table="push_rules_enable",
             keyvalues={"user_name": user_id},
             retcols=("user_name", "rule_id", "enabled"),
@@ -151,9 +148,11 @@ class PushRulesWorkerStore(
         )
         return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
 
-    def have_push_rules_changed_for_user(self, user_id, last_id):
+    async def have_push_rules_changed_for_user(
+        self, user_id: str, last_id: int
+    ) -> bool:
         if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
+            return False
         else:
 
             def have_push_rules_changed_txn(txn):
@@ -165,23 +164,20 @@ class PushRulesWorkerStore(
                 (count,) = txn.fetchone()
                 return bool(count)
 
-            return self.db_pool.runInteraction(
+            return await self.db_pool.runInteraction(
                 "have_push_rules_changed", have_push_rules_changed_txn
             )
 
     @cachedList(
-        cached_method_name="get_push_rules_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
+        cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
     )
-    def bulk_get_push_rules(self, user_ids):
+    async def bulk_get_push_rules(self, user_ids):
         if not user_ids:
             return {}
 
         results = {user_id: [] for user_id in user_ids}
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="push_rules",
             column="user_name",
             iterable=user_ids,
@@ -194,7 +190,7 @@ class PushRulesWorkerStore(
         for row in rows:
             results.setdefault(row["user_name"], []).append(row)
 
-        enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+        enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
 
         for user_id, rules in results.items():
             use_new_defaults = user_id in self._users_new_default_push_rules
@@ -205,14 +201,15 @@ class PushRulesWorkerStore(
 
         return results
 
-    @defer.inlineCallbacks
-    def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+    async def copy_push_rule_from_room_to_room(
+        self, new_room_id: str, user_id: str, rule: dict
+    ) -> None:
         """Copy a single push rule from one room to another for a specific user.
 
         Args:
-            new_room_id (str): ID of the new room.
-            user_id (str): ID of user the push rule belongs to.
-            rule (Dict): A push rule.
+            new_room_id: ID of the new room.
+            user_id : ID of user the push rule belongs to.
+            rule: A push rule.
         """
         # Create new rule id
         rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
@@ -224,7 +221,7 @@ class PushRulesWorkerStore(
                 condition["pattern"] = new_room_id
 
         # Add the rule for the new room
-        yield self.add_push_rule(
+        await self.add_push_rule(
             user_id=user_id,
             rule_id=new_rule_id,
             priority_class=rule["priority_class"],
@@ -232,20 +229,19 @@ class PushRulesWorkerStore(
             actions=rule["actions"],
         )
 
-    @defer.inlineCallbacks
-    def copy_push_rules_from_room_to_room_for_user(
-        self, old_room_id, new_room_id, user_id
-    ):
+    async def copy_push_rules_from_room_to_room_for_user(
+        self, old_room_id: str, new_room_id: str, user_id: str
+    ) -> None:
         """Copy all of the push rules from one room to another for a specific
         user.
 
         Args:
-            old_room_id (str): ID of the old room.
-            new_room_id (str): ID of the new room.
-            user_id (str): ID of user to copy push rules for.
+            old_room_id: ID of the old room.
+            new_room_id: ID of the new room.
+            user_id: ID of user to copy push rules for.
         """
         # Retrieve push rules for this user
-        user_push_rules = yield self.get_push_rules_for_user(user_id)
+        user_push_rules = await self.get_push_rules_for_user(user_id)
 
         # Get rules relating to the old room and copy them to the new room
         for rule in user_push_rules:
@@ -254,21 +250,20 @@ class PushRulesWorkerStore(
                 (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
                 for c in conditions
             ):
-                yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
+                await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
 
     @cachedList(
         cached_method_name="get_push_rules_enabled_for_user",
         list_name="user_ids",
         num_args=1,
-        inlineCallbacks=True,
     )
-    def bulk_get_push_rules_enabled(self, user_ids):
+    async def bulk_get_push_rules_enabled(self, user_ids):
         if not user_ids:
             return {}
 
         results = {user_id: {} for user_id in user_ids}
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="push_rules_enable",
             column="user_name",
             iterable=user_ids,
@@ -332,8 +327,7 @@ class PushRulesWorkerStore(
 
 
 class PushRuleStore(PushRulesWorkerStore):
-    @defer.inlineCallbacks
-    def add_push_rule(
+    async def add_push_rule(
         self,
         user_id,
         rule_id,
@@ -342,13 +336,14 @@ class PushRuleStore(PushRulesWorkerStore):
         actions,
         before=None,
         after=None,
-    ):
+    ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             if before or after:
-                yield self.db_pool.runInteraction(
+                await self.db_pool.runInteraction(
                     "_add_push_rule_relative_txn",
                     self._add_push_rule_relative_txn,
                     stream_id,
@@ -362,7 +357,7 @@ class PushRuleStore(PushRulesWorkerStore):
                     after,
                 )
             else:
-                yield self.db_pool.runInteraction(
+                await self.db_pool.runInteraction(
                     "_add_push_rule_highest_priority_txn",
                     self._add_push_rule_highest_priority_txn,
                     stream_id,
@@ -546,19 +541,43 @@ class PushRuleStore(PushRulesWorkerStore):
                 },
             )
 
-    @defer.inlineCallbacks
-    def delete_push_rule(self, user_id, rule_id):
+        # ensure we have a push_rules_enable row
+        # enabledness defaults to true
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = """
+                INSERT INTO push_rules_enable (id, user_name, rule_id, enabled)
+                VALUES (?, ?, ?, ?)
+                ON CONFLICT DO NOTHING
+            """
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = """
+                INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled)
+                VALUES (?, ?, ?, ?)
+            """
+        else:
+            raise RuntimeError("Unknown database engine")
+
+        new_enable_id = self._push_rules_enable_id_gen.get_next()
+        txn.execute(sql, (new_enable_id, user_id, rule_id, 1))
+
+    async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
         """
         Delete a push rule. Args specify the row to be deleted and can be
         any of the columns in the push_rule table, but below are the
         standard ones
 
         Args:
-            user_id (str): The matrix ID of the push rule owner
-            rule_id (str): The rule_id of the rule to be deleted
+            user_id: The matrix ID of the push rule owner
+            rule_id: The rule_id of the rule to be deleted
         """
 
         def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+            # we don't use simple_delete_one_txn because that would fail if the
+            # user did not have a push_rule_enable row.
+            self.db_pool.simple_delete_txn(
+                txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}
+            )
+
             self.db_pool.simple_delete_one_txn(
                 txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
             )
@@ -567,20 +586,40 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.db_pool.runInteraction(
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
+            await self.db_pool.runInteraction(
                 "delete_push_rule",
                 delete_push_rule_txn,
                 stream_id,
                 event_stream_ordering,
             )
 
-    @defer.inlineCallbacks
-    def set_push_rule_enabled(self, user_id, rule_id, enabled):
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.db_pool.runInteraction(
+    async def set_push_rule_enabled(
+        self, user_id: str, rule_id: str, enabled: bool, is_default_rule: bool
+    ) -> None:
+        """
+        Sets the `enabled` state of a push rule.
+
+        Args:
+            user_id: the user ID of the user who wishes to enable/disable the rule
+                e.g. '@tina:example.org'
+            rule_id: the full rule ID of the rule to be enabled/disabled
+                e.g. 'global/override/.m.rule.roomnotif'
+                  or 'global/override/myCustomRule'
+            enabled: True if the rule is to be enabled, False if it is to be
+                disabled
+            is_default_rule: True if and only if this is a server-default rule.
+                This skips the check for existence (as only user-created rules
+                are always stored in the database `push_rules` table).
+
+        Raises:
+            NotFoundError if the rule does not exist.
+        """
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+            await self.db_pool.runInteraction(
                 "_set_push_rule_enabled_txn",
                 self._set_push_rule_enabled_txn,
                 stream_id,
@@ -588,12 +627,47 @@ class PushRuleStore(PushRulesWorkerStore):
                 user_id,
                 rule_id,
                 enabled,
+                is_default_rule,
             )
 
     def _set_push_rule_enabled_txn(
-        self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+        self,
+        txn,
+        stream_id,
+        event_stream_ordering,
+        user_id,
+        rule_id,
+        enabled,
+        is_default_rule,
     ):
         new_id = self._push_rules_enable_id_gen.get_next()
+
+        if not is_default_rule:
+            # first check it exists; we need to lock for key share so that a
+            # transaction that deletes the push rule will conflict with this one.
+            # We also need a push_rule_enable row to exist for every push_rules
+            # row, otherwise it is possible to simultaneously delete a push rule
+            # (that has no _enable row) and enable it, resulting in a dangling
+            # _enable row. To solve this: we either need to use SERIALISABLE or
+            # ensure we always have a push_rule_enable row for every push_rule
+            # row. We chose the latter.
+            for_key_share = "FOR KEY SHARE"
+            if not isinstance(self.database_engine, PostgresEngine):
+                # For key share is not applicable/available on SQLite
+                for_key_share = ""
+            sql = (
+                """
+                SELECT 1 FROM push_rules
+                WHERE user_name = ? AND rule_id = ?
+                %s
+            """
+                % for_key_share
+            )
+            txn.execute(sql, (user_id, rule_id))
+            if txn.fetchone() is None:
+                # needed to set NOT_FOUND code.
+                raise NotFoundError("Push rule does not exist.")
+
         self.db_pool.simple_upsert_txn(
             txn,
             "push_rules_enable",
@@ -611,8 +685,31 @@ class PushRuleStore(PushRulesWorkerStore):
             op="ENABLE" if enabled else "DISABLE",
         )
 
-    @defer.inlineCallbacks
-    def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+    async def set_push_rule_actions(
+        self,
+        user_id: str,
+        rule_id: str,
+        actions: List[Union[dict, str]],
+        is_default_rule: bool,
+    ) -> None:
+        """
+        Sets the `actions` state of a push rule.
+
+        Will throw NotFoundError if the rule does not exist; the Code for this
+        is NOT_FOUND.
+
+        Args:
+            user_id: the user ID of the user who wishes to enable/disable the rule
+                e.g. '@tina:example.org'
+            rule_id: the full rule ID of the rule to be enabled/disabled
+                e.g. 'global/override/.m.rule.roomnotif'
+                  or 'global/override/myCustomRule'
+            actions: A list of actions (each action being a dict or string),
+                e.g. ["notify", {"set_tweak": "highlight", "value": false}]
+            is_default_rule: True if and only if this is a server-default rule.
+                This skips the check for existence (as only user-created rules
+                are always stored in the database `push_rules` table).
+        """
         actions_json = json_encoder.encode(actions)
 
         def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -634,12 +731,19 @@ class PushRuleStore(PushRulesWorkerStore):
                     update_stream=False,
                 )
             else:
-                self.db_pool.simple_update_one_txn(
-                    txn,
-                    "push_rules",
-                    {"user_name": user_id, "rule_id": rule_id},
-                    {"actions": actions_json},
-                )
+                try:
+                    self.db_pool.simple_update_one_txn(
+                        txn,
+                        "push_rules",
+                        {"user_name": user_id, "rule_id": rule_id},
+                        {"actions": actions_json},
+                    )
+                except StoreError as serr:
+                    if serr.code == 404:
+                        # this sets the NOT_FOUND error Code
+                        raise NotFoundError("Push rule does not exist")
+                    else:
+                        raise
 
             self._insert_push_rules_update_txn(
                 txn,
@@ -651,9 +755,10 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.db_pool.runInteraction(
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
+            await self.db_pool.runInteraction(
                 "set_push_rule_actions",
                 set_push_rule_actions_txn,
                 stream_id,
@@ -681,11 +786,5 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_push_rules_stream_token(self):
-        """Get the position of the push rules stream.
-        Returns a pair of a stream id for the push_rules stream and the
-        room stream ordering it corresponds to."""
-        return self._push_rules_stream_id_gen.get_current_token()
-
     def get_max_push_rules_stream_id(self):
-        return self.get_push_rules_stream_token()[0]
+        return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b5200fbe79..c388468273 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -19,10 +19,8 @@ from typing import Iterable, Iterator, List, Tuple
 
 from canonicaljson import encode_canonical_json
 
-from twisted.internet import defer
-
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
 
 logger = logging.getLogger(__name__)
 
@@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore):
         Drops any rows whose data cannot be decoded
         """
         for r in rows:
-            dataJson = r["data"]
+            data_json = r["data"]
             try:
-                r["data"] = db_to_json(dataJson)
+                r["data"] = db_to_json(data_json)
             except Exception as e:
                 logger.warning(
                     "Invalid JSON in data for pusher %d: %s, %s",
                     r["id"],
-                    dataJson,
+                    data_json,
                     e.args[0],
                 )
                 continue
 
             yield r
 
-    @defer.inlineCallbacks
-    def user_has_pusher(self, user_id):
-        ret = yield self.db_pool.simple_select_one_onecol(
+    async def user_has_pusher(self, user_id):
+        ret = await self.db_pool.simple_select_one_onecol(
             "pushers", {"user_name": user_id}, "id", allow_none=True
         )
         return ret is not None
@@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore):
     def get_pushers_by_user_id(self, user_id):
         return self.get_pushers_by({"user_name": user_id})
 
-    @defer.inlineCallbacks
-    def get_pushers_by(self, keyvalues):
-        ret = yield self.db_pool.simple_select_list(
+    async def get_pushers_by(self, keyvalues):
+        ret = await self.db_pool.simple_select_list(
             "pushers",
             keyvalues,
             [
@@ -87,16 +83,14 @@ class PusherWorkerStore(SQLBaseStore):
         )
         return self._decode_pushers_rows(ret)
 
-    @defer.inlineCallbacks
-    def get_all_pushers(self):
+    async def get_all_pushers(self):
         def get_pushers(txn):
             txn.execute("SELECT * FROM pushers")
             rows = self.db_pool.cursor_to_dict(txn)
 
             return self._decode_pushers_rows(rows)
 
-        rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
-        return rows
+        return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
 
     async def get_all_updated_pushers_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -164,19 +158,16 @@ class PusherWorkerStore(SQLBaseStore):
             "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
         )
 
-    @cachedInlineCallbacks(num_args=1, max_entries=15000)
-    def get_if_user_has_pusher(self, user_id):
+    @cached(num_args=1, max_entries=15000)
+    async def get_if_user_has_pusher(self, user_id):
         # This only exists for the cachedList decorator
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="get_if_user_has_pusher",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
+        cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
     )
-    def get_if_users_have_pushers(self, user_ids):
-        rows = yield self.db_pool.simple_select_many_batch(
+    async def get_if_users_have_pushers(self, user_ids):
+        rows = await self.db_pool.simple_select_many_batch(
             table="pushers",
             column="user_name",
             iterable=user_ids,
@@ -189,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore):
 
         return result
 
-    @defer.inlineCallbacks
-    def update_pusher_last_stream_ordering(
+    async def update_pusher_last_stream_ordering(
         self, app_id, pushkey, user_id, last_stream_ordering
-    ):
-        yield self.db_pool.simple_update_one(
+    ) -> None:
+        await self.db_pool.simple_update_one(
             "pushers",
             {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             {"last_stream_ordering": last_stream_ordering},
             desc="update_pusher_last_stream_ordering",
         )
 
-    @defer.inlineCallbacks
-    def update_pusher_last_stream_ordering_and_success(
-        self, app_id, pushkey, user_id, last_stream_ordering, last_success
-    ):
+    async def update_pusher_last_stream_ordering_and_success(
+        self,
+        app_id: str,
+        pushkey: str,
+        user_id: str,
+        last_stream_ordering: int,
+        last_success: int,
+    ) -> bool:
         """Update the last stream ordering position we've processed up to for
         the given pusher.
 
         Args:
-            app_id (str)
-            pushkey (str)
-            last_stream_ordering (int)
-            last_success (int)
+            app_id
+            pushkey
+            user_id
+            last_stream_ordering
+            last_success
 
         Returns:
-            Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+            True if the pusher still exists; False if it has been deleted.
         """
-        updated = yield self.db_pool.simple_update(
+        updated = await self.db_pool.simple_update(
             table="pushers",
             keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             updatevalues={
@@ -228,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore):
 
         return bool(updated)
 
-    @defer.inlineCallbacks
-    def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
-        yield self.db_pool.simple_update(
+    async def update_pusher_failing_since(
+        self, app_id, pushkey, user_id, failing_since
+    ) -> None:
+        await self.db_pool.simple_update(
             table="pushers",
             keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             updatevalues={"failing_since": failing_since},
             desc="update_pusher_failing_since",
         )
 
-    @defer.inlineCallbacks
-    def get_throttle_params_by_room(self, pusher_id):
-        res = yield self.db_pool.simple_select_list(
+    async def get_throttle_params_by_room(self, pusher_id):
+        res = await self.db_pool.simple_select_list(
             "pusher_throttle",
             {"pusher": pusher_id},
             ["room_id", "last_sent_ts", "throttle_ms"],
@@ -255,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore):
 
         return params_by_room
 
-    @defer.inlineCallbacks
-    def set_throttle_params(self, pusher_id, room_id, params):
+    async def set_throttle_params(self, pusher_id, room_id, params) -> None:
         # no need to lock because `pusher_throttle` has a primary key on
         # (pusher, room_id) so simple_upsert will retry
-        yield self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             "pusher_throttle",
             {"pusher": pusher_id, "room_id": room_id},
             params,
@@ -272,8 +266,7 @@ class PusherStore(PusherWorkerStore):
     def get_pushers_stream_token(self):
         return self._pushers_id_gen.get_current_token()
 
-    @defer.inlineCallbacks
-    def add_pusher(
+    async def add_pusher(
         self,
         user_id,
         access_token,
@@ -287,11 +280,11 @@ class PusherStore(PusherWorkerStore):
         data,
         last_stream_ordering,
         profile_tag="",
-    ):
-        with self._pushers_id_gen.get_next() as stream_id:
+    ) -> None:
+        with await self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so simple_upsert will retry
-            yield self.db_pool.simple_upsert(
+            await self.db_pool.simple_upsert(
                 table="pushers",
                 keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
                 values={
@@ -316,15 +309,16 @@ class PusherStore(PusherWorkerStore):
 
             if user_has_pusher is not True:
                 # invalidate, since we the user might not have had a pusher before
-                yield self.db_pool.runInteraction(
+                await self.db_pool.runInteraction(
                     "add_pusher",
                     self._invalidate_cache_and_stream,
                     self.get_if_user_has_pusher,
                     (user_id,),
                 )
 
-    @defer.inlineCallbacks
-    def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
+    async def delete_pusher_by_app_id_pushkey_user_id(
+        self, app_id, pushkey, user_id
+    ) -> None:
         def delete_pusher_txn(txn, stream_id):
             self._invalidate_cache_and_stream(
                 txn, self.get_if_user_has_pusher, (user_id,)
@@ -350,7 +344,7 @@ class PusherStore(PusherWorkerStore):
                 },
             )
 
-        with self._pushers_id_gen.get_next() as stream_id:
-            yield self.db_pool.runInteraction(
+        with await self._pushers_id_gen.get_next() as stream_id:
+            await self.db_pool.runInteraction(
                 "delete_pusher", delete_pusher_txn, stream_id
             )
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1920a8a152..4a0d5a320e 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@
 
 import abc
 import logging
-from typing import List, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet import defer
 
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
@@ -56,14 +56,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """
         raise NotImplementedError()
 
-    @cachedInlineCallbacks()
-    def get_users_with_read_receipts_in_room(self, room_id):
-        receipts = yield self.get_receipts_for_room(room_id, "m.read")
+    @cached()
+    async def get_users_with_read_receipts_in_room(self, room_id):
+        receipts = await self.get_receipts_for_room(room_id, "m.read")
         return {r["user_id"] for r in receipts}
 
     @cached(num_args=2)
-    def get_receipts_for_room(self, room_id, receipt_type):
-        return self.db_pool.simple_select_list(
+    async def get_receipts_for_room(
+        self, room_id: str, receipt_type: str
+    ) -> List[Dict[str, Any]]:
+        return await self.db_pool.simple_select_list(
             table="receipts_linearized",
             keyvalues={"room_id": room_id, "receipt_type": receipt_type},
             retcols=("user_id", "event_id"),
@@ -71,8 +73,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=3)
-    def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_last_receipt_event_id_for_user(
+        self, user_id: str, room_id: str, receipt_type: str
+    ) -> Optional[str]:
+        return await self.db_pool.simple_select_one_onecol(
             table="receipts_linearized",
             keyvalues={
                 "room_id": room_id,
@@ -84,9 +88,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    @cachedInlineCallbacks(num_args=2)
-    def get_receipts_for_user(self, user_id, receipt_type):
-        rows = yield self.db_pool.simple_select_list(
+    @cached(num_args=2)
+    async def get_receipts_for_user(self, user_id, receipt_type):
+        rows = await self.db_pool.simple_select_list(
             table="receipts_linearized",
             keyvalues={"user_id": user_id, "receipt_type": receipt_type},
             retcols=("room_id", "event_id"),
@@ -95,8 +99,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return {row["room_id"]: row["event_id"] for row in rows}
 
-    @defer.inlineCallbacks
-    def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
+    async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
         def f(txn):
             sql = (
                 "SELECT rl.room_id, rl.event_id,"
@@ -110,7 +113,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id,))
             return txn.fetchall()
 
-        rows = yield self.db_pool.runInteraction(
+        rows = await self.db_pool.runInteraction(
             "get_receipts_for_user_with_orderings", f
         )
         return {
@@ -122,56 +125,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    @defer.inlineCallbacks
-    def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+    async def get_linearized_receipts_for_rooms(
+        self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+    ) -> List[dict]:
         """Get receipts for multiple rooms for sending to clients.
 
         Args:
-            room_ids (list): List of room_ids.
-            to_key (int): Max stream id to fetch receipts upto.
-            from_key (int): Min stream id to fetch receipts from. None fetches
+            room_id: List of room_ids.
+            to_key: Max stream id to fetch receipts upto.
+            from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
         Returns:
-            list: A list of receipts.
+            A list of receipts.
         """
         room_ids = set(room_ids)
 
         if from_key is not None:
             # Only ask the database about rooms where there have been new
             # receipts added since `from_key`
-            room_ids = yield self._receipts_stream_cache.get_entities_changed(
+            room_ids = self._receipts_stream_cache.get_entities_changed(
                 room_ids, from_key
             )
 
-        results = yield self._get_linearized_receipts_for_rooms(
+        results = await self._get_linearized_receipts_for_rooms(
             room_ids, to_key, from_key=from_key
         )
 
         return [ev for res in results.values() for ev in res]
 
-    def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+    async def get_linearized_receipts_for_room(
+        self, room_id: str, to_key: int, from_key: Optional[int] = None
+    ) -> List[dict]:
         """Get receipts for a single room for sending to clients.
 
         Args:
-            room_ids (str): The room id.
-            to_key (int): Max stream id to fetch receipts upto.
-            from_key (int): Min stream id to fetch receipts from. None fetches
+            room_ids: The room id.
+            to_key: Max stream id to fetch receipts upto.
+            from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
         Returns:
-            Deferred[list]: A list of receipts.
+            A list of receipts.
         """
         if from_key is not None:
             # Check the cache first to see if any new receipts have been added
             # since`from_key`. If not we can no-op.
             if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
-                defer.succeed([])
+                return []
 
-        return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+        return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
 
-    @cachedInlineCallbacks(num_args=3, tree=True)
-    def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+    @cached(num_args=3, tree=True)
+    async def _get_linearized_receipts_for_room(
+        self, room_id: str, to_key: int, from_key: Optional[int] = None
+    ) -> List[dict]:
         """See get_linearized_receipts_for_room
         """
 
@@ -195,7 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return rows
 
-        rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
+        rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
 
         if not rows:
             return []
@@ -212,9 +220,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
         cached_method_name="_get_linearized_receipts_for_room",
         list_name="room_ids",
         num_args=3,
-        inlineCallbacks=True,
     )
-    def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+    async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
         if not room_ids:
             return {}
 
@@ -243,7 +250,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return self.db_pool.cursor_to_dict(txn)
 
-        txn_results = yield self.db_pool.runInteraction(
+        txn_results = await self.db_pool.runInteraction(
             "_get_linearized_receipts_for_rooms", f
         )
 
@@ -269,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
         }
         return results
 
-    def get_users_sent_receipts_between(self, last_id: int, current_id: int):
+    async def get_users_sent_receipts_between(
+        self, last_id: int, current_id: int
+    ) -> List[str]:
         """Get all users who sent receipts between `last_id` exclusive and
         `current_id` inclusive.
 
         Returns:
-            Deferred[List[str]]
+            The list of users.
         """
 
         if last_id == current_id:
@@ -289,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return [r[0] for r in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
         )
 
@@ -346,7 +355,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     def _invalidate_get_users_with_receipts_in_room(
-        self, room_id, receipt_type, user_id
+        self, room_id: str, receipt_type: str, user_id: str
     ):
         if receipt_type != "m.read":
             return
@@ -472,15 +481,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
 
         return rx_ts
 
-    @defer.inlineCallbacks
-    def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+    async def insert_receipt(
+        self,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_ids: List[str],
+        data: dict,
+    ) -> Optional[Tuple[int, int]]:
         """Insert a receipt, either from local client or remote server.
 
         Automatically does conversion between linearized and graph
         representations.
         """
         if not event_ids:
-            return
+            return None
 
         if len(event_ids) == 1:
             linearized_event_id = event_ids[0]
@@ -507,13 +522,12 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 else:
                     raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
 
-            linearized_event_id = yield self.db_pool.runInteraction(
+            linearized_event_id = await self.db_pool.runInteraction(
                 "insert_receipt_conv", graph_to_linear
             )
 
-        stream_id_manager = self._receipts_id_gen.get_next()
-        with stream_id_manager as stream_id:
-            event_ts = yield self.db_pool.runInteraction(
+        with await self._receipts_id_gen.get_next() as stream_id:
+            event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
                 room_id,
@@ -535,14 +549,16 @@ class ReceiptsStore(ReceiptsWorkerStore):
             now - event_ts,
         )
 
-        yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
+        await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
 
         max_persisted_id = self._receipts_id_gen.get_current_token()
 
         return stream_id, max_persisted_id
 
-    def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
-        return self.db_pool.runInteraction(
+    async def insert_graph_receipt(
+        self, room_id, receipt_type, user_id, event_ids, data
+    ):
+        return await self.db_pool.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
             room_id,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 402ae25571..01f20c03c2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,7 @@
 
 import logging
 import re
-from typing import Dict, List, Optional
-
-from twisted.internet.defer import Deferred
+from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -48,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    def get_user_by_id(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="users",
             keyvalues={"name": user_id},
             retcols=[
@@ -86,22 +84,22 @@ class RegistrationWorkerStore(SQLBaseStore):
         return is_trial
 
     @cached()
-    def get_user_by_access_token(self, token):
+    async def get_user_by_access_token(self, token: str) -> Optional[dict]:
         """Get a user from the given access token.
 
         Args:
-            token (str): The access token of a user.
+            token: The access token of a user.
         Returns:
-            defer.Deferred: None, if the token did not match, otherwise dict
-                including the keys `name`, `is_guest`, `device_id`, `token_id`,
-                `valid_until_ms`.
+            None, if the token did not match, otherwise dict
+            including the keys `name`, `is_guest`, `device_id`, `token_id`,
+            `valid_until_ms`.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
         )
 
     @cached()
-    async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
+    async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]:
         """Get the expiration timestamp for the account bearing a given user ID.
 
         Args:
@@ -283,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return bool(res) if res else False
 
-    def set_server_admin(self, user, admin):
+    async def set_server_admin(self, user: UserID, admin: bool) -> None:
         """Sets whether a user is an admin of this homeserver.
 
         Args:
-            user (UserID): user ID of the user to test
-            admin (bool): true iff the user is to be a server admin,
-                false otherwise.
+            user: user ID of the user to test
+            admin: true iff the user is to be a server admin, false otherwise.
         """
 
         def set_server_admin_txn(txn):
@@ -300,11 +297,11 @@ class RegistrationWorkerStore(SQLBaseStore):
                 txn, self.get_user_by_id, (user.to_string(),)
             )
 
-        return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+        await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
     def _query_for_auth(self, txn, token):
         sql = (
-            "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+            "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
             " access_tokens.device_id, access_tokens.valid_until_ms"
             " FROM users"
             " INNER JOIN access_tokens on users.name = access_tokens.user_id"
@@ -366,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
         return True if res == UserTypes.SUPPORT else False
 
-    def get_users_by_id_case_insensitive(self, user_id):
+    async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
         """Gets users that match user_id case insensitively.
-        Returns a mapping of user_id -> password_hash.
+
+        Returns:
+             A mapping of user_id -> password_hash.
         """
 
         def f(txn):
@@ -376,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id,))
             return dict(txn)
 
-        return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
+        return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
 
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
@@ -410,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_users", _count_users)
 
-    def count_daily_user_type(self):
+    async def count_daily_user_type(self) -> Dict[str, int]:
         """
         Counts 1) native non guest users
                2) native guests users
@@ -439,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 results[row[0]] = row[1]
             return results
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_daily_user_type", _count_daily_user_type
         )
 
@@ -531,43 +530,42 @@ class RegistrationWorkerStore(SQLBaseStore):
             "user_get_threepids",
         )
 
-    def user_delete_threepid(self, user_id, medium, address):
-        return self.db_pool.simple_delete(
+    async def user_delete_threepid(self, user_id, medium, address) -> None:
+        await self.db_pool.simple_delete(
             "user_threepids",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
             desc="user_delete_threepid",
         )
 
-    def user_delete_threepids(self, user_id: str):
+    async def user_delete_threepids(self, user_id: str) -> None:
         """Delete all threepid this user has bound
 
         Args:
              user_id: The user id to delete all threepids of
 
         """
-        return self.db_pool.simple_delete(
+        await self.db_pool.simple_delete(
             "user_threepids",
             keyvalues={"user_id": user_id},
             desc="user_delete_threepids",
         )
 
-    def add_user_bound_threepid(self, user_id, medium, address, id_server):
+    async def add_user_bound_threepid(
+        self, user_id: str, medium: str, address: str, id_server: str
+    ):
         """The server proxied a bind request to the given identity server on
         behalf of the given user. We need to remember this in case the user
         asks us to unbind the threepid.
 
         Args:
-            user_id (str)
-            medium (str)
-            address (str)
-            id_server (str)
-
-        Returns:
-            Deferred
+            user_id
+            medium
+            address
+            id_server
         """
         # We need to use an upsert, in case they user had already bound the
         # threepid
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="user_threepid_id_server",
             keyvalues={
                 "user_id": user_id,
@@ -580,41 +578,40 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="add_user_bound_threepid",
         )
 
-    def user_get_bound_threepids(self, user_id):
+    async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
         """Get the threepids that a user has bound to an identity server through the homeserver
         The homeserver remembers where binds to an identity server occurred. Using this
         method can retrieve those threepids.
 
         Args:
-            user_id (str): The ID of the user to retrieve threepids for
+            user_id: The ID of the user to retrieve threepids for
 
         Returns:
-            Deferred[list[dict]]: List of dictionaries containing the following:
+            List of dictionaries containing the following keys:
                 medium (str): The medium of the threepid (e.g "email")
                 address (str): The address of the threepid (e.g "bob@example.com")
         """
-        return self.db_pool.simple_select_list(
+        return await self.db_pool.simple_select_list(
             table="user_threepid_id_server",
             keyvalues={"user_id": user_id},
             retcols=["medium", "address"],
             desc="user_get_bound_threepids",
         )
 
-    def remove_user_bound_threepid(self, user_id, medium, address, id_server):
+    async def remove_user_bound_threepid(
+        self, user_id: str, medium: str, address: str, id_server: str
+    ) -> None:
         """The server proxied an unbind request to the given identity server on
         behalf of the given user, so we remove the mapping of threepid to
         identity server.
 
         Args:
-            user_id (str)
-            medium (str)
-            address (str)
-            id_server (str)
-
-        Returns:
-            Deferred
+            user_id
+            medium
+            address
+            id_server
         """
-        return self.db_pool.simple_delete(
+        await self.db_pool.simple_delete(
             table="user_threepid_id_server",
             keyvalues={
                 "user_id": user_id,
@@ -625,19 +622,21 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="remove_user_bound_threepid",
         )
 
-    def get_id_servers_user_bound(self, user_id, medium, address):
+    async def get_id_servers_user_bound(
+        self, user_id: str, medium: str, address: str
+    ) -> List[str]:
         """Get the list of identity servers that the server proxied bind
         requests to for given user and threepid
 
         Args:
-            user_id (str)
-            medium (str)
-            address (str)
+            user_id: The user to query for identity servers.
+            medium: The medium to query for identity servers.
+            address: The address to query for identity servers.
 
         Returns:
-            Deferred[list[str]]: Resolves to a list of identity servers
+            A list of identity servers
         """
-        return self.db_pool.simple_select_onecol(
+        return await self.db_pool.simple_select_onecol(
             table="user_threepid_id_server",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
             retcol="id_server",
@@ -665,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
         # Convert the integer into a boolean.
         return res == 1
 
-    def get_threepid_validation_session(
-        self, medium, client_secret, address=None, sid=None, validated=True
-    ):
+    async def get_threepid_validation_session(
+        self,
+        medium: Optional[str],
+        client_secret: str,
+        address: Optional[str] = None,
+        sid: Optional[str] = None,
+        validated: Optional[bool] = True,
+    ) -> Optional[Dict[str, Any]]:
         """Gets a session_id and last_send_attempt (if available) for a
         combination of validation metadata
 
         Args:
-            medium (str|None): The medium of the 3PID
-            address (str|None): The address of the 3PID
-            sid (str|None): The ID of the validation session
-            client_secret (str): A unique string provided by the client to help identify this
+            medium: The medium of the 3PID
+            client_secret: A unique string provided by the client to help identify this
                 validation attempt
-            validated (bool|None): Whether sessions should be filtered by
+            address: The address of the 3PID
+            sid: The ID of the validation session
+            validated: Whether sessions should be filtered by
                 whether they have been validated already or not. None to
                 perform no filtering
 
         Returns:
-            Deferred[dict|None]: A dict containing the following:
+            A dict containing the following:
                 * address - address of the 3pid
                 * medium - medium of the 3pid
                 * client_secret - a secret provided by the client for this validation session
@@ -728,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
 
             return rows[0]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_threepid_validation_session", get_threepid_validation_session_txn
         )
 
-    def delete_threepid_session(self, session_id):
+    async def delete_threepid_session(self, session_id: str) -> None:
         """Removes a threepid validation session from the database. This can
         be done after validation has been performed and whatever action was
         waiting on it has been carried out
 
         Args:
-            session_id (str): The ID of the session to delete
+            session_id: The ID of the session to delete
         """
 
         def delete_threepid_session_txn(txn):
@@ -753,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 keyvalues={"session_id": session_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_threepid_session", delete_threepid_session_txn
         )
 
@@ -891,6 +895,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         super(RegistrationStore, self).__init__(database, db_conn, hs)
 
         self._account_validity = hs.config.account_validity
+        self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
         if self._account_validity.enabled:
             self._clock.call_later(
@@ -942,40 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_access_token_to_user",
         )
 
-    def register_user(
+    async def register_user(
         self,
-        user_id,
-        password_hash=None,
-        was_guest=False,
-        make_guest=False,
-        appservice_id=None,
-        create_profile_with_displayname=None,
-        admin=False,
-        user_type=None,
-    ):
+        user_id: str,
+        password_hash: Optional[str] = None,
+        was_guest: bool = False,
+        make_guest: bool = False,
+        appservice_id: Optional[str] = None,
+        create_profile_with_displayname: Optional[str] = None,
+        admin: bool = False,
+        user_type: Optional[str] = None,
+        shadow_banned: bool = False,
+    ) -> None:
         """Attempts to register an account.
 
         Args:
-            user_id (str): The desired user ID to register.
-            password_hash (str|None): Optional. The password hash for this user.
-            was_guest (bool): Optional. Whether this is a guest account being
-                upgraded to a non-guest account.
-            make_guest (boolean): True if the the new user should be guest,
-                false to add a regular user account.
-            appservice_id (str): The ID of the appservice registering the user.
-            create_profile_with_displayname (unicode): Optionally create a profile for
+            user_id: The desired user ID to register.
+            password_hash: Optional. The password hash for this user.
+            was_guest: Whether this is a guest account being upgraded to a
+                non-guest account.
+            make_guest: True if the the new user should be guest, false to add a
+                regular user account.
+            appservice_id: The ID of the appservice registering the user.
+            create_profile_with_displayname: Optionally create a profile for
                 the user, setting their displayname to the given value
-            admin (boolean): is an admin user?
-            user_type (str|None): type of user. One of the values from
-                api.constants.UserTypes, or None for a normal user.
+            admin: is an admin user?
+            user_type: type of user. One of the values from api.constants.UserTypes,
+                or None for a normal user.
+            shadow_banned: Whether the user is shadow-banned, i.e. they may be
+                told their requests succeeded but we ignore them.
 
         Raises:
             StoreError if the user_id could not be registered.
-
-        Returns:
-            Deferred
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "register_user",
             self._register_user,
             user_id,
@@ -986,6 +991,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             create_profile_with_displayname,
             admin,
             user_type,
+            shadow_banned,
         )
 
     def _register_user(
@@ -999,6 +1005,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         create_profile_with_displayname,
         admin,
         user_type,
+        shadow_banned,
     ):
         user_id_obj = UserID.from_string(user_id)
 
@@ -1028,6 +1035,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                         "appservice_id": appservice_id,
                         "admin": 1 if admin else 0,
                         "user_type": user_type,
+                        "shadow_banned": shadow_banned,
                     },
                 )
             else:
@@ -1042,6 +1050,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                         "appservice_id": appservice_id,
                         "admin": 1 if admin else 0,
                         "user_type": user_type,
+                        "shadow_banned": shadow_banned,
                     },
                 )
 
@@ -1075,9 +1084,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-    def record_user_external_id(
+    async def record_user_external_id(
         self, auth_provider: str, external_id: str, user_id: str
-    ) -> Deferred:
+    ) -> None:
         """Record a mapping from an external user id to a mxid
 
         Args:
@@ -1085,7 +1094,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             external_id: id on that system
             user_id: complete mxid that it is mapped to
         """
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             table="user_external_ids",
             values={
                 "auth_provider": auth_provider,
@@ -1095,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="record_user_external_id",
         )
 
-    def user_set_password_hash(self, user_id, password_hash):
+    async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
         """
         NB. This does *not* evict any cache because the one use for this
             removes most of the entries subsequently anyway so it would be
@@ -1108,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "user_set_password_hash", user_set_password_hash_txn
         )
 
-    def user_set_consent_version(self, user_id, consent_version):
+    async def user_set_consent_version(
+        self, user_id: str, consent_version: str
+    ) -> None:
         """Updates the user table to record privacy policy consent
 
         Args:
-            user_id (str): full mxid of the user to update
-            consent_version (str): version of the policy the user has consented
-                to
+            user_id: full mxid of the user to update
+            consent_version: version of the policy the user has consented to
 
         Raises:
             StoreError(404) if user not found
@@ -1133,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.db_pool.runInteraction("user_set_consent_version", f)
+        await self.db_pool.runInteraction("user_set_consent_version", f)
 
-    def user_set_consent_server_notice_sent(self, user_id, consent_version):
+    async def user_set_consent_server_notice_sent(
+        self, user_id: str, consent_version: str
+    ) -> None:
         """Updates the user table to record that we have sent the user a server
         notice about privacy policy consent
 
         Args:
-            user_id (str): full mxid of the user to update
-            consent_version (str): version of the policy we have notified the
-                user about
+            user_id: full mxid of the user to update
+            consent_version: version of the policy we have notified the user about
 
         Raises:
             StoreError(404) if user not found
@@ -1157,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
+        await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
 
-    def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
+    async def user_delete_access_tokens(
+        self,
+        user_id: str,
+        except_token_id: Optional[str] = None,
+        device_id: Optional[str] = None,
+    ) -> List[Tuple[str, int, Optional[str]]]:
         """
         Invalidate access tokens belonging to a user
 
         Args:
-            user_id (str):  ID of user the tokens belong to
-            except_token_id (str): list of access_tokens IDs which should
-                *not* be deleted
-            device_id (str|None):  ID of device the tokens are associated with.
+            user_id: ID of user the tokens belong to
+            except_token_id: access_tokens ID which should *not* be deleted
+            device_id: ID of device the tokens are associated with.
                 If None, tokens associated with any device (or no device) will
                 be deleted
         Returns:
-            defer.Deferred[list[str, int, str|None, int]]: a list of
-                (token, token id, device id) for each of the deleted tokens
+            A tuple of (token, token id, device id) for each of the deleted tokens
         """
 
         def f(txn):
@@ -1203,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
             return tokens_and_devices
 
-        return self.db_pool.runInteraction("user_delete_access_tokens", f)
+        return await self.db_pool.runInteraction("user_delete_access_tokens", f)
 
-    def delete_access_token(self, access_token):
+    async def delete_access_token(self, access_token: str) -> None:
         def f(txn):
             self.db_pool.simple_delete_one_txn(
                 txn, table="access_tokens", keyvalues={"token": access_token}
@@ -1215,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 txn, self.get_user_by_access_token, (access_token,)
             )
 
-        return self.db_pool.runInteraction("delete_access_token", f)
+        await self.db_pool.runInteraction("delete_access_token", f)
 
     @cached()
     async def is_guest(self, user_id: str) -> bool:
@@ -1229,36 +1243,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         return res if res else False
 
-    def add_user_pending_deactivation(self, user_id):
+    async def add_user_pending_deactivation(self, user_id: str) -> None:
         """
         Adds a user to the table of users who need to be parted from all the rooms they're
         in
         """
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "users_pending_deactivation",
             values={"user_id": user_id},
             desc="add_user_pending_deactivation",
         )
 
-    def del_user_pending_deactivation(self, user_id):
+    async def del_user_pending_deactivation(self, user_id: str) -> None:
         """
         Removes the given user to the table of users who need to be parted from all the
         rooms they're in, effectively marking that user as fully deactivated.
         """
         # XXX: This should be simple_delete_one but we failed to put a unique index on
         # the table, so somehow duplicate entries have ended up in it.
-        return self.db_pool.simple_delete(
+        await self.db_pool.simple_delete(
             "users_pending_deactivation",
             keyvalues={"user_id": user_id},
             desc="del_user_pending_deactivation",
         )
 
-    def get_user_pending_deactivation(self):
+    async def get_user_pending_deactivation(self) -> Optional[str]:
         """
         Gets one user from the table of users waiting to be parted from all the rooms
         they're in.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             "users_pending_deactivation",
             keyvalues={},
             retcol="user_id",
@@ -1266,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="get_users_pending_deactivation",
         )
 
-    def validate_threepid_session(self, session_id, client_secret, token, current_ts):
+    async def validate_threepid_session(
+        self, session_id: str, client_secret: str, token: str, current_ts: int
+    ) -> Optional[str]:
         """Attempt to validate a threepid session using a token
 
         Args:
-            session_id (str): The id of a validation session
-            client_secret (str): A unique string provided by the client to
-                help identify this validation attempt
-            token (str): A validation token
-            current_ts (int): The current unix time in milliseconds. Used for
-                checking token expiry status
+            session_id: The id of a validation session
+            client_secret: A unique string provided by the client to help identify
+                this validation attempt
+            token: A validation token
+            current_ts: The current unix time in milliseconds. Used for checking
+                token expiry status
 
         Raises:
             ThreepidValidationError: if a matching validation token was not found or has
                 expired
 
         Returns:
-            deferred str|None: A str representing a link to redirect the user
-            to if there is one.
+            A str representing a link to redirect the user to if there is one.
         """
 
         # Insert everything into a transaction in order to run atomically
@@ -1297,15 +1312,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
 
             if not row:
-                raise ThreepidValidationError(400, "Unknown session_id")
+                if self._ignore_unknown_session_error:
+                    # If we need to inhibit the error caused by an incorrect session ID,
+                    # use None as placeholder values for the client secret and the
+                    # validation timestamp.
+                    # It shouldn't be an issue because they're both only checked after
+                    # the token check, which should fail. And if it doesn't for some
+                    # reason, the next check is on the client secret, which is NOT NULL,
+                    # so we don't have to worry about the client secret matching by
+                    # accident.
+                    row = {"client_secret": None, "validated_at": None}
+                else:
+                    raise ThreepidValidationError(400, "Unknown session_id")
+
             retrieved_client_secret = row["client_secret"]
             validated_at = row["validated_at"]
 
-            if retrieved_client_secret != client_secret:
-                raise ThreepidValidationError(
-                    400, "This client_secret does not match the provided session_id"
-                )
-
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 table="threepid_validation_token",
@@ -1321,6 +1343,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             expires = row["expires"]
             next_link = row["next_link"]
 
+            if retrieved_client_secret != client_secret:
+                raise ThreepidValidationError(
+                    400, "This client_secret does not match the provided session_id"
+                )
+
             # If the session is already validated, no need to revalidate
             if validated_at:
                 return next_link
@@ -1341,73 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             return next_link
 
         # Return next_link if it exists
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "validate_threepid_session_txn", validate_threepid_session_txn
         )
 
-    def upsert_threepid_validation_session(
-        self,
-        medium,
-        address,
-        client_secret,
-        send_attempt,
-        session_id,
-        validated_at=None,
-    ):
-        """Upsert a threepid validation session
-        Args:
-            medium (str): The medium of the 3PID
-            address (str): The address of the 3PID
-            client_secret (str): A unique string provided by the client to
-                help identify this validation attempt
-            send_attempt (int): The latest send_attempt on this session
-            session_id (str): The id of this validation session
-            validated_at (int|None): The unix timestamp in milliseconds of
-                when the session was marked as valid
-        """
-        insertion_values = {
-            "medium": medium,
-            "address": address,
-            "client_secret": client_secret,
-        }
-
-        if validated_at:
-            insertion_values["validated_at"] = validated_at
-
-        return self.db_pool.simple_upsert(
-            table="threepid_validation_session",
-            keyvalues={"session_id": session_id},
-            values={"last_send_attempt": send_attempt},
-            insertion_values=insertion_values,
-            desc="upsert_threepid_validation_session",
-        )
-
-    def start_or_continue_validation_session(
+    async def start_or_continue_validation_session(
         self,
-        medium,
-        address,
-        session_id,
-        client_secret,
-        send_attempt,
-        next_link,
-        token,
-        token_expires,
-    ):
+        medium: str,
+        address: str,
+        session_id: str,
+        client_secret: str,
+        send_attempt: int,
+        next_link: Optional[str],
+        token: str,
+        token_expires: int,
+    ) -> None:
         """Creates a new threepid validation session if it does not already
         exist and associates a new validation token with it
 
         Args:
-            medium (str): The medium of the 3PID
-            address (str): The address of the 3PID
-            session_id (str): The id of this validation session
-            client_secret (str): A unique string provided by the client to
-                help identify this validation attempt
-            send_attempt (int): The latest send_attempt on this session
-            next_link (str|None): The link to redirect the user to upon
-                successful validation
-            token (str): The validation token
-            token_expires (int): The timestamp for which after the token
-                will no longer be valid
+            medium: The medium of the 3PID
+            address: The address of the 3PID
+            session_id: The id of this validation session
+            client_secret: A unique string provided by the client to help
+                identify this validation attempt
+            send_attempt: The latest send_attempt on this session
+            next_link: The link to redirect the user to upon successful validation
+            token: The validation token
+            token_expires: The timestamp for which after the token will no
+                longer be valid
         """
 
         def start_or_continue_validation_session_txn(txn):
@@ -1436,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 },
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "start_or_continue_validation_session",
             start_or_continue_validation_session_txn,
         )
 
-    def cull_expired_threepid_validation_tokens(self):
+    async def cull_expired_threepid_validation_tokens(self) -> None:
         """Remove threepid validation tokens with expiry dates that have passed"""
 
         def cull_expired_threepid_validation_tokens_txn(txn, ts):
@@ -1449,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             DELETE FROM threepid_validation_token WHERE
             expires < ?
             """
-            return txn.execute(sql, (ts,))
+            txn.execute(sql, (ts,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "cull_expired_threepid_validation_tokens",
             cull_expired_threepid_validation_tokens_txn,
             self.clock.time_msec(),
diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py
index cf9ba51205..1e361aaa9a 100644
--- a/synapse/storage/databases/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
 
 from synapse.storage._base import SQLBaseStore
 
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
 
 
 class RejectionsStore(SQLBaseStore):
-    def get_rejection_reason(self, event_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_rejection_reason(self, event_id: str) -> Optional[str]:
+        return await self.db_pool.simple_select_one_onecol(
             table="rejections",
             retcol="reason",
             keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index a9ceffc20e..5cd61547f7 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
 
 class RelationsWorkerStore(SQLBaseStore):
     @cached(tree=True)
-    def get_relations_for_event(
+    async def get_relations_for_event(
         self,
-        event_id,
-        relation_type=None,
-        event_type=None,
-        aggregation_key=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
+        event_id: str,
+        relation_type: Optional[str] = None,
+        event_type: Optional[str] = None,
+        aggregation_key: Optional[str] = None,
+        limit: int = 5,
+        direction: str = "b",
+        from_token: Optional[RelationPaginationToken] = None,
+        to_token: Optional[RelationPaginationToken] = None,
+    ) -> PaginationChunk:
         """Get a list of relations for an event, ordered by topological ordering.
 
         Args:
-            event_id (str): Fetch events that relate to this event ID.
-            relation_type (str|None): Only fetch events with this relation
-                type, if given.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            aggregation_key (str|None): Only fetch events with this aggregation
-                key, if given.
-            limit (int): Only fetch the most recent `limit` events.
-            direction (str): Whether to fetch the most recent first (`"b"`) or
-                the oldest first (`"f"`).
-            from_token (RelationPaginationToken|None): Fetch rows from the given
-                token, or from the start if None.
-            to_token (RelationPaginationToken|None): Fetch rows up to the given
-                token, or up to the end if None.
+            event_id: Fetch events that relate to this event ID.
+            relation_type: Only fetch events with this relation type, if given.
+            event_type: Only fetch events with this event type, if given.
+            aggregation_key: Only fetch events with this aggregation key, if given.
+            limit: Only fetch the most recent `limit` events.
+            direction: Whether to fetch the most recent first (`"b"`) or the
+                oldest first (`"f"`).
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
-            Deferred[PaginationChunk]: List of event IDs that match relations
-            requested. The rows are of the form `{"event_id": "..."}`.
+            List of event IDs that match relations requested. The rows are of
+            the form `{"event_id": "..."}`.
         """
 
         where_clause = ["relates_to_id = ?"]
@@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_recent_references_for_event", _get_recent_references_for_event_txn
         )
 
     @cached(tree=True)
-    def get_aggregation_groups_for_event(
+    async def get_aggregation_groups_for_event(
         self,
-        event_id,
-        event_type=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
+        event_id: str,
+        event_type: Optional[str] = None,
+        limit: int = 5,
+        direction: str = "b",
+        from_token: Optional[AggregationPaginationToken] = None,
+        to_token: Optional[AggregationPaginationToken] = None,
+    ) -> PaginationChunk:
         """Get a list of annotations on the event, grouped by event type and
         aggregation key, sorted by count.
 
@@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
         on an event.
 
         Args:
-            event_id (str): Fetch events that relate to this event ID.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            limit (int): Only fetch the `limit` groups.
-            direction (str): Whether to fetch the highest count first (`"b"`) or
+            event_id: Fetch events that relate to this event ID.
+            event_type: Only fetch events with this event type, if given.
+            limit: Only fetch the `limit` groups.
+            direction: Whether to fetch the highest count first (`"b"`) or
                 the lowest count first (`"f"`).
-            from_token (AggregationPaginationToken|None): Fetch rows from the
-                given token, or from the start if None.
-            to_token (AggregationPaginationToken|None): Fetch rows up to the
-                given token, or up to the end if None.
-
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
-            Deferred[PaginationChunk]: List of groups of annotations that
-            match. Each row is a dict with `type`, `key` and `count` fields.
+            List of groups of annotations that match. Each row is a dict with
+            `type`, `key` and `count` fields.
         """
 
         where_clause = ["relates_to_id = ?", "relation_type = ?"]
@@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
@@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return await self.get_event(edit_id, allow_none=True)
 
-    def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+    async def has_user_annotated_event(
+        self, parent_id: str, event_type: str, aggregation_key: str, sender: str
+    ) -> bool:
         """Check if a user has already annotated an event with the same key
         (e.g. already liked an event).
 
         Args:
-            parent_id (str): The event being annotated
-            event_type (str): The event type of the annotation
-            aggregation_key (str): The aggregation key of the annotation
-            sender (str): The sender of the annotation
+            parent_id: The event being annotated
+            event_type: The event type of the annotation
+            aggregation_key: The aggregation key of the annotation
+            sender: The sender of the annotation
 
         Returns:
-            Deferred[bool]
+            True if the event is already annotated.
         """
 
         sql = """
@@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
             return bool(txn.fetchone())
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f4008e6221..127588ce4c 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -21,24 +21,19 @@ from abc import abstractmethod
 from enum import Enum
 from typing import Any, Dict, List, Optional, Tuple
 
-from canonicaljson import json
-
 from synapse.api.constants import EventTypes
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchStore
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
+from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
 
 
-OpsLevel = collections.namedtuple(
-    "OpsLevel", ("ban_level", "kick_level", "redact_level")
-)
-
 RatelimitOverride = collections.namedtuple(
     "RatelimitOverride", ("messages_per_second", "burst_count")
 )
@@ -78,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
 
         self.config = hs.config
 
-    def get_room(self, room_id):
+    async def get_room(self, room_id: str) -> dict:
         """Retrieve a room.
 
         Args:
-            room_id (str): The ID of the room to retrieve.
+            room_id: The ID of the room to retrieve.
         Returns:
             A dict containing the room information, or None if the room is unknown.
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             table="rooms",
             keyvalues={"room_id": room_id},
             retcols=("room_id", "is_public", "creator"),
@@ -94,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    def get_room_with_stats(self, room_id: str):
+    async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
         """Retrieve room with statistics.
 
         Args:
@@ -109,7 +104,8 @@ class RoomWorkerStore(SQLBaseStore):
                   curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
                   rooms.creator, state.encryption, state.is_federatable AS federatable,
                   rooms.is_public AS public, state.join_rules, state.guest_access,
-                  state.history_visibility, curr.current_state_events AS state_events
+                  state.history_visibility, curr.current_state_events AS state_events,
+                  state.avatar, state.topic
                 FROM rooms
                 LEFT JOIN room_stats_state state USING (room_id)
                 LEFT JOIN room_stats_current curr USING (room_id)
@@ -126,25 +122,29 @@ class RoomWorkerStore(SQLBaseStore):
             res["public"] = bool(res["public"])
             return res
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_room_with_stats", get_room_with_stats_txn, room_id
         )
 
-    def get_public_room_ids(self):
-        return self.db_pool.simple_select_onecol(
+    async def get_public_room_ids(self) -> List[str]:
+        return await self.db_pool.simple_select_onecol(
             table="rooms",
             keyvalues={"is_public": True},
             retcol="room_id",
             desc="get_public_room_ids",
         )
 
-    def count_public_rooms(self, network_tuple, ignore_non_federatable):
+    async def count_public_rooms(
+        self,
+        network_tuple: Optional[ThirdPartyInstanceID],
+        ignore_non_federatable: bool,
+    ) -> int:
         """Counts the number of public rooms as tracked in the room_stats_current
         and room_stats_state table.
 
         Args:
-            network_tuple (ThirdPartyInstanceID|None)
-            ignore_non_federatable (bool): If true filters out non-federatable rooms
+            network_tuple
+            ignore_non_federatable: If true filters out non-federatable rooms
         """
 
         def _count_public_rooms_txn(txn):
@@ -188,7 +188,7 @@ class RoomWorkerStore(SQLBaseStore):
             txn.execute(sql, query_args)
             return txn.fetchone()[0]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_public_rooms", _count_public_rooms_txn
         )
 
@@ -335,8 +335,8 @@ class RoomWorkerStore(SQLBaseStore):
         return ret_val
 
     @cached(max_entries=10000)
-    def is_room_blocked(self, room_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_room_blocked(self, room_id: str) -> Optional[bool]:
+        return await self.db_pool.simple_select_one_onecol(
             table="blocked_rooms",
             keyvalues={"room_id": room_id},
             retcol="1",
@@ -591,15 +591,14 @@ class RoomWorkerStore(SQLBaseStore):
 
         return row
 
-    def get_media_mxcs_in_room(self, room_id):
+    async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
         """Retrieves all the local and remote media MXC URIs in a given room
 
         Args:
-            room_id (str)
+            room_id
 
         Returns:
-            The local and remote media as a lists of tuples where the key is
-            the hostname and the value is the media ID.
+            The local and remote media as a lists of the media IDs.
         """
 
         def _get_media_mxcs_in_room_txn(txn):
@@ -615,11 +614,13 @@ class RoomWorkerStore(SQLBaseStore):
 
             return local_media_mxcs, remote_media_mxcs
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_media_ids_in_room", _get_media_mxcs_in_room_txn
         )
 
-    def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+    async def quarantine_media_ids_in_room(
+        self, room_id: str, quarantined_by: str
+    ) -> int:
         """For a room loops through all events with media and quarantines
         the associated media
         """
@@ -632,7 +633,7 @@ class RoomWorkerStore(SQLBaseStore):
                 txn, local_mxcs, remote_mxcs, quarantined_by
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_in_room", _quarantine_media_in_room_txn
         )
 
@@ -695,9 +696,9 @@ class RoomWorkerStore(SQLBaseStore):
 
         return local_media_mxcs, remote_media_mxcs
 
-    def quarantine_media_by_id(
+    async def quarantine_media_by_id(
         self, server_name: str, media_id: str, quarantined_by: str,
-    ):
+    ) -> int:
         """quarantines a single local or remote media id
 
         Args:
@@ -716,11 +717,13 @@ class RoomWorkerStore(SQLBaseStore):
                 txn, local_mxcs, remote_mxcs, quarantined_by
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_by_user", _quarantine_media_by_id_txn
         )
 
-    def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+    async def quarantine_media_ids_by_user(
+        self, user_id: str, quarantined_by: str
+    ) -> int:
         """quarantines all local media associated with a single user
 
         Args:
@@ -732,7 +735,7 @@ class RoomWorkerStore(SQLBaseStore):
             local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
             return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_by_user", _quarantine_media_by_user_txn
         )
 
@@ -1134,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         },
                     )
 
-            with self._public_room_id_gen.get_next() as next_id:
+            with await self._public_room_id_gen.get_next() as next_id:
                 await self.db_pool.runInteraction(
                     "store_room_txn", store_room_txn, next_id
                 )
@@ -1201,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with self._public_room_id_gen.get_next() as next_id:
+        with await self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
@@ -1281,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with self._public_room_id_gen.get_next() as next_id:
+        with await self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
@@ -1289,8 +1292,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    def get_room_count(self):
-        """Retrieve a list of all rooms
+    async def get_room_count(self) -> int:
+        """Retrieve the total number of rooms.
         """
 
         def f(txn):
@@ -1299,13 +1302,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             row = txn.fetchone()
             return row[0] or 0
 
-        return self.db_pool.runInteraction("get_rooms", f)
+        return await self.db_pool.runInteraction("get_rooms", f)
 
-    def add_event_report(
-        self, room_id, event_id, user_id, reason, content, received_ts
-    ):
+    async def add_event_report(
+        self,
+        room_id: str,
+        event_id: str,
+        user_id: str,
+        reason: str,
+        content: JsonDict,
+        received_ts: int,
+    ) -> None:
         next_id = self._event_reports_id_gen.get_next()
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             table="event_reports",
             values={
                 "id": next_id,
@@ -1314,7 +1323,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 "event_id": event_id,
                 "user_id": user_id,
                 "reason": reason,
-                "content": json.dumps(content),
+                "content": json_encoder.encode(content),
             },
             desc="add_event_report",
         )
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index b2fcfc9bfe..91a8b43da3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,9 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 lambda: self._known_servers_count,
             )
 
-    @defer.inlineCallbacks
-    def _count_known_servers(self):
+    async def _count_known_servers(self):
         """
         Count the servers that this server knows about.
 
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(query)
             return list(txn)[0][0]
 
-        count = yield self.db_pool.runInteraction("get_known_servers", _transact)
+        count = await self.db_pool.runInteraction("get_known_servers", _transact)
 
         # We always know about ourselves, even if we have nothing in
         # room_memberships (for example, the server is new).
@@ -155,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
 
     @cached(max_entries=100000, iterable=True)
-    def get_users_in_room(self, room_id: str):
-        return self.db_pool.runInteraction(
+    async def get_users_in_room(self, room_id: str) -> List[str]:
+        return await self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
@@ -183,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return [r[0] for r in txn]
 
     @cached(max_entries=100000)
-    def get_room_summary(self, room_id: str):
+    async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
         """ Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
         Args:
             room_id: The room ID to query
         Returns:
-            Deferred[dict[str, MemberSummary]:
-                dict of membership states, pointing to a MemberSummary named tuple.
+            dict of membership states, pointing to a MemberSummary named tuple.
         """
 
         def _get_room_summary_txn(txn):
@@ -264,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
             return res
 
-        return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
+        return await self.db_pool.runInteraction(
+            "get_room_summary", _get_room_summary_txn
+        )
 
     @cached()
-    def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
+    async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
         """Get all the rooms the *local* user is invited to.
 
         Args:
             user_id: The user ID.
 
         Returns:
-            A awaitable list of RoomsForUser.
+            A list of RoomsForUser.
         """
 
-        return self.get_rooms_for_local_user_where_membership_is(
+        return await self.get_rooms_for_local_user_where_membership_is(
             user_id, [Membership.INVITE]
         )
 
@@ -300,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return None
 
     async def get_rooms_for_local_user_where_membership_is(
-        self, user_id: str, membership_list: List[str]
-    ) -> Optional[List[RoomsForUser]]:
+        self, user_id: str, membership_list: Collection[str]
+    ) -> List[RoomsForUser]:
         """Get all the rooms for this *local* user where the membership for this user
         matches one in the membership list.
 
@@ -316,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             The RoomsForUser that the user matches the membership types.
         """
         if not membership_list:
-            return None
+            return []
 
         rooms = await self.db_pool.runInteraction(
             "get_rooms_for_local_user_where_membership_is",
@@ -360,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return results
 
     @cached(max_entries=500000, iterable=True)
-    def get_rooms_for_user_with_stream_ordering(self, user_id: str):
+    async def get_rooms_for_user_with_stream_ordering(
+        self, user_id: str
+    ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
         """Returns a set of room_ids the user is currently joined to.
 
         If a remote user only returns rooms this server is currently
@@ -370,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             user_id
 
         Returns:
-            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
-            the rooms the user is in currently, along with the stream ordering
-            of the most recent join for that user and room.
+            Returns the rooms the user is in currently, along with the stream
+            ordering of the most recent join for that user and room.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_rooms_for_user_with_stream_ordering",
             self._get_rooms_for_user_with_stream_ordering_txn,
             user_id,
         )
 
-    def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
+    def _get_rooms_for_user_with_stream_ordering_txn(
+        self, txn, user_id: str
+    ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
         # We use `current_state_events` here and not `local_current_membership`
         # as a) this gets called with remote users and b) this only gets called
         # for rooms the server is participating in.
@@ -407,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
 
         txn.execute(sql, (user_id, Membership.JOIN))
-        results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
-
-        return results
+        return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
 
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
@@ -589,11 +588,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id",
-        list_name="event_ids",
-        inlineCallbacks=True,
+        cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
     )
-    def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+    async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
         """For given set of member event_ids check if they point to a join
         event and if so return the associated user and profile info.
 
@@ -601,11 +598,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             event_ids: The member event IDs to lookup
 
         Returns:
-            Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+            dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
             to `user_id` and ProfileInfo (or None if not join event).
         """
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
@@ -716,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return count == 0
 
     @cached()
-    def get_forgotten_rooms_for_user(self, user_id: str):
+    async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
         """Gets all rooms the user has forgotten.
 
         Args:
-            user_id
+            user_id: The user ID to query the rooms of.
 
         Returns:
-            Deferred[set[str]]
+            The forgotten rooms.
         """
 
         def _get_forgotten_rooms_for_user_txn(txn):
@@ -749,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(sql, (user_id,))
             return {row[0] for row in txn if row[1] == 0}
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
         )
 
@@ -772,13 +769,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return set(room_ids)
 
-    def get_membership_from_event_ids(
+    async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
     ) -> List[dict]:
         """Get user_id and membership of a set of event IDs.
         """
 
-        return self.db_pool.simple_select_many_batch(
+        return await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=member_event_ids,
@@ -978,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
-    def forget(self, user_id: str, room_id: str):
+    async def forget(self, user_id: str, room_id: str) -> None:
         """Indicate that user_id wishes to discard history for room_id."""
 
         def f(txn):
@@ -999,10 +996,10 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
                 txn, self.get_forgotten_rooms_for_user, (user_id,)
             )
 
-        return self.db_pool.runInteraction("forget_membership", f)
+        await self.db_pool.runInteraction("forget_membership", f)
 
 
-class _JoinedHostsCache(object):
+class _JoinedHostsCache:
     """Cache for joined hosts in a room that is optimised to handle updates
     via state deltas.
     """
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
new file mode 100644
index 0000000000..b64926e9c9
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -0,0 +1,33 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is the postgres specific migration modifying the table with a background
+ * migration.
+ */
+
+-- add new index that includes method to local media
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('local_media_repository_thumbnails_method_idx', '{}');
+
+-- add new index that includes method to remote media
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+  ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+
+-- drop old index
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+  ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
new file mode 100644
index 0000000000..1d0c04b53a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
@@ -0,0 +1,44 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is a sqlite specific migration, since sqlite can't modify the unique
+ * constraint of a table without recreating it.
+ */
+
+CREATE TABLE local_media_repository_thumbnails_new ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO local_media_repository_thumbnails_new
+    SELECT media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method, thumbnail_length
+    FROM local_media_repository_thumbnails;
+
+DROP TABLE local_media_repository_thumbnails;
+
+ALTER TABLE local_media_repository_thumbnails_new RENAME TO local_media_repository_thumbnails;
+
+CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id);
+
+
+
+CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails_new ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO remote_media_cache_thumbnails_new
+    SELECT media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_method, thumbnail_type, thumbnail_length, filesystem_id
+    FROM remote_media_cache_thumbnails;
+
+DROP TABLE remote_media_cache_thumbnails;
+
+ALTER TABLE remote_media_cache_thumbnails_new RENAME TO remote_media_cache_thumbnails;
diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
new file mode 100644
index 0000000000..4cc96a5341
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A table of the IP address and user-agent used to complete each step of a
+-- user-interactive authentication session.
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
+    session_id TEXT NOT NULL,
+    ip TEXT NOT NULL,
+    user_agent TEXT NOT NULL,
+    UNIQUE (session_id, ip, user_agent),
+    FOREIGN KEY (session_id)
+        REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
new file mode 100644
index 0000000000..260b009b48
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A shadow-banned user may be told that their requests succeeded when they were
+-- actually ignored.
+ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
new file mode 100644
index 0000000000..847aebd85e
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+  Delete stuck 'enabled' bits that correspond to deleted or non-existent push rules.
+  We ignore rules that are server-default rules because they are not defined
+  in the `push_rules` table.
+**/
+
+DELETE FROM push_rules_enable WHERE
+  rule_id NOT LIKE 'global/%/.m.rule.%'
+  AND NOT EXISTS (
+    SELECT 1 FROM push_rules
+    WHERE push_rules.user_name = push_rules_enable.user_name
+      AND push_rules.rule_id = push_rules_enable.rule_id
+  );
diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
new file mode 100644
index 0000000000..15421b99ac
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This table is no longer used.
+DROP TABLE IF EXISTS presence_allow_inbound;
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
new file mode 100644
index 0000000000..98ff76d709
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE events ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
new file mode 100644
index 0000000000..97c1e6a0c5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
+
+SELECT setval('events_stream_seq', (
+    SELECT COALESCE(MAX(stream_ordering), 1) FROM events
+));
+
+CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+
+SELECT setval('events_backfill_stream_seq', (
+    SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+));
diff --git a/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
new file mode 100644
index 0000000000..ebfbed7925
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
@@ -0,0 +1,42 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This schema delta alters the schema to enable 'catching up' remote homeservers
+-- after there has been a connectivity problem for any reason.
+
+-- This stores, for each (destination, room) pair, the stream_ordering of the
+-- latest event for that destination.
+CREATE TABLE IF NOT EXISTS destination_rooms (
+  -- the destination in question.
+  destination TEXT NOT NULL REFERENCES destinations (destination),
+  -- the ID of the room in question
+  room_id TEXT NOT NULL REFERENCES rooms (room_id),
+  -- the stream_ordering of the event
+  stream_ordering BIGINT NOT NULL,
+  PRIMARY KEY (destination, room_id)
+  -- We don't declare a foreign key on stream_ordering here because that'd mean
+  -- we'd need to either maintain an index (expensive) or do a table scan of
+  -- destination_rooms whenever we delete an event (also potentially expensive).
+  -- In addition to that, a foreign key on stream_ordering would be redundant
+  -- as this row doesn't need to refer to a specific event; if the event gets
+  -- deleted then it doesn't affect the validity of the stream_ordering here.
+);
+
+-- This index is needed to make it so that a deletion of a room (in the rooms
+-- table) can be efficient, as otherwise a table scan would need to be performed
+-- to check that no destination_rooms rows point to the room to be deleted.
+-- Also: it makes it efficient to delete all the entries for a given room ID,
+-- such as when purging a room.
+CREATE INDEX IF NOT EXISTS destination_rooms_room_id
+    ON destination_rooms (room_id);
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
new file mode 100644
index 0000000000..317fba8a5d
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're hijacking the push actions to store unread messages and unread counts (specified
+-- in MSC2654) because doing otherwise would result in either performance issues or
+-- reimplementing a consequent bit of the push actions.
+
+-- Add columns to event_push_actions and event_push_actions_staging to track unread
+-- messages and calculate unread counts.
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT;
+
+-- Add column to event_push_summary
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
new file mode 100644
index 0000000000..55f5d0f732
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This delta file fixes a regression introduced by 58/12room_stats.sql, removing the hacky
+-- populate_stats_process_rooms_2 background job and restores the functionality under the
+-- original name.
+-- See https://github.com/matrix-org/synapse/issues/8238 for details
+
+DELETE FROM background_updates WHERE update_name = 'populate_stats_process_rooms';
+UPDATE background_updates SET update_name = 'populate_stats_process_rooms'
+    WHERE update_name = 'populate_stats_process_rooms_2';
diff --git a/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
new file mode 100644
index 0000000000..a67aa5e500
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
@@ -0,0 +1,21 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This column tracks the stream_ordering of the event that was most recently
+-- successfully transmitted to the destination.
+-- A value of NULL means that we have not sent an event successfully yet
+-- (at least, not since the introduction of this column).
+ALTER TABLE destinations
+    ADD COLUMN last_successful_stream_ordering BIGINT;
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 7f8d1880e5..f01cf2fd02 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,9 +16,10 @@
 import logging
 import re
 from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Set
 
 from synapse.api.errors import SynapseError
+from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
             "count": count,
         }
 
-    def _find_highlights_in_postgres(self, search_query, events):
+    async def _find_highlights_in_postgres(
+        self, search_query: str, events: List[EventBase]
+    ) -> Set[str]:
         """Given a list of events and a search term, return a list of words
         that match from the content of the event.
 
@@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
         highlight the matching parts.
 
         Args:
-            search_query (str)
-            events (list): A list of events
+            search_query
+            events: A list of events
 
         Returns:
-            deferred : A set of strings.
+            A set of strings.
         """
 
         def f(txn):
@@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
             return highlight_words
 
-        return self.db_pool.runInteraction("_find_highlights", f)
+        return await self.db_pool.runInteraction("_find_highlights", f)
 
 
 def _to_postgres_options(options_dict):
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index be191dd870..c8c67953e4 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -13,9 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Dict, Iterable, List, Tuple
+
 from unpaddedbase64 import encode_base64
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.types import Cursor
 from synapse.util.caches.descriptors import cached, cachedList
 
 
@@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
     @cachedList(
         cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
     )
-    def get_event_reference_hashes(self, event_ids):
+    async def get_event_reference_hashes(
+        self, event_ids: Iterable[str]
+    ) -> Dict[str, Dict[str, bytes]]:
+        """Get all hashes for given events.
+
+        Args:
+            event_ids: The event IDs to get hashes for.
+
+        Returns:
+             A mapping of event ID to a mapping of algorithm to hash.
+        """
+
         def f(txn):
             return {
                 event_id: self._get_event_reference_hashes_txn(txn, event_id)
                 for event_id in event_ids
             }
 
-        return self.db_pool.runInteraction("get_event_reference_hashes", f)
+        return await self.db_pool.runInteraction("get_event_reference_hashes", f)
 
-    async def add_event_hashes(self, event_ids):
+    async def add_event_hashes(
+        self, event_ids: Iterable[str]
+    ) -> List[Tuple[str, Dict[str, str]]]:
+        """
+
+        Args:
+            event_ids: The event IDs
+
+        Returns:
+            A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
+        """
         hashes = await self.get_event_reference_hashes(event_ids)
         hashes = {
             e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
@@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
 
         return list(hashes.items())
 
-    def _get_event_reference_hashes_txn(self, txn, event_id):
+    def _get_event_reference_hashes_txn(
+        self, txn: Cursor, event_id: str
+    ) -> Dict[str, bytes]:
         """Get all the hashes for a given PDU.
         Args:
-            txn (cursor):
-            event_id (str): Id for the Event.
+            txn:
+            event_id: Id for the Event.
         Returns:
-            A dict[unicode, bytes] of algorithm -> hash.
+            A mapping of algorithm -> hash.
         """
         query = (
             "SELECT algorithm, hash"
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 96e0378e50..5c6168e301 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
+from synapse.types import StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         return create_event
 
     @cached(max_entries=100000, iterable=True)
-    def get_current_state_ids(self, room_id):
+    async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
         """Get the current state event ids for a room based on the
         current_state_events table.
 
         Args:
-            room_id (str)
+            room_id: The room to get the state IDs of.
 
         Returns:
-            deferred: dict of (type, state_key) -> event_id
+            The current state of the room.
         """
 
         def _get_current_state_ids_txn(txn):
@@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_current_state_ids", _get_current_state_ids_txn
         )
 
     # FIXME: how should this be cached?
-    def get_filtered_current_state_ids(
+    async def get_filtered_current_state_ids(
         self, room_id: str, state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> StateMap[str]:
         """Get the current state event of a given type for a room based on the
         current_state_events table.  This may not be as up-to-date as the result
         of doing a fresh state resolution as per state_handler.get_current_state
@@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 from the database.
 
         Returns:
-            defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
+            Map from type/state_key to event ID.
         """
 
         where_clause, where_args = state_filter.make_sql_filter_clause()
 
         if not where_clause:
             # We delegate to the cached version
-            return self.get_current_state_ids(room_id)
+            return await self.get_current_state_ids(room_id)
 
         def _get_filtered_current_state_ids_txn(txn):
             results = {}
@@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return results
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
         )
 
@@ -260,8 +261,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         return event.content.get("canonical_alias")
 
     @cached(max_entries=50000)
-    def _get_state_group_for_event(self, event_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
+        return await self.db_pool.simple_select_one_onecol(
             table="event_to_state_groups",
             keyvalues={"event_id": event_id},
             retcol="state_group",
@@ -273,12 +274,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         cached_method_name="_get_state_group_for_event",
         list_name="event_ids",
         num_args=1,
-        inlineCallbacks=True,
     )
-    def _get_state_group_for_events(self, event_ids):
+    async def _get_state_group_for_events(self, event_ids):
         """Returns mapping event_id -> state_group
         """
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="event_to_state_groups",
             column="event_id",
             iterable=event_ids,
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 0d963c98ff..356623fc6e 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -14,8 +14,7 @@
 # limitations under the License.
 
 import logging
-
-from twisted.internet import defer
+from typing import Any, Dict, List, Tuple
 
 from synapse.storage._base import SQLBaseStore
 
@@ -23,7 +22,9 @@ logger = logging.getLogger(__name__)
 
 
 class StateDeltasStore(SQLBaseStore):
-    def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
+    async def get_current_state_deltas(
+        self, prev_stream_id: int, max_stream_id: int
+    ) -> Tuple[int, List[Dict[str, Any]]]:
         """Fetch a list of room state changes since the given stream id
 
         Each entry in the result contains the following fields:
@@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore):
                 if it's new state.
 
         Args:
-            prev_stream_id (int): point to get changes since (exclusive)
-            max_stream_id (int): the point that we know has been correctly persisted
+            prev_stream_id: point to get changes since (exclusive)
+            max_stream_id: the point that we know has been correctly persisted
                - ie, an upper limit to return changes from.
 
         Returns:
-            Deferred[tuple[int, list[dict]]: A tuple consisting of:
+            A tuple consisting of:
                - the stream id which these results go up to
                - list of current_state_delta_stream rows. If it is empty, we are
                  up to date.
@@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
             # if the CSDs haven't changed between prev_stream_id and now, we
             # know for certain that they haven't changed between prev_stream_id and
             # max_stream_id.
-            return defer.succeed((max_stream_id, []))
+            return (max_stream_id, [])
 
         def get_current_state_deltas_txn(txn):
             # First we calculate the max stream id that will give us less than
@@ -102,7 +103,7 @@ class StateDeltasStore(SQLBaseStore):
             txn.execute(sql, (prev_stream_id, clipped_stream_id))
             return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_current_state_deltas", get_current_state_deltas_txn
         )
 
@@ -114,8 +115,8 @@ class StateDeltasStore(SQLBaseStore):
             retcol="COALESCE(MAX(stream_id), -1)",
         )
 
-    def get_max_stream_id_in_current_state_deltas(self):
-        return self.db_pool.runInteraction(
+    async def get_max_stream_id_in_current_state_deltas(self):
+        return await self.db_pool.runInteraction(
             "get_max_stream_id_in_current_state_deltas",
             self._get_max_stream_id_in_current_state_deltas_txn,
         )
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 802c9019b9..30840dbbaa 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,8 +15,9 @@
 # limitations under the License.
 
 import logging
+from collections import Counter
 from itertools import chain
-from typing import Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet.defer import DeferredLock
 
@@ -73,9 +74,6 @@ class StatsStore(StateDeltasStore):
             "populate_stats_process_rooms", self._populate_stats_process_rooms
         )
         self.db_pool.updates.register_background_update_handler(
-            "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2
-        )
-        self.db_pool.updates.register_background_update_handler(
             "populate_stats_process_users", self._populate_stats_process_users
         )
         # we no longer need to perform clean-up, but we will give ourselves
@@ -147,31 +145,10 @@ class StatsStore(StateDeltasStore):
         return len(users_to_work_on)
 
     async def _populate_stats_process_rooms(self, progress, batch_size):
-        """
-        This was a background update which regenerated statistics for rooms.
-
-        It has been replaced by StatsStore._populate_stats_process_rooms_2. This background
-        job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure
-        someone upgrading from <v1.0.0, this background task has been turned into a no-op
-        so that the potentially expensive task is not run twice.
-
-        Further context: https://github.com/matrix-org/synapse/pull/7977
-        """
-        await self.db_pool.updates._end_background_update(
-            "populate_stats_process_rooms"
-        )
-        return 1
-
-    async def _populate_stats_process_rooms_2(self, progress, batch_size):
-        """
-        This is a background update which regenerates statistics for rooms.
-
-        It replaces StatsStore._populate_stats_process_rooms. See its docstring for the
-        reasoning.
-        """
+        """This is a background update which regenerates statistics for rooms."""
         if not self.stats_enabled:
             await self.db_pool.updates._end_background_update(
-                "populate_stats_process_rooms_2"
+                "populate_stats_process_rooms"
             )
             return 1
 
@@ -188,13 +165,13 @@ class StatsStore(StateDeltasStore):
             return [r for r, in txn]
 
         rooms_to_work_on = await self.db_pool.runInteraction(
-            "populate_stats_rooms_2_get_batch", _get_next_batch
+            "populate_stats_rooms_get_batch", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
             await self.db_pool.updates._end_background_update(
-                "populate_stats_process_rooms_2"
+                "populate_stats_process_rooms"
             )
             return 1
 
@@ -203,34 +180,52 @@ class StatsStore(StateDeltasStore):
             progress["last_room_id"] = room_id
 
         await self.db_pool.runInteraction(
-            "_populate_stats_process_rooms_2",
+            "_populate_stats_process_rooms",
             self.db_pool.updates._background_update_progress_txn,
-            "populate_stats_process_rooms_2",
+            "populate_stats_process_rooms",
             progress,
         )
 
         return len(rooms_to_work_on)
 
-    def get_stats_positions(self):
+    async def get_stats_positions(self) -> int:
         """
         Returns the stats processor positions.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="stats_incremental_position",
             keyvalues={},
             retcol="stream_id",
             desc="stats_incremental_position",
         )
 
-    def update_room_state(self, room_id, fields):
-        """
+    async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
+        """Update the state of a room.
+
+        fields can contain the following keys with string values:
+        * join_rules
+        * history_visibility
+        * encryption
+        * name
+        * topic
+        * avatar
+        * canonical_alias
+
+        A is_federatable key can also be included with a boolean value.
+
         Args:
-            room_id (str)
-            fields (dict[str:Any])
+            room_id: The room ID to update the state of.
+            fields: The fields to update. This can include a partial list of the
+                above fields to only update some room information.
         """
-
-        # For whatever reason some of the fields may contain null bytes, which
-        # postgres isn't a fan of, so we replace those fields with null.
+        # Ensure that the values to update are valid, they should be strings and
+        # not contain any null bytes.
+        #
+        # Invalid data gets overwritten with null.
+        #
+        # Note that a missing value should not be overwritten (it keeps the
+        # previous value).
+        sentinel = object()
         for col in (
             "join_rules",
             "history_visibility",
@@ -240,32 +235,34 @@ class StatsStore(StateDeltasStore):
             "avatar",
             "canonical_alias",
         ):
-            field = fields.get(col)
-            if field and "\0" in field:
+            field = fields.get(col, sentinel)
+            if field is not sentinel and (not isinstance(field, str) or "\0" in field):
                 fields[col] = None
 
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="room_stats_state",
             keyvalues={"room_id": room_id},
             values=fields,
             desc="update_room_state",
         )
 
-    def get_statistics_for_subject(self, stats_type, stats_id, start, size=100):
+    async def get_statistics_for_subject(
+        self, stats_type: str, stats_id: str, start: str, size: int = 100
+    ) -> List[dict]:
         """
         Get statistics for a given subject.
 
         Args:
-            stats_type (str): The type of subject
-            stats_id (str): The ID of the subject (e.g. room_id or user_id)
-            start (int): Pagination start. Number of entries, not timestamp.
-            size (int): How many entries to return.
+            stats_type: The type of subject
+            stats_id: The ID of the subject (e.g. room_id or user_id)
+            start: Pagination start. Number of entries, not timestamp.
+            size: How many entries to return.
 
         Returns:
-            Deferred[list[dict]], where the dict has the keys of
+            A list of dicts, where the dict has the keys of
             ABSOLUTE_STATS_FIELDS[stats_type],  and "bucket_size" and "end_ts".
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_statistics_for_subject",
             self._get_statistics_for_subject_txn,
             stats_type,
@@ -300,7 +297,7 @@ class StatsStore(StateDeltasStore):
         return slice_list
 
     @cached()
-    def get_earliest_token_for_stats(self, stats_type, id):
+    async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
         """
         Fetch the "earliest token". This is used by the room stats delta
         processor to ignore deltas that have been processed between the
@@ -308,29 +305,28 @@ class StatsStore(StateDeltasStore):
         being calculated.
 
         Returns:
-            Deferred[int]
+            The earliest token.
         """
         table, id_col = TYPE_TO_TABLE[stats_type]
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             "%s_current" % (table,),
             keyvalues={id_col: id},
             retcol="completed_delta_stream_id",
             allow_none=True,
         )
 
-    def bulk_update_stats_delta(self, ts, updates, stream_id):
+    async def bulk_update_stats_delta(
+        self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+    ) -> None:
         """Bulk update stats tables for a given stream_id and updates the stats
         incremental position.
 
         Args:
-            ts (int): Current timestamp in ms
-            updates(dict[str, dict[str, dict[str, Counter]]]): The updates to
-                commit as a mapping stats_type -> stats_id -> field -> delta.
-            stream_id (int): Current position.
-
-        Returns:
-            Deferred
+            ts: Current timestamp in ms
+            updates: The updates to commit as a mapping of
+                stats_type -> stats_id -> field -> delta.
+            stream_id: Current position.
         """
 
         def _bulk_update_stats_delta_txn(txn):
@@ -355,38 +351,37 @@ class StatsStore(StateDeltasStore):
                 updatevalues={"stream_id": stream_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "bulk_update_stats_delta", _bulk_update_stats_delta_txn
         )
 
-    def update_stats_delta(
+    async def update_stats_delta(
         self,
-        ts,
-        stats_type,
-        stats_id,
-        fields,
-        complete_with_stream_id,
-        absolute_field_overrides=None,
-    ):
+        ts: int,
+        stats_type: str,
+        stats_id: str,
+        fields: Dict[str, int],
+        complete_with_stream_id: Optional[int],
+        absolute_field_overrides: Optional[Dict[str, int]] = None,
+    ) -> None:
         """
         Updates the statistics for a subject, with a delta (difference/relative
         change).
 
         Args:
-            ts (int): timestamp of the change
-            stats_type (str): "room" or "user" – the kind of subject
-            stats_id (str): the subject's ID (room ID or user ID)
-            fields (dict[str, int]): Deltas of stats values.
-            complete_with_stream_id (int, optional):
+            ts: timestamp of the change
+            stats_type: "room" or "user" – the kind of subject
+            stats_id: the subject's ID (room ID or user ID)
+            fields: Deltas of stats values.
+            complete_with_stream_id:
                 If supplied, converts an incomplete row into a complete row,
                 with the supplied stream_id marked as the stream_id where the
                 row was completed.
-            absolute_field_overrides (dict[str, int]): Current stats values
-                (i.e. not deltas) of absolute fields.
-                Does not work with per-slice fields.
+            absolute_field_overrides: Current stats values (i.e. not deltas) of
+                absolute fields. Does not work with per-slice fields.
         """
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_stats_delta",
             self._update_stats_delta_txn,
             ts,
@@ -646,19 +641,20 @@ class StatsStore(StateDeltasStore):
                     txn, into_table, all_dest_keyvalues, src_row
                 )
 
-    def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
+    async def get_changes_room_total_events_and_bytes(
+        self, min_pos: int, max_pos: int
+    ) -> Dict[str, Dict[str, int]]:
         """Fetches the counts of events in the given range of stream IDs.
 
         Args:
-            min_pos (int)
-            max_pos (int)
+            min_pos
+            max_pos
 
         Returns:
-            Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field
-            changes.
+            Mapping of room ID to field changes.
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "stats_incremental_total_events_and_bytes",
             self.get_changes_room_total_events_and_bytes_txn,
             min_pos,
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index aaf225894e..2e95518752 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,18 +39,27 @@ what sort order was used:
 import abc
 import logging
 from collections import namedtuple
-from typing import Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 from twisted.internet import defer
 
+from synapse.api.filtering import Filter
+from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.engines import PostgresEngine
-from synapse.types import RoomStreamToken
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.types import Collection, RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -68,8 +77,12 @@ _EventDictReturn = namedtuple(
 
 
 def generate_pagination_where_clause(
-    direction, column_names, from_token, to_token, engine
-):
+    direction: str,
+    column_names: Tuple[str, str],
+    from_token: Optional[Tuple[Optional[int], int]],
+    to_token: Optional[Tuple[Optional[int], int]],
+    engine: BaseDatabaseEngine,
+) -> str:
     """Creates an SQL expression to bound the columns by the pagination
     tokens.
 
@@ -90,21 +103,19 @@ def generate_pagination_where_clause(
           token, but include those that match the to token.
 
     Args:
-        direction (str): Whether we're paginating backwards("b") or
-            forwards ("f").
-        column_names (tuple[str, str]): The column names to bound. Must *not*
-            be user defined as these get inserted directly into the SQL
-            statement without escapes.
-        from_token (tuple[int, int]|None): The start point for the pagination.
-            This is an exclusive minimum bound if direction is "f", and an
-            inclusive maximum bound if direction is "b".
-        to_token (tuple[int, int]|None): The endpoint point for the pagination.
-            This is an inclusive maximum bound if direction is "f", and an
-            exclusive minimum bound if direction is "b".
+        direction: Whether we're paginating backwards("b") or forwards ("f").
+        column_names: The column names to bound. Must *not* be user defined as
+            these get inserted directly into the SQL statement without escapes.
+        from_token: The start point for the pagination. This is an exclusive
+            minimum bound if direction is "f", and an inclusive maximum bound if
+            direction is "b".
+        to_token: The endpoint point for the pagination. This is an inclusive
+            maximum bound if direction is "f", and an exclusive minimum bound if
+            direction is "b".
         engine: The database engine to generate the clauses for
 
     Returns:
-        str: The sql expression
+        The sql expression
     """
     assert direction in ("b", "f")
 
@@ -132,7 +143,12 @@ def generate_pagination_where_clause(
     return " AND ".join(where_clause)
 
 
-def _make_generic_sql_bound(bound, column_names, values, engine):
+def _make_generic_sql_bound(
+    bound: str,
+    column_names: Tuple[str, str],
+    values: Tuple[Optional[int], int],
+    engine: BaseDatabaseEngine,
+) -> str:
     """Create an SQL expression that bounds the given column names by the
     values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
 
@@ -142,18 +158,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
     out manually.
 
     Args:
-        bound (str): The comparison operator to use. One of ">", "<", ">=",
+        bound: The comparison operator to use. One of ">", "<", ">=",
             "<=", where the values are on the left and columns on the right.
-        names (tuple[str, str]): The column names. Must *not* be user defined
+        names: The column names. Must *not* be user defined
             as these get inserted directly into the SQL statement without
             escapes.
-        values (tuple[int|None, int]): The values to bound the columns by. If
+        values: The values to bound the columns by. If
             the first value is None then only creates a bound on the second
             column.
         engine: The database engine to generate the SQL for
 
     Returns:
-        str
+        The SQL statement
     """
 
     assert bound in (">", "<", ">=", "<=")
@@ -193,7 +209,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
     )
 
 
-def filter_to_clause(event_filter):
+def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
     # "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -251,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super(StreamWorkerStore, self).__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
@@ -284,40 +300,41 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         self._stream_order_on_start = self.get_room_max_stream_ordering()
 
     @abc.abstractmethod
-    def get_room_max_stream_ordering(self):
+    def get_room_max_stream_ordering(self) -> int:
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def get_room_min_stream_ordering(self):
+    def get_room_min_stream_ordering(self) -> int:
         raise NotImplementedError()
 
-    @defer.inlineCallbacks
-    def get_room_events_stream_for_rooms(
-        self, room_ids, from_key, to_key, limit=0, order="DESC"
-    ):
+    async def get_room_events_stream_for_rooms(
+        self,
+        room_ids: Collection[str],
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
+        limit: int = 0,
+        order: str = "DESC",
+    ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
-            room_id (str)
-            from_key (str): Token from which no events are returned before
-            to_key (str): Token from which no events are returned after. (This
+            room_ids
+            from_key: Token from which no events are returned before
+            to_key: Token from which no events are returned after. (This
                 is typically the current stream token)
-            limit (int): Maximum number of events to return
-            order (str): Either "DESC" or "ASC". Determines which events are
+            limit: Maximum number of events to return
+            order: Either "DESC" or "ASC". Determines which events are
                 returned when the result is limited. If "DESC" then the most
                 recent `limit` events are returned, otherwise returns the
                 oldest `limit` events.
 
         Returns:
-            Deferred[dict[str,tuple[list[FrozenEvent], str]]]
-                A map from room id to a tuple containing:
-                    - list of recent events in the room
-                    - stream ordering key for the start of the chunk of events returned.
+            A map from room id to a tuple containing:
+                - list of recent events in the room
+                - stream ordering key for the start of the chunk of events returned.
         """
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-
-        room_ids = yield self._events_stream_cache.get_entities_changed(
-            room_ids, from_id
+        room_ids = self._events_stream_cache.get_entities_changed(
+            room_ids, from_key.stream
         )
 
         if not room_ids:
@@ -326,7 +343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         results = {}
         room_ids = list(room_ids)
         for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
-            res = yield make_deferred_yieldable(
+            res = await make_deferred_yieldable(
                 defer.gatherResults(
                     [
                         run_in_background(
@@ -346,53 +363,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return results
 
-    def get_rooms_that_changed(self, room_ids, from_key):
+    def get_rooms_that_changed(
+        self, room_ids: Collection[str], from_key: RoomStreamToken
+    ) -> Set[str]:
         """Given a list of rooms and a token, return rooms where there may have
         been changes.
-
-        Args:
-            room_ids (list)
-            from_key (str): The room_key portion of a StreamToken
         """
-        from_key = RoomStreamToken.parse_stream_token(from_key).stream
+        from_id = from_key.stream
         return {
             room_id
             for room_id in room_ids
-            if self._events_stream_cache.has_entity_changed(room_id, from_key)
+            if self._events_stream_cache.has_entity_changed(room_id, from_id)
         }
 
-    @defer.inlineCallbacks
-    def get_room_events_stream_for_room(
-        self, room_id, from_key, to_key, limit=0, order="DESC"
-    ):
-
+    async def get_room_events_stream_for_room(
+        self,
+        room_id: str,
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
+        limit: int = 0,
+        order: str = "DESC",
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
-            room_id (str)
-            from_key (str): Token from which no events are returned before
-            to_key (str): Token from which no events are returned after. (This
+            room_id
+            from_key: Token from which no events are returned before
+            to_key: Token from which no events are returned after. (This
                 is typically the current stream token)
-            limit (int): Maximum number of events to return
-            order (str): Either "DESC" or "ASC". Determines which events are
+            limit: Maximum number of events to return
+            order: Either "DESC" or "ASC". Determines which events are
                 returned when the result is limited. If "DESC" then the most
                 recent `limit` events are returned, otherwise returns the
                 oldest `limit` events.
 
         Returns:
-            Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
-            events (in ascending order) and the token from the start of
-            the chunk of events returned.
+            The list of events (in ascending order) and the token from the start
+            of the chunk of events returned.
         """
         if from_key == to_key:
             return [], from_key
 
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+        from_id = from_key.stream
+        to_id = to_key.stream
 
-        has_changed = yield self._events_stream_cache.has_entity_changed(
-            room_id, from_id
-        )
+        has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
 
         if not has_changed:
             return [], from_key
@@ -410,9 +425,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
             return rows
 
-        rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
+        rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
 
-        ret = yield self.get_events_as_list(
+        ret = await self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
 
@@ -422,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             ret.reverse()
 
         if rows:
-            key = "s%d" % min(r.stream_ordering for r in rows)
+            key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
         else:
             # Assume we didn't get anything because there was nothing to
             # get.
@@ -430,10 +445,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return ret, key
 
-    @defer.inlineCallbacks
-    def get_membership_changes_for_user(self, user_id, from_key, to_key):
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+    async def get_membership_changes_for_user(
+        self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
+    ) -> List[EventBase]:
+        from_id = from_key.stream
+        to_id = to_key.stream
 
         if from_key == to_key:
             return []
@@ -460,9 +476,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return rows
 
-        rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
+        rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
 
-        ret = yield self.get_events_as_list(
+        ret = await self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
 
@@ -470,27 +486,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return ret
 
-    @defer.inlineCallbacks
-    def get_recent_events_for_room(self, room_id, limit, end_token):
+    async def get_recent_events_for_room(
+        self, room_id: str, limit: int, end_token: RoomStreamToken
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Get the most recent events in the room in topological ordering.
 
         Args:
-            room_id (str)
-            limit (int)
-            end_token (str): The stream token representing now.
+            room_id
+            limit
+            end_token: The stream token representing now.
 
         Returns:
-            Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
-            events and a token pointing to the start of the returned
-            events.
-            The events returned are in ascending order.
+            A list of events and a token pointing to the start of the returned
+            events. The events returned are in ascending order.
         """
 
-        rows, token = yield self.get_recent_event_ids_for_room(
+        rows, token = await self.get_recent_event_ids_for_room(
             room_id, limit, end_token
         )
 
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
 
@@ -498,28 +513,25 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return (events, token)
 
-    @defer.inlineCallbacks
-    def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+    async def get_recent_event_ids_for_room(
+        self, room_id: str, limit: int, end_token: RoomStreamToken
+    ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
         """Get the most recent events in the room in topological ordering.
 
         Args:
-            room_id (str)
-            limit (int)
-            end_token (str): The stream token representing now.
+            room_id
+            limit
+            end_token: The stream token representing now.
 
         Returns:
-            Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
-            _EventDictReturn and a token pointing to the start of the returned
-            events.
-            The events returned are in ascending order.
+            A list of _EventDictReturn and a token pointing to the start of the
+            returned events. The events returned are in ascending order.
         """
         # Allow a zero limit here, and no-op.
         if limit == 0:
             return [], end_token
 
-        end_token = RoomStreamToken.parse(end_token)
-
-        rows, token = yield self.db_pool.runInteraction(
+        rows, token = await self.db_pool.runInteraction(
             "get_recent_event_ids_for_room",
             self._paginate_room_events_txn,
             room_id,
@@ -532,16 +544,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return rows, token
 
-    def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+    async def get_room_event_before_stream_ordering(
+        self, room_id: str, stream_ordering: int
+    ) -> Tuple[int, int, str]:
         """Gets details of the first event in a room at or before a stream ordering
 
         Args:
-            room_id (str):
-            stream_ordering (int):
+            room_id:
+            stream_ordering:
 
         Returns:
-            Deferred[(int, int, str)]:
-                (stream ordering, topological ordering, event_id)
+            A tuple of (stream ordering, topological ordering, event_id)
         """
 
         def _f(txn):
@@ -556,7 +569,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             txn.execute(sql, (room_id, stream_ordering))
             return txn.fetchone()
 
-        return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
+        return await self.db_pool.runInteraction(
+            "get_room_event_before_stream_ordering", _f
+        )
 
     async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
         """Returns the current token for rooms stream.
@@ -574,57 +589,80 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             )
             return "t%d-%d" % (topo, token)
 
-    def get_stream_token_for_event(self, event_id):
+    async def get_stream_id_for_event(self, event_id: str) -> int:
+        """The stream ID for an event
+        Args:
+            event_id: The id of the event to look up a stream token for.
+        Raises:
+            StoreError if the event wasn't in the database.
+        Returns:
+            A stream ID.
+        """
+        return await self.db_pool.runInteraction(
+            "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
+        )
+
+    def get_stream_id_for_event_txn(
+        self, txn: LoggingTransaction, event_id: str, allow_none=False,
+    ) -> int:
+        return self.db_pool.simple_select_one_onecol_txn(
+            txn=txn,
+            table="events",
+            keyvalues={"event_id": event_id},
+            retcol="stream_ordering",
+            allow_none=allow_none,
+        )
+
+    async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
         """The stream token for an event
         Args:
-            event_id(str): The id of the event to look up a stream token for.
+            event_id: The id of the event to look up a stream token for.
         Raises:
             StoreError if the event wasn't in the database.
         Returns:
-            A deferred "s%d" stream token.
+            A stream token.
         """
-        return self.db_pool.simple_select_one_onecol(
-            table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
-        ).addCallback(lambda row: "s%d" % (row,))
+        stream_id = await self.get_stream_id_for_event(event_id)
+        return RoomStreamToken(None, stream_id)
 
-    def get_topological_token_for_event(self, event_id):
+    async def get_topological_token_for_event(self, event_id: str) -> str:
         """The stream token for an event
         Args:
-            event_id(str): The id of the event to look up a stream token for.
+            event_id: The id of the event to look up a stream token for.
         Raises:
             StoreError if the event wasn't in the database.
         Returns:
-            A deferred "t%d-%d" topological token.
+            A "t%d-%d" topological token.
         """
-        return self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             table="events",
             keyvalues={"event_id": event_id},
             retcols=("stream_ordering", "topological_ordering"),
             desc="get_topological_token_for_event",
-        ).addCallback(
-            lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
         )
+        return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
 
-    def get_max_topological_token(self, room_id, stream_key):
+    async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
         """Get the max topological token in a room before the given stream
         ordering.
 
         Args:
-            room_id (str)
-            stream_key (int)
+            room_id
+            stream_key
 
         Returns:
-            Deferred[int]
+            The maximum topological token.
         """
         sql = (
             "SELECT coalesce(max(topological_ordering), 0) FROM events"
             " WHERE room_id = ? AND stream_ordering < ?"
         )
-        return self.db_pool.execute(
+        row = await self.db_pool.execute(
             "get_max_topological_token", None, sql, room_id, stream_key
-        ).addCallback(lambda r: r[0][0] if r else 0)
+        )
+        return row[0][0] if row else 0
 
-    def _get_max_topological_txn(self, txn, room_id):
+    def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
         txn.execute(
             "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
             (room_id,),
@@ -634,16 +672,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return rows[0][0] if rows else 0
 
     @staticmethod
-    def _set_before_and_after(events, rows, topo_order=True):
+    def _set_before_and_after(
+        events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
+    ):
         """Inserts ordering information to events' internal metadata from
         the DB rows.
 
         Args:
-            events (list[FrozenEvent])
-            rows (list[_EventDictReturn])
-            topo_order (bool): Whether the events were ordered topologically
-                or by stream ordering. If true then all rows should have a non
-                null topological_ordering.
+            events
+            rows
+            topo_order: Whether the events were ordered topologically or by stream
+                ordering. If true then all rows should have a non null
+                topological_ordering.
         """
         for event, row in zip(events, rows):
             stream = row.stream_ordering
@@ -656,25 +696,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             internal.after = str(RoomStreamToken(topo, stream))
             internal.order = (int(topo) if topo else 0, int(stream))
 
-    @defer.inlineCallbacks
-    def get_events_around(
-        self, room_id, event_id, before_limit, after_limit, event_filter=None
-    ):
+    async def get_events_around(
+        self,
+        room_id: str,
+        event_id: str,
+        before_limit: int,
+        after_limit: int,
+        event_filter: Optional[Filter] = None,
+    ) -> dict:
         """Retrieve events and pagination tokens around a given event in a
         room.
-
-        Args:
-            room_id (str)
-            event_id (str)
-            before_limit (int)
-            after_limit (int)
-            event_filter (Filter|None)
-
-        Returns:
-            dict
         """
 
-        results = yield self.db_pool.runInteraction(
+        results = await self.db_pool.runInteraction(
             "get_events_around",
             self._get_events_around_txn,
             room_id,
@@ -684,11 +718,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             event_filter,
         )
 
-        events_before = yield self.get_events_as_list(
+        events_before = await self.get_events_as_list(
             list(results["before"]["event_ids"]), get_prev_content=True
         )
 
-        events_after = yield self.get_events_as_list(
+        events_after = await self.get_events_as_list(
             list(results["after"]["event_ids"]), get_prev_content=True
         )
 
@@ -700,17 +734,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         }
 
     def _get_events_around_txn(
-        self, txn, room_id, event_id, before_limit, after_limit, event_filter
-    ):
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        event_id: str,
+        before_limit: int,
+        after_limit: int,
+        event_filter: Optional[Filter],
+    ) -> dict:
         """Retrieves event_ids and pagination tokens around a given event in a
         room.
 
         Args:
-            room_id (str)
-            event_id (str)
-            before_limit (int)
-            after_limit (int)
-            event_filter (Filter|None)
+            room_id
+            event_id
+            before_limit
+            after_limit
+            event_filter
 
         Returns:
             dict
@@ -723,6 +763,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             retcols=["stream_ordering", "topological_ordering"],
         )
 
+        # This cannot happen as `allow_none=False`.
+        assert results is not None
+
         # Paginating backwards includes the event at the token, but paginating
         # forward doesn't.
         before_token = RoomStreamToken(
@@ -758,22 +801,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             "after": {"event_ids": events_after, "token": end_token},
         }
 
-    @defer.inlineCallbacks
-    def get_all_new_events_stream(self, from_id, current_id, limit):
+    async def get_all_new_events_stream(
+        self, from_id: int, current_id: int, limit: int
+    ) -> Tuple[int, List[EventBase]]:
         """Get all new events
 
          Returns all events with from_id < stream_ordering <= current_id.
 
          Args:
-             from_id (int):  the stream_ordering of the last event we processed
-             current_id (int):  the stream_ordering of the most recently processed event
-             limit (int): the maximum number of events to return
+             from_id:  the stream_ordering of the last event we processed
+             current_id:  the stream_ordering of the most recently processed event
+             limit: the maximum number of events to return
 
          Returns:
-             Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
-             `next_id` is the next value to pass as `from_id` (it will either be the
-             stream_ordering of the last returned event, or, if fewer than `limit` events
-             were found, `current_id`.
+             A tuple of (next_id, events), where `next_id` is the next value to
+             pass as `from_id` (it will either be the stream_ordering of the
+             last returned event, or, if fewer than `limit` events were found,
+             the `current_id`).
          """
 
         def get_all_new_events_stream_txn(txn):
@@ -795,11 +839,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return upper_bound, [row[1] for row in rows]
 
-        upper_bound, event_ids = yield self.db_pool.runInteraction(
+        upper_bound, event_ids = await self.db_pool.runInteraction(
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
-        events = yield self.get_events_as_list(event_ids)
+        events = await self.get_events_as_list(event_ids)
 
         return upper_bound, events
 
@@ -817,21 +861,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="get_federation_out_pos",
         )
 
-    async def update_federation_out_pos(self, typ, stream_id):
+    async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
         if self._need_to_reset_federation_stream_positions:
             await self.db_pool.runInteraction(
                 "_reset_federation_positions_txn", self._reset_federation_positions_txn
             )
             self._need_to_reset_federation_stream_positions = False
 
-        return await self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="federation_stream_position",
             keyvalues={"type": typ, "instance_name": self._instance_name},
             updatevalues={"stream_id": stream_id},
             desc="update_federation_out_pos",
         )
 
-    def _reset_federation_positions_txn(self, txn):
+    def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
         """Fiddles with the `federation_stream_position` table to make it match
         the configured federation sender instances during start up.
         """
@@ -870,7 +914,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             GROUP BY type
         """
         txn.execute(sql)
-        min_positions = dict(txn)  # Map from type -> min position
+        min_positions = {typ: pos for typ, pos in txn}  # Map from type -> min position
 
         # Ensure we do actually have some values here
         assert set(min_positions) == {"federation", "events"}
@@ -892,39 +936,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 values={"stream_id": stream_id},
             )
 
-    def has_room_changed_since(self, room_id, stream_id):
+    def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
         return self._events_stream_cache.has_entity_changed(room_id, stream_id)
 
     def _paginate_room_events_txn(
         self,
-        txn,
-        room_id,
-        from_token,
-        to_token=None,
-        direction="b",
-        limit=-1,
-        event_filter=None,
-    ):
+        txn: LoggingTransaction,
+        room_id: str,
+        from_token: RoomStreamToken,
+        to_token: Optional[RoomStreamToken] = None,
+        direction: str = "b",
+        limit: int = -1,
+        event_filter: Optional[Filter] = None,
+    ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
         """Returns list of events before or after a given token.
 
         Args:
             txn
-            room_id (str)
-            from_token (RoomStreamToken): The token used to stream from
-            to_token (RoomStreamToken|None): A token which if given limits the
-                results to only those before
-            direction(char): Either 'b' or 'f' to indicate whether we are
-                paginating forwards or backwards from `from_key`.
-            limit (int): The maximum number of events to return.
-            event_filter (Filter|None): If provided filters the events to
+            room_id
+            from_token: The token used to stream from
+            to_token: A token which if given limits the results to only those before
+            direction: Either 'b' or 'f' to indicate whether we are paginating
+                forwards or backwards from `from_key`.
+            limit: The maximum number of events to return.
+            event_filter: If provided filters the events to
                 those that match the filter.
 
         Returns:
-            Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
-            as a list of _EventDictReturn and a token that points to the end
-            of the result set. If no events are returned then the end of the
-            stream has been reached (i.e. there are no events between
-            `from_token` and `to_token`), or `limit` is zero.
+            A list of _EventDictReturn and a token that points to the end of the
+            result set. If no events are returned then the end of the stream has
+            been reached (i.e. there are no events between `from_token` and
+            `to_token`), or `limit` is zero.
         """
 
         assert int(limit) >= 0
@@ -941,8 +983,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token,
-            to_token=to_token,
+            from_token=from_token.as_tuple(),
+            to_token=to_token.as_tuple() if to_token else None,
             engine=self.database_engine,
         )
 
@@ -1006,37 +1048,36 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             # TODO (erikj): We should work out what to do here instead.
             next_token = to_token if to_token else from_token
 
-        return rows, str(next_token)
+        return rows, next_token
 
-    @defer.inlineCallbacks
-    def paginate_room_events(
-        self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
-    ):
+    async def paginate_room_events(
+        self,
+        room_id: str,
+        from_key: RoomStreamToken,
+        to_key: Optional[RoomStreamToken] = None,
+        direction: str = "b",
+        limit: int = -1,
+        event_filter: Optional[Filter] = None,
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Returns list of events before or after a given token.
 
         Args:
-            room_id (str)
-            from_key (str): The token used to stream from
-            to_key (str|None): A token which if given limits the results to
-                only those before
-            direction(char): Either 'b' or 'f' to indicate whether we are
-                paginating forwards or backwards from `from_key`.
-            limit (int): The maximum number of events to return.
-            event_filter (Filter|None): If provided filters the events to
-                those that match the filter.
+            room_id
+            from_key: The token used to stream from
+            to_key: A token which if given limits the results to only those before
+            direction: Either 'b' or 'f' to indicate whether we are paginating
+                forwards or backwards from `from_key`.
+            limit: The maximum number of events to return.
+            event_filter: If provided filters the events to those that match the filter.
 
         Returns:
-            tuple[list[FrozenEvent], str]: Returns the results as a list of
-            events and a token that points to the end of the result set. If no
-            events are returned then the end of the stream has been reached
-            (i.e. there are no events between `from_key` and `to_key`).
+            The results as a list of events and a token that points to the end
+            of the result set. If no events are returned then the end of the
+            stream has been reached (i.e. there are no events between `from_key`
+            and `to_key`).
         """
 
-        from_key = RoomStreamToken.parse(from_key)
-        if to_key:
-            to_key = RoomStreamToken.parse(to_key)
-
-        rows, token = yield self.db_pool.runInteraction(
+        rows, token = await self.db_pool.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,
             room_id,
@@ -1047,7 +1088,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             event_filter,
         )
 
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
 
@@ -1057,8 +1098,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
 
 class StreamStore(StreamWorkerStore):
-    def get_room_max_stream_ordering(self):
+    def get_room_max_stream_ordering(self) -> int:
         return self._stream_id_gen.get_current_token()
 
-    def get_room_min_stream_ordering(self):
+    def get_room_min_stream_ordering(self) -> int:
         return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e4e0a0c433..96ffe26cc9 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,11 +17,10 @@
 import logging
 from typing import Dict, List, Tuple
 
-from canonicaljson import json
-
 from synapse.storage._base import db_to_json
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.types import JsonDict
+from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -44,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
         )
 
-        tags_by_room = {}
+        tags_by_room = {}  # type: Dict[str, Dict[str, JsonDict]]
         for row in rows:
             room_tags = tags_by_room.setdefault(row["room_id"], {})
             room_tags[row["tag"]] = db_to_json(row["content"])
@@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
                 txn.execute(sql, (user_id, room_id))
                 tags = []
                 for tag, content in txn:
-                    tags.append(json.dumps(tag) + ":" + content)
+                    tags.append(json_encoder.encode(tag) + ":" + content)
                 tag_json = "{" + ",".join(tags) + "}"
                 results.append((stream_id, (user_id, room_id, tag_json)))
 
@@ -124,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
     async def get_updated_tags(
         self, user_id: str, stream_id: int
-    ) -> Dict[str, List[str]]:
+    ) -> Dict[str, Dict[str, JsonDict]]:
         """Get all the tags for the rooms where the tags have changed since the
         given version
 
@@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore):
         Returns:
             The next account data ID.
         """
-        content_json = json.dumps(content)
+        content_json = json_encoder.encode(content)
 
         def add_tag_txn(txn, next_id):
             self.db_pool.simple_upsert_txn(
@@ -211,7 +210,7 @@ class TagsStore(TagsWorkerStore):
             )
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
@@ -233,7 +232,7 @@ class TagsStore(TagsWorkerStore):
             txn.execute(sql, (user_id, room_id, tag))
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 52668dbdf9..091367006e 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -15,12 +15,15 @@
 
 import logging
 from collections import namedtuple
+from typing import Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import JsonDict
 from synapse.util.caches.expiringcache import ExpiringCache
 
 db_binary_type = memoryview
@@ -55,21 +58,23 @@ class TransactionStore(SQLBaseStore):
             expiry_ms=5 * 60 * 1000,
         )
 
-    def get_received_txn_response(self, transaction_id, origin):
+    async def get_received_txn_response(
+        self, transaction_id: str, origin: str
+    ) -> Optional[Tuple[int, JsonDict]]:
         """For an incoming transaction from a given origin, check if we have
         already responded to it. If so, return the response code and response
         body (as a dict).
 
         Args:
-            transaction_id (str)
-            origin(str)
+            transaction_id
+            origin
 
         Returns:
-            tuple: None if we have not previously responded to
-            this transaction or a 2-tuple of (int, dict)
+            None if we have not previously responded to this transaction or a
+            2-tuple of (int, dict)
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_received_txn_response",
             self._get_received_txn_response,
             transaction_id,
@@ -98,20 +103,21 @@ class TransactionStore(SQLBaseStore):
         else:
             return None
 
-    def set_received_txn_response(self, transaction_id, origin, code, response_dict):
-        """Persist the response we returened for an incoming transaction, and
+    async def set_received_txn_response(
+        self, transaction_id: str, origin: str, code: int, response_dict: JsonDict
+    ) -> None:
+        """Persist the response we returned for an incoming transaction, and
         should return for subsequent transactions with the same transaction_id
         and origin.
 
         Args:
-            txn
-            transaction_id (str)
-            origin (str)
-            code (int)
-            response_json (str)
+            transaction_id: The incoming transaction ID.
+            origin: The origin server.
+            code: The response code.
+            response_dict: The response, to be encoded into JSON.
         """
 
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             table="received_transactions",
             values={
                 "transaction_id": transaction_id,
@@ -159,26 +165,32 @@ class TransactionStore(SQLBaseStore):
             allow_none=True,
         )
 
-        if result and result["retry_last_ts"] > 0:
+        # check we have a row and retry_last_ts is not null or zero
+        # (retry_last_ts can't be negative)
+        if result and result["retry_last_ts"]:
             return result
         else:
             return None
 
-    def set_destination_retry_timings(
-        self, destination, failure_ts, retry_last_ts, retry_interval
-    ):
+    async def set_destination_retry_timings(
+        self,
+        destination: str,
+        failure_ts: Optional[int],
+        retry_last_ts: int,
+        retry_interval: int,
+    ) -> None:
         """Sets the current retry timings for a given destination.
         Both timings should be zero if retrying is no longer occuring.
 
         Args:
-            destination (str)
-            failure_ts (int|None) - when the server started failing (ms since epoch)
-            retry_last_ts (int) - time of last retry attempt in unix epoch ms
-            retry_interval (int) - how long until next retry in ms
+            destination
+            failure_ts: when the server started failing (ms since epoch)
+            retry_last_ts: time of last retry attempt in unix epoch ms
+            retry_interval: how long until next retry in ms
         """
 
         self._destination_retry_cache.pop(destination, None)
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "set_destination_retry_timings",
             self._set_destination_retry_timings,
             destination,
@@ -254,13 +266,149 @@ class TransactionStore(SQLBaseStore):
             "cleanup_transactions", self._cleanup_transactions
         )
 
-    def _cleanup_transactions(self):
+    async def _cleanup_transactions(self) -> None:
         now = self._clock.time_msec()
         month_ago = now - 30 * 24 * 60 * 60 * 1000
 
         def _cleanup_transactions_txn(txn):
             txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "_cleanup_transactions", _cleanup_transactions_txn
         )
+
+    async def store_destination_rooms_entries(
+        self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+    ) -> None:
+        """
+        Updates or creates `destination_rooms` entries in batch for a single event.
+
+        Args:
+            destinations: list of destinations
+            room_id: the room_id of the event
+            stream_ordering: the stream_ordering of the event
+        """
+
+        return await self.db_pool.runInteraction(
+            "store_destination_rooms_entries",
+            self._store_destination_rooms_entries_txn,
+            destinations,
+            room_id,
+            stream_ordering,
+        )
+
+    def _store_destination_rooms_entries_txn(
+        self,
+        txn: LoggingTransaction,
+        destinations: Iterable[str],
+        room_id: str,
+        stream_ordering: int,
+    ) -> None:
+
+        # ensure we have a `destinations` row for this destination, as there is
+        # a foreign key constraint.
+        if isinstance(self.database_engine, PostgresEngine):
+            q = """
+                INSERT INTO destinations (destination)
+                    VALUES (?)
+                    ON CONFLICT DO NOTHING;
+            """
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            q = """
+                INSERT OR IGNORE INTO destinations (destination)
+                    VALUES (?);
+            """
+        else:
+            raise RuntimeError("Unknown database engine")
+
+        txn.execute_batch(q, ((destination,) for destination in destinations))
+
+        rows = [(destination, room_id) for destination in destinations]
+
+        self.db_pool.simple_upsert_many_txn(
+            txn,
+            "destination_rooms",
+            ["destination", "room_id"],
+            rows,
+            ["stream_ordering"],
+            [(stream_ordering,)] * len(rows),
+        )
+
+    async def get_destination_last_successful_stream_ordering(
+        self, destination: str
+    ) -> Optional[int]:
+        """
+        Gets the stream ordering of the PDU most-recently successfully sent
+        to the specified destination, or None if this information has not been
+        tracked yet.
+
+        Args:
+            destination: the destination to query
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            "destinations",
+            {"destination": destination},
+            "last_successful_stream_ordering",
+            allow_none=True,
+            desc="get_last_successful_stream_ordering",
+        )
+
+    async def set_destination_last_successful_stream_ordering(
+        self, destination: str, last_successful_stream_ordering: int
+    ) -> None:
+        """
+        Marks that we have successfully sent the PDUs up to and including the
+        one specified.
+
+        Args:
+            destination: the destination we have successfully sent to
+            last_successful_stream_ordering: the stream_ordering of the most
+                recent successfully-sent PDU
+        """
+        return await self.db_pool.simple_upsert(
+            "destinations",
+            keyvalues={"destination": destination},
+            values={"last_successful_stream_ordering": last_successful_stream_ordering},
+            desc="set_last_successful_stream_ordering",
+        )
+
+    async def get_catch_up_room_event_ids(
+        self, destination: str, last_successful_stream_ordering: int,
+    ) -> List[str]:
+        """
+        Returns at most 50 event IDs and their corresponding stream_orderings
+        that correspond to the oldest events that have not yet been sent to
+        the destination.
+
+        Args:
+            destination: the destination in question
+            last_successful_stream_ordering: the stream_ordering of the
+                most-recently successfully-transmitted event to the destination
+
+        Returns:
+            list of event_ids
+        """
+        return await self.db_pool.runInteraction(
+            "get_catch_up_room_event_ids",
+            self._get_catch_up_room_event_ids_txn,
+            destination,
+            last_successful_stream_ordering,
+        )
+
+    @staticmethod
+    def _get_catch_up_room_event_ids_txn(
+        txn, destination: str, last_successful_stream_ordering: int,
+    ) -> List[str]:
+        q = """
+                SELECT event_id FROM destination_rooms
+                 JOIN events USING (stream_ordering)
+                WHERE destination = ?
+                  AND stream_ordering > ?
+                ORDER BY stream_ordering
+                LIMIT 50
+            """
+        txn.execute(
+            q, (destination, last_successful_stream_ordering),
+        )
+        event_ids = [row[0] for row in txn]
+        return event_ids
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 37276f73f8..3b9211a6d2 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,18 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import attr
-from canonicaljson import json
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict
-from synapse.util import stringutils as stringutils
+from synapse.util import json_encoder, stringutils
 
 
-@attr.s
+@attr.s(slots=True)
 class UIAuthSessionData:
     session_id = attr.ib(type=str)
     # The dictionary from the client root level, not the 'auth' key.
@@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore):
             StoreError if a unique session ID cannot be generated.
         """
         # The clientdict gets stored as JSON.
-        clientdict_json = json.dumps(clientdict)
+        clientdict_json = json_encoder.encode(clientdict)
 
         # autogen a session ID and try to create it. We may clash, so just
         # try a few times till one goes through, giving up eventually.
@@ -143,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore):
             await self.db_pool.simple_upsert(
                 table="ui_auth_sessions_credentials",
                 keyvalues={"session_id": session_id, "stage_type": stage_type},
-                values={"result": json.dumps(result)},
+                values={"result": json_encoder.encode(result)},
                 desc="mark_ui_auth_stage_complete",
             )
         except self.db_pool.engine.module.IntegrityError:
@@ -184,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore):
                 The dictionary from the client root level, not the 'auth' key.
         """
         # The clientdict gets stored as JSON.
-        clientdict_json = json.dumps(clientdict)
+        clientdict_json = json_encoder.encode(clientdict)
 
         await self.db_pool.simple_update_one(
             table="ui_auth_sessions",
@@ -214,14 +214,16 @@ class UIAuthWorkerStore(SQLBaseStore):
             value,
         )
 
-    def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+    def _set_ui_auth_session_data_txn(
+        self, txn: LoggingTransaction, session_id: str, key: str, value: Any
+    ):
         # Get the current value.
         result = self.db_pool.simple_select_one_txn(
             txn,
             table="ui_auth_sessions",
             keyvalues={"session_id": session_id},
             retcols=("serverdict",),
-        )
+        )  # type: Dict[str, Any]  # type: ignore
 
         # Update it and add it back to the database.
         serverdict = db_to_json(result["serverdict"])
@@ -231,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore):
             txn,
             table="ui_auth_sessions",
             keyvalues={"session_id": session_id},
-            updatevalues={"serverdict": json.dumps(serverdict)},
+            updatevalues={"serverdict": json_encoder.encode(serverdict)},
         )
 
     async def get_ui_auth_session_data(
@@ -258,9 +260,37 @@ class UIAuthWorkerStore(SQLBaseStore):
 
         return serverdict.get(key, default)
 
+    async def add_user_agent_ip_to_ui_auth_session(
+        self, session_id: str, user_agent: str, ip: str,
+    ):
+        """Add the given user agent / IP to the tracking table
+        """
+        await self.db_pool.simple_upsert(
+            table="ui_auth_sessions_ips",
+            keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
+            values={},
+            desc="add_user_agent_ip_to_ui_auth_session",
+        )
+
+    async def get_user_agents_ips_to_ui_auth_session(
+        self, session_id: str,
+    ) -> List[Tuple[str, str]]:
+        """Get the given user agents / IPs used during the ui auth process
+
+        Returns:
+            List of user_agent/ip pairs
+        """
+        rows = await self.db_pool.simple_select_list(
+            table="ui_auth_sessions_ips",
+            keyvalues={"session_id": session_id},
+            retcols=("user_agent", "ip"),
+            desc="get_user_agents_ips_to_ui_auth_session",
+        )
+        return [(row["user_agent"], row["ip"]) for row in rows]
+
 
 class UIAuthStore(UIAuthWorkerStore):
-    def delete_old_ui_auth_sessions(self, expiration_time: int):
+    async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
         """
         Remove sessions which were last used earlier than the expiration time.
 
@@ -269,18 +299,29 @@ class UIAuthStore(UIAuthWorkerStore):
                 This is an epoch time in milliseconds.
 
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_old_ui_auth_sessions",
             self._delete_old_ui_auth_sessions_txn,
             expiration_time,
         )
 
-    def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+    def _delete_old_ui_auth_sessions_txn(
+        self, txn: LoggingTransaction, expiration_time: int
+    ):
         # Get the expired sessions.
         sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
         txn.execute(sql, [expiration_time])
         session_ids = [r[0] for r in txn.fetchall()]
 
+        # Delete the corresponding IP/user agents.
+        self.db_pool.simple_delete_many_txn(
+            txn,
+            table="ui_auth_sessions_ips",
+            column="session_id",
+            iterable=session_ids,
+            keyvalues={},
+        )
+
         # Delete the corresponding completed credentials.
         self.db_pool.simple_delete_many_txn(
             txn,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index af21fe457a..f2f9a5799a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,6 +15,7 @@
 
 import logging
 import re
+from typing import Any, Dict, Iterable, Optional, Set, Tuple
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.database import DatabasePool
@@ -364,10 +365,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         return False
 
-    def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
+    async def update_profile_in_user_dir(
+        self, user_id: str, display_name: str, avatar_url: str
+    ) -> None:
         """
         Update or add a user's profile in the user directory.
         """
+        # If the display name or avatar URL are unexpected types, overwrite them.
+        if not isinstance(display_name, str):
+            display_name = None
+        if not isinstance(avatar_url, str):
+            avatar_url = None
 
         def _update_profile_in_user_dir_txn(txn):
             new_entry = self.db_pool.simple_upsert_txn(
@@ -457,17 +465,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_profile_in_user_dir", _update_profile_in_user_dir_txn
         )
 
-    def add_users_who_share_private_room(self, room_id, user_id_tuples):
+    async def add_users_who_share_private_room(
+        self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
+    ) -> None:
         """Insert entries into the users_who_share_private_rooms table. The first
         user should be a local user.
 
         Args:
-            room_id (str)
-            user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+            room_id
+            user_id_tuples: iterable of 2-tuple of user IDs.
         """
 
         def _add_users_who_share_room_txn(txn):
@@ -483,17 +493,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 value_values=None,
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
 
-    def add_users_in_public_rooms(self, room_id, user_ids):
+    async def add_users_in_public_rooms(
+        self, room_id: str, user_ids: Iterable[str]
+    ) -> None:
         """Insert entries into the users_who_share_private_rooms table. The first
         user should be a local user.
 
         Args:
-            room_id (str)
-            user_ids (list[str])
+            room_id
+            user_ids
         """
 
         def _add_users_in_public_rooms_txn(txn):
@@ -507,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 value_values=None,
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "add_users_in_public_rooms", _add_users_in_public_rooms_txn
         )
 
-    def delete_all_from_user_dir(self):
+    async def delete_all_from_user_dir(self) -> None:
         """Delete the entire user directory
         """
 
@@ -522,13 +534,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             txn.execute("DELETE FROM users_who_share_private_rooms")
             txn.call_after(self.get_user_in_directory.invalidate_all)
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )
 
     @cached()
-    def get_user_in_directory(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
             retcols=("display_name", "avatar_url"),
@@ -536,8 +548,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             desc="get_user_in_directory",
         )
 
-    def update_user_directory_stream_pos(self, stream_id):
-        return self.db_pool.simple_update_one(
+    async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+        await self.db_pool.simple_update_one(
             table="user_directory_stream_pos",
             keyvalues={},
             updatevalues={"stream_id": stream_id},
@@ -554,7 +566,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(UserDirectoryStore, self).__init__(database, db_conn, hs)
 
-    def remove_from_user_dir(self, user_id):
+    async def remove_from_user_dir(self, user_id: str) -> None:
         def _remove_from_user_dir_txn(txn):
             self.db_pool.simple_delete_txn(
                 txn, table="user_directory", keyvalues={"user_id": user_id}
@@ -577,7 +589,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             )
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "remove_from_user_dir", _remove_from_user_dir_txn
         )
 
@@ -604,14 +616,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
 
         return user_ids
 
-    def remove_user_who_share_room(self, user_id, room_id):
+    async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
         """
         Deletes entries in the users_who_share_*_rooms table. The first
         user should be a local user.
 
         Args:
-            user_id (str)
-            room_id (str)
+            user_id
+            room_id
         """
 
         def _remove_user_who_share_room_txn(txn):
@@ -631,7 +643,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )
 
@@ -663,8 +675,50 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         users.update(rows)
         return list(users)
 
-    def get_user_directory_stream_pos(self):
-        return self.db_pool.simple_select_one_onecol(
+    @cached()
+    async def get_shared_rooms_for_users(
+        self, user_id: str, other_user_id: str
+    ) -> Set[str]:
+        """
+        Returns the rooms that a local user shares with another local or remote user.
+
+        Args:
+            user_id: The MXID of a local user
+            other_user_id: The MXID of the other user
+
+        Returns:
+            A set of room ID's that the users share.
+        """
+
+        def _get_shared_rooms_for_users_txn(txn):
+            txn.execute(
+                """
+                SELECT p1.room_id
+                FROM users_in_public_rooms as p1
+                INNER JOIN users_in_public_rooms as p2
+                    ON p1.room_id = p2.room_id
+                    AND p1.user_id = ?
+                    AND p2.user_id = ?
+                UNION
+                SELECT room_id
+                FROM users_who_share_private_rooms
+                WHERE
+                    user_id = ?
+                    AND other_user_id = ?
+                """,
+                (user_id, other_user_id, user_id, other_user_id),
+            )
+            rows = self.db_pool.cursor_to_dict(txn)
+            return rows
+
+        rows = await self.db_pool.runInteraction(
+            "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+        )
+
+        return {row["room_id"] for row in rows}
+
+    async def get_user_directory_stream_pos(self) -> int:
+        return await self.db_pool.simple_select_one_onecol(
             table="user_directory_stream_pos",
             keyvalues={},
             retcol="stream_id",
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index ab6cb2c1f6..2f7c95fc74 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -13,35 +13,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import operator
-
 from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
 
 class UserErasureWorkerStore(SQLBaseStore):
     @cached()
-    def is_user_erased(self, user_id):
+    async def is_user_erased(self, user_id: str) -> bool:
         """
         Check if the given user id has requested erasure
 
         Args:
-            user_id (str): full user id to check
+            user_id: full user id to check
 
         Returns:
-            Deferred[bool]: True if the user has requested erasure
+            True if the user has requested erasure
         """
-        return self.db_pool.simple_select_onecol(
+        result = await self.db_pool.simple_select_onecol(
             table="erased_users",
             keyvalues={"user_id": user_id},
             retcol="1",
             desc="is_user_erased",
-        ).addCallback(operator.truth)
+        )
+        return bool(result)
 
-    @cachedList(
-        cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
-    )
-    def are_users_erased(self, user_ids):
+    @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
+    async def are_users_erased(self, user_ids):
         """
         Checks which users in a list have requested erasure
 
@@ -49,14 +46,14 @@ class UserErasureWorkerStore(SQLBaseStore):
             user_ids (iterable[str]): full user id to check
 
         Returns:
-            Deferred[dict[str, bool]]:
+            dict[str, bool]:
                 for each user, whether the user has requested erasure.
         """
         # this serves the dual purpose of (a) making sure we can do len and
         # iterate it multiple times, and (b) avoiding duplicates.
         user_ids = tuple(set(user_ids))
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="erased_users",
             column="user_id",
             iterable=user_ids,
@@ -65,12 +62,11 @@ class UserErasureWorkerStore(SQLBaseStore):
         )
         erased_users = {row["user_id"] for row in rows}
 
-        res = {u: u in erased_users for u in user_ids}
-        return res
+        return {u: u in erased_users for u in user_ids}
 
 
 class UserErasureStore(UserErasureWorkerStore):
-    def mark_user_erased(self, user_id: str) -> None:
+    async def mark_user_erased(self, user_id: str) -> None:
         """Indicate that user_id wishes their message history to be erased.
 
         Args:
@@ -88,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
 
             self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
 
-        return self.db_pool.runInteraction("mark_user_erased", f)
+        await self.db_pool.runInteraction("mark_user_erased", f)
 
-    def mark_user_not_erased(self, user_id: str) -> None:
+    async def mark_user_not_erased(self, user_id: str) -> None:
         """Indicate that user_id is no longer erased.
 
         Args:
@@ -110,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
 
             self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
 
-        return self.db_pool.runInteraction("mark_user_not_erased", f)
+        await self.db_pool.runInteraction("mark_user_not_erased", f)
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7f104ad936..e924f1ca3b 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -17,8 +17,6 @@ import logging
 from collections import namedtuple
 from typing import Dict, Iterable, List, Set, Tuple
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
@@ -103,7 +101,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     @cached(max_entries=10000, iterable=True)
-    def get_state_group_delta(self, state_group):
+    async def get_state_group_delta(self, state_group):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
@@ -135,7 +133,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_state_group_delta", _get_state_group_delta_txn
         )
 
@@ -367,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 fetched_keys=non_member_types,
             )
 
-    def store_state_group(
+    async def store_state_group(
         self, event_id, room_id, prev_group, delta_ids, current_state_ids
-    ):
+    ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
         Args:
@@ -383,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 to event_id.
 
         Returns:
-            Deferred[int]: The state group ID
+            The state group ID
         """
 
         def _store_state_group_txn(txn):
@@ -484,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
             return state_group
 
-        return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
+        return await self.db_pool.runInteraction(
+            "store_state_group", _store_state_group_txn
+        )
 
-    def purge_unreferenced_state_groups(
+    async def purge_unreferenced_state_groups(
         self, room_id: str, state_groups_to_delete
-    ) -> defer.Deferred:
+    ) -> None:
         """Deletes no longer referenced state groups and de-deltas any state
         groups that reference them.
 
@@ -499,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 to delete.
         """
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "purge_unreferenced_state_groups",
             self._purge_unreferenced_state_groups,
             room_id,
@@ -594,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return {row["state_group"]: row["prev_state_group"] for row in rows}
 
-    def purge_room_state(self, room_id, state_groups_to_delete):
+    async def purge_room_state(self, room_id, state_groups_to_delete):
         """Deletes all record of a room from state tables
 
         Args:
@@ -602,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state_groups_to_delete (list[int]): State groups to delete
         """
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "purge_room_state",
             self._purge_room_state_txn,
             room_id,
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 4769b21529..afd10f7bae 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -22,6 +22,6 @@ logger = logging.getLogger(__name__)
 
 
 @attr.s(slots=True, frozen=True)
-class FetchKeyResult(object):
+class FetchKeyResult:
     verify_key = attr.ib()  # VerifyKey: the key itself
     valid_until_ts = attr.ib()  # int: how long we can use this key for
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index f15b95e633..d89f6ed128 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,7 +18,7 @@
 import itertools
 import logging
 from collections import deque, namedtuple
-from typing import Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
 
 from prometheus_client import Counter, Histogram
 
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
-from synapse.types import StateMap
+from synapse.types import Collection, StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
 
@@ -69,7 +69,7 @@ stale_forward_extremities_counter = Histogram(
 )
 
 
-class _EventPeristenceQueue(object):
+class _EventPeristenceQueue:
     """Queues up events so that they can be persisted in bulk with only one
     concurrent transaction per room.
     """
@@ -172,7 +172,7 @@ class _EventPeristenceQueue(object):
             pass
 
 
-class EventsPersistenceStorage(object):
+class EventsPersistenceStorage:
     """High level interface for handling persisting newly received events.
 
     Takes care of batching up events by room, and calculating the necessary
@@ -185,6 +185,8 @@ class EventsPersistenceStorage(object):
         # store for now.
         self.main_store = stores.main
         self.state_store = stores.state
+
+        assert stores.persist_events
         self.persist_events_store = stores.persist_events
 
         self._clock = hs.get_clock()
@@ -208,7 +210,7 @@ class EventsPersistenceStorage(object):
         Returns:
             the stream ordering of the latest persisted event
         """
-        partitioned = {}
+        partitioned = {}  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
@@ -305,7 +307,9 @@ class EventsPersistenceStorage(object):
                     # Work out the new "current state" for each room.
                     # We do this by working out what the new extremities are and then
                     # calculating the state from that.
-                    events_by_room = {}
+                    events_by_room = (
+                        {}
+                    )  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
                     for event, context in chunk:
                         events_by_room.setdefault(event.room_id, []).append(
                             (event, context)
@@ -436,7 +440,7 @@ class EventsPersistenceStorage(object):
         self,
         room_id: str,
         event_contexts: List[Tuple[EventBase, EventContext]],
-        latest_event_ids: List[str],
+        latest_event_ids: Collection[str],
     ):
         """Calculates the new forward extremities for a room given events to
         persist.
@@ -470,7 +474,7 @@ class EventsPersistenceStorage(object):
         # Remove any events which are prev_events of any existing events.
         existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
             result
-        )
+        )  # type: Collection[str]
         result.difference_update(existing_prevs)
 
         # Finally handle the case where the new events have soft-failed prev
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1c5f305132..4957e77f4c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -19,12 +19,15 @@ import logging
 import os
 import re
 from collections import Counter
-from typing import TextIO
+from typing import Optional, TextIO
 
 import attr
 
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
 
@@ -47,7 +50,28 @@ class UpgradeDatabaseException(PrepareDatabaseException):
     pass
 
 
-def prepare_database(db_conn, database_engine, config, databases=["main", "state"]):
+OUTDATED_SCHEMA_ON_WORKER_ERROR = (
+    "Expected database schema version %i but got %i: run the main synapse process to "
+    "upgrade the database schema before starting worker processes."
+)
+
+EMPTY_DATABASE_ON_WORKER_ERROR = (
+    "Uninitialised database: run the main synapse process to prepare the database "
+    "schema before starting worker processes."
+)
+
+UNAPPLIED_DELTA_ON_WORKER_ERROR = (
+    "Database schema delta %s has not been applied: run the main synapse process to "
+    "upgrade the database schema before starting worker processes."
+)
+
+
+def prepare_database(
+    db_conn: Connection,
+    database_engine: BaseDatabaseEngine,
+    config: Optional[HomeServerConfig],
+    databases: Collection[str] = ["main", "state"],
+):
     """Prepares a physical database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
@@ -57,39 +81,67 @@ def prepare_database(db_conn, database_engine, config, databases=["main", "state
     Args:
         db_conn:
         database_engine:
-        config (synapse.config.homeserver.HomeServerConfig|None):
+        config :
             application config, or None if we are connecting to an existing
             database which we expect to be configured already
-        databases (list[str]): The name of the databases that will be used
+        databases: The name of the databases that will be used
             with this physical database. Defaults to all databases.
     """
 
     try:
         cur = db_conn.cursor()
+
+        # sqlite does not automatically start transactions for DDL / SELECT statements,
+        # so we start one before running anything. This ensures that any upgrades
+        # are either applied completely, or not at all.
+        #
+        # (psycopg2 automatically starts a transaction as soon as we run any statements
+        # at all, so this is redundant but harmless there.)
+        cur.execute("BEGIN TRANSACTION")
+
+        logger.info("%r: Checking existing schema version", databases)
         version_info = _get_or_create_schema_state(cur, database_engine)
 
         if version_info:
             user_version, delta_files, upgraded = version_info
+            logger.info(
+                "%r: Existing schema is %i (+%i deltas)",
+                databases,
+                user_version,
+                len(delta_files),
+            )
 
+            # config should only be None when we are preparing an in-memory SQLite db,
+            # which should be empty.
             if config is None:
-                if user_version != SCHEMA_VERSION:
-                    # If we don't pass in a config file then we are expecting to
-                    # have already upgraded the DB.
-                    raise UpgradeDatabaseException(
-                        "Expected database schema version %i but got %i"
-                        % (SCHEMA_VERSION, user_version)
-                    )
-            else:
-                _upgrade_existing_database(
-                    cur,
-                    user_version,
-                    delta_files,
-                    upgraded,
-                    database_engine,
-                    config,
-                    databases=databases,
+                raise ValueError(
+                    "config==None in prepare_database, but databse is not empty"
                 )
+
+            # if it's a worker app, refuse to upgrade the database, to avoid multiple
+            # workers doing it at once.
+            if config.worker_app is not None and user_version != SCHEMA_VERSION:
+                raise UpgradeDatabaseException(
+                    OUTDATED_SCHEMA_ON_WORKER_ERROR % (SCHEMA_VERSION, user_version)
+                )
+
+            _upgrade_existing_database(
+                cur,
+                user_version,
+                delta_files,
+                upgraded,
+                database_engine,
+                config,
+                databases=databases,
+            )
         else:
+            logger.info("%r: Initialising new database", databases)
+
+            # if it's a worker app, refuse to upgrade the database, to avoid multiple
+            # workers doing it at once.
+            if config and config.worker_app is not None:
+                raise UpgradeDatabaseException(EMPTY_DATABASE_ON_WORKER_ERROR)
+
             _setup_new_database(cur, database_engine, databases=databases)
 
         # check if any of our configured dynamic modules want a database
@@ -295,6 +347,8 @@ def _upgrade_existing_database(
     else:
         assert config
 
+    is_worker = config and config.worker_app is not None
+
     if current_version > SCHEMA_VERSION:
         raise ValueError(
             "Cannot use this database as it is too "
@@ -322,7 +376,7 @@ def _upgrade_existing_database(
     specific_engine_extensions = (".sqlite", ".postgres")
 
     for v in range(start_ver, SCHEMA_VERSION + 1):
-        logger.info("Upgrading schema to v%d", v)
+        logger.info("Applying schema deltas for v%d", v)
 
         # We need to search both the global and per data store schema
         # directories for schema updates.
@@ -382,9 +436,15 @@ def _upgrade_existing_database(
                 continue
 
             root_name, ext = os.path.splitext(file_name)
+
             if ext == ".py":
                 # This is a python upgrade module. We need to import into some
                 # package and then execute its `run_upgrade` function.
+                if is_worker:
+                    raise PrepareDatabaseException(
+                        UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path
+                    )
+
                 module_name = "synapse.storage.v%d_%s" % (v, root_name)
                 with open(absolute_path) as python_file:
                     module = imp.load_source(module_name, absolute_path, python_file)
@@ -399,10 +459,18 @@ def _upgrade_existing_database(
                 continue
             elif ext == ".sql":
                 # A plain old .sql file, just read and execute it
+                if is_worker:
+                    raise PrepareDatabaseException(
+                        UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path
+                    )
                 logger.info("Applying schema %s", relative_path)
                 executescript(cur, absolute_path)
             elif ext == specific_engine_extension and root_name.endswith(".sql"):
                 # A .sql file specific to our engine; just read and execute it
+                if is_worker:
+                    raise PrepareDatabaseException(
+                        UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path
+                    )
                 logger.info("Applying engine-specific schema %s", relative_path)
                 executescript(cur, absolute_path)
             elif ext in specific_engine_extensions and root_name.endswith(".sql"):
@@ -432,6 +500,8 @@ def _upgrade_existing_database(
                 (v, True),
             )
 
+    logger.info("Schema now up to date")
+
 
 def _apply_module_schemas(txn, database_engine, config):
     """Apply the module schemas for the dynamic modules, if any
@@ -568,8 +638,8 @@ def _get_or_create_schema_state(txn, database_engine):
     return None
 
 
-@attr.s()
-class _DirectoryListing(object):
+@attr.s(slots=True)
+class _DirectoryListing:
     """Helper class to store schema file name and the
     absolute path to it.
 
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index 79d9f06e2e..bfa0a9fd06 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -20,7 +20,7 @@ from typing import Set
 logger = logging.getLogger(__name__)
 
 
-class PurgeEventsStorage(object):
+class PurgeEventsStorage:
     """High level interface for purging rooms and event history.
     """
 
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index d471ec9860..cec96ad6a7 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -22,8 +22,8 @@ from synapse.api.errors import SynapseError
 logger = logging.getLogger(__name__)
 
 
-@attr.s
-class PaginationChunk(object):
+@attr.s(slots=True)
+class PaginationChunk:
     """Returned by relation pagination APIs.
 
     Attributes:
@@ -51,7 +51,7 @@ class PaginationChunk(object):
 
 
 @attr.s(frozen=True, slots=True)
-class RelationPaginationToken(object):
+class RelationPaginationToken:
     """Pagination token for relation pagination API.
 
     As the results are in topological order, we can use the
@@ -82,7 +82,7 @@ class RelationPaginationToken(object):
 
 
 @attr.s(frozen=True, slots=True)
-class AggregationPaginationToken(object):
+class AggregationPaginationToken:
     """Pagination token for relation aggregation pagination API.
 
     As the results are order by count and then MAX(stream_ordering) of the
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 534883361f..8f68d968f0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -29,7 +29,7 @@ T = TypeVar("T")
 
 
 @attr.s(slots=True)
-class StateFilter(object):
+class StateFilter:
     """A filter used when querying for state.
 
     Attributes:
@@ -326,14 +326,14 @@ class StateFilter(object):
         return member_filter, non_member_filter
 
 
-class StateGroupStorage(object):
+class StateGroupStorage:
     """High level interface to fetching state for event.
     """
 
     def __init__(self, hs, stores):
         self.stores = stores
 
-    def get_state_group_delta(self, state_group: int):
+    async def get_state_group_delta(self, state_group: int):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
@@ -341,11 +341,11 @@ class StateGroupStorage(object):
             state_group: The state group used to retrieve state deltas.
 
         Returns:
-            Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
+            Tuple[Optional[int], Optional[StateMap[str]]]:
                 (prev_group, delta_ids)
         """
 
-        return self.stores.state.get_state_group_delta(state_group)
+        return await self.stores.state.get_state_group_delta(state_group)
 
     async def get_state_groups_ids(
         self, _room_id: str, event_ids: Iterable[str]
@@ -525,7 +525,7 @@ class StateGroupStorage(object):
             state_filter: The state filter used to fetch state from the database.
 
         Returns:
-            A deferred dict from (type, state_key) -> state_event
+            A dict from (type, state_key) -> state_event
         """
         state_map = await self.get_state_ids_for_events([event_id], state_filter)
         return state_map[event_id]
@@ -546,14 +546,14 @@ class StateGroupStorage(object):
         """
         return self.stores.state._get_state_for_groups(groups, state_filter)
 
-    def store_state_group(
+    async def store_state_group(
         self,
         event_id: str,
         room_id: str,
         prev_group: Optional[int],
         delta_ids: Optional[dict],
         current_state_ids: dict,
-    ):
+    ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
         Args:
@@ -567,8 +567,8 @@ class StateGroupStorage(object):
                 to event_id.
 
         Returns:
-            Deferred[int]: The state group ID
+            The state group ID
         """
-        return self.stores.state.store_state_group(
+        return await self.stores.state.store_state_group(
             event_id, room_id, prev_group, delta_ids, current_state_ids
         )
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e2ddd01290..1de2b91587 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,17 +14,21 @@
 # limitations under the License.
 
 import contextlib
+import heapq
+import logging
 import threading
 from collections import deque
-from typing import Dict, Set, Tuple
+from typing import Dict, List, Set
 
 from typing_extensions import Deque
 
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
+logger = logging.getLogger(__name__)
 
-class IdGenerator(object):
+
+class IdGenerator:
     def __init__(self, db_conn, table, column):
         self._lock = threading.Lock()
         self._next_id = _load_current_id(db_conn, table, column)
@@ -47,6 +51,8 @@ def _load_current_id(db_conn, table, column, step=1):
     Returns:
         int
     """
+    # debug logging for https://github.com/matrix-org/synapse/issues/7968
+    logger.info("initialising stream generator for %s(%s)", table, column)
     cur = db_conn.cursor()
     if step == 1:
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
@@ -58,7 +64,7 @@ def _load_current_id(db_conn, table, column, step=1):
     return (max if step > 0 else min)(current_id, step)
 
 
-class StreamIdGenerator(object):
+class StreamIdGenerator:
     """Used to generate new stream ids when persisting events while keeping
     track of which transactions have been completed.
 
@@ -80,7 +86,7 @@ class StreamIdGenerator(object):
             upwards, -1 to grow downwards.
 
     Usage:
-        with stream_id_gen.get_next() as stream_id:
+        with await stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
 
@@ -95,10 +101,10 @@ class StreamIdGenerator(object):
             )
         self._unfinished_ids = deque()  # type: Deque[int]
 
-    def get_next(self):
+    async def get_next(self):
         """
         Usage:
-            with stream_id_gen.get_next() as stream_id:
+            with await stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
         with self._lock:
@@ -117,10 +123,10 @@ class StreamIdGenerator(object):
 
         return manager()
 
-    def get_next_mult(self, n):
+    async def get_next_mult(self, n):
         """
         Usage:
-            with stream_id_gen.get_next(n) as stream_ids:
+            with await stream_id_gen.get_next(n) as stream_ids:
                 # ... persist events ...
         """
         with self._lock:
@@ -158,63 +164,13 @@ class StreamIdGenerator(object):
 
             return self._current
 
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
 
-class ChainedIdGenerator(object):
-    """Used to generate new stream ids where the stream must be kept in sync
-    with another stream. It generates pairs of IDs, the first element is an
-    integer ID for this stream, the second element is the ID for the stream
-    that this stream needs to be kept in sync with."""
-
-    def __init__(self, chained_generator, db_conn, table, column):
-        self.chained_generator = chained_generator
-        self._table = table
-        self._lock = threading.Lock()
-        self._current_max = _load_current_id(db_conn, table, column)
-        self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
-
-    def get_next(self):
-        """
-        Usage:
-            with stream_id_gen.get_next() as (stream_id, chained_id):
-                # ... persist event ...
-        """
-        with self._lock:
-            self._current_max += 1
-            next_id = self._current_max
-            chained_id = self.chained_generator.get_current_token()
-
-            self._unfinished_ids.append((next_id, chained_id))
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                yield (next_id, chained_id)
-            finally:
-                with self._lock:
-                    self._unfinished_ids.remove((next_id, chained_id))
-
-        return manager()
-
-    def get_current_token(self):
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-        """
-        with self._lock:
-            if self._unfinished_ids:
-                stream_id, chained_id = self._unfinished_ids[0]
-                return stream_id - 1, chained_id
-
-            return self._current_max, self.chained_generator.get_current_token()
-
-    def advance(self, token: int):
-        """Stub implementation for advancing the token when receiving updates
-        over replication; raises an exception as this instance should be the
-        only source of updates.
+        For streams with single writers this is equivalent to
+        `get_current_token`.
         """
-
-        raise Exception(
-            "Attempted to advance token on source for table %r", self._table
-        )
+        return self.get_current_token()
 
 
 class MultiWriterIdGenerator:
@@ -234,6 +190,8 @@ class MultiWriterIdGenerator:
         id_column: Column that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
+        positive: Whether the IDs are positive (true) or negative (false).
+            When using negative IDs we go backwards from -1 to -2, -3, etc.
     """
 
     def __init__(
@@ -245,13 +203,19 @@ class MultiWriterIdGenerator:
         instance_column: str,
         id_column: str,
         sequence_name: str,
+        positive: bool = True,
     ):
         self._db = db
         self._instance_name = instance_name
+        self._positive = positive
+        self._return_factor = 1 if positive else -1
 
         # We lock as some functions may be called from DB threads.
         self._lock = threading.Lock()
 
+        # Note: If we are a negative stream then we still store all the IDs as
+        # positive to make life easier for us, and simply negate the IDs when we
+        # return them.
         self._current_positions = self._load_current_ids(
             db_conn, table, instance_column, id_column
         )
@@ -260,18 +224,46 @@ class MultiWriterIdGenerator:
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
 
+        # Set of local IDs that we've processed that are larger than the current
+        # position, due to there being smaller unpersisted IDs.
+        self._finished_ids = set()  # type: Set[int]
+
+        # We track the max position where we know everything before has been
+        # persisted. This is done by a) looking at the min across all instances
+        # and b) noting that if we have seen a run of persisted positions
+        # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+        #
+        # Note: There is no guarentee that the IDs generated by the sequence
+        # will be gapless; gaps can form when e.g. a transaction was rolled
+        # back. This means that sometimes we won't be able to skip forward the
+        # position even though everything has been persisted. However, since
+        # gaps should be relatively rare it's still worth doing the book keeping
+        # that allows us to skip forwards when there are gapless runs of
+        # positions.
+        #
+        # We start at 1 here as a) the first generated stream ID will be 2, and
+        # b) other parts of the code assume that stream IDs are strictly greater
+        # than 0.
+        self._persisted_upto_position = (
+            min(self._current_positions.values()) if self._current_positions else 1
+        )
+        self._known_persisted_positions = []  # type: List[int]
+
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
     ) -> Dict[str, int]:
+        # If positive stream aggregate via MAX. For negative stream use MIN
+        # *and* negate the result to get a positive number.
         sql = """
-            SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+            SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
             GROUP BY %(instance)s
         """ % {
             "instance": instance_column,
             "id": id_column,
             "table": table,
+            "agg": "MAX" if self._positive else "-MIN",
         }
 
         cur = db_conn.cursor()
@@ -284,9 +276,12 @@ class MultiWriterIdGenerator:
 
         return current_positions
 
-    def _load_next_id_txn(self, txn):
+    def _load_next_id_txn(self, txn) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
 
+    def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+        return self._sequence_gen.get_next_mult_txn(txn, n)
+
     async def get_next(self):
         """
         Usage:
@@ -298,20 +293,49 @@ class MultiWriterIdGenerator:
         # Assert the fetched ID is actually greater than what we currently
         # believe the ID to be. If not, then the sequence and table have got
         # out of sync somehow.
-        assert self.get_current_token() < next_id
-
         with self._lock:
+            assert self._current_positions.get(self._instance_name, 0) < next_id
+
             self._unfinished_ids.add(next_id)
 
         @contextlib.contextmanager
         def manager():
             try:
-                yield next_id
+                # Multiply by the return factor so that the ID has correct sign.
+                yield self._return_factor * next_id
             finally:
                 self._mark_id_as_finished(next_id)
 
         return manager()
 
+    async def get_next_mult(self, n: int):
+        """
+        Usage:
+            with await stream_id_gen.get_next_mult(5) as stream_ids:
+                # ... persist events ...
+        """
+        next_ids = await self._db.runInteraction(
+            "_load_next_mult_id", self._load_next_mult_id_txn, n
+        )
+
+        # Assert the fetched ID is actually greater than any ID we've already
+        # seen. If not, then the sequence and table have got out of sync
+        # somehow.
+        with self._lock:
+            assert max(self._current_positions.values(), default=0) < min(next_ids)
+
+            self._unfinished_ids.update(next_ids)
+
+        @contextlib.contextmanager
+        def manager():
+            try:
+                yield [self._return_factor * i for i in next_ids]
+            finally:
+                for i in next_ids:
+                    self._mark_id_as_finished(i)
+
+        return manager()
+
     def get_next_txn(self, txn: LoggingTransaction):
         """
         Usage:
@@ -328,49 +352,133 @@ class MultiWriterIdGenerator:
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
 
-        return next_id
+        return self._return_factor * next_id
 
     def _mark_id_as_finished(self, next_id: int):
         """The ID has finished being processed so we should advance the
-        current poistion if possible.
+        current position if possible.
         """
 
         with self._lock:
             self._unfinished_ids.discard(next_id)
+            self._finished_ids.add(next_id)
+
+            new_cur = None
 
-            # Figure out if its safe to advance the position by checking there
-            # aren't any lower allocated IDs that are yet to finish.
-            if all(c > next_id for c in self._unfinished_ids):
+            if self._unfinished_ids:
+                # If there are unfinished IDs then the new position will be the
+                # largest finished ID less than the minimum unfinished ID.
+
+                finished = set()
+
+                min_unfinshed = min(self._unfinished_ids)
+                for s in self._finished_ids:
+                    if s < min_unfinshed:
+                        if new_cur is None or new_cur < s:
+                            new_cur = s
+                    else:
+                        finished.add(s)
+
+                # We clear these out since they're now all less than the new
+                # position.
+                self._finished_ids = finished
+            else:
+                # There are no unfinished IDs so the new position is simply the
+                # largest finished one.
+                new_cur = max(self._finished_ids)
+
+                # We clear these out since they're now all less than the new
+                # position.
+                self._finished_ids.clear()
+
+            if new_cur:
                 curr = self._current_positions.get(self._instance_name, 0)
-                self._current_positions[self._instance_name] = max(curr, next_id)
+                self._current_positions[self._instance_name] = max(curr, new_cur)
 
-    def get_current_token(self, instance_name: str = None) -> int:
-        """Gets the current position of a named writer (defaults to current
-        instance).
+            self._add_persisted_position(next_id)
 
-        Returns 0 if we don't have a position for the named writer (likely due
-        to it being a new writer).
+    def get_current_token(self) -> int:
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
         """
 
-        if instance_name is None:
-            instance_name = self._instance_name
+        return self.get_persisted_upto_position()
+
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+        """
 
         with self._lock:
-            return self._current_positions.get(instance_name, 0)
+            return self._return_factor * self._current_positions.get(instance_name, 0)
 
     def get_positions(self) -> Dict[str, int]:
         """Get a copy of the current positon map.
         """
 
         with self._lock:
-            return dict(self._current_positions)
+            return {
+                name: self._return_factor * i
+                for name, i in self._current_positions.items()
+            }
 
     def advance(self, instance_name: str, new_id: int):
         """Advance the postion of the named writer to the given ID, if greater
         than existing entry.
         """
 
+        new_id *= self._return_factor
+
         with self._lock:
             self._current_positions[instance_name] = max(
                 new_id, self._current_positions.get(instance_name, 0)
             )
+
+            self._add_persisted_position(new_id)
+
+    def get_persisted_upto_position(self) -> int:
+        """Get the max position where all previous positions have been
+        persisted.
+
+        Note: In the worst case scenario this will be equal to the minimum
+        position across writers. This means that the returned position here can
+        lag if one writer doesn't write very often.
+        """
+
+        with self._lock:
+            return self._return_factor * self._persisted_upto_position
+
+    def _add_persisted_position(self, new_id: int):
+        """Record that we have persisted a position.
+
+        This is used to keep the `_current_positions` up to date.
+        """
+
+        # We require that the lock is locked by caller
+        assert self._lock.locked()
+
+        heapq.heappush(self._known_persisted_positions, new_id)
+
+        # We move the current min position up if the minimum current positions
+        # of all instances is higher (since by definition all positions less
+        # that that have been persisted).
+        min_curr = min(self._current_positions.values(), default=0)
+        self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+        # We now iterate through the seen positions, discarding those that are
+        # less than the current min positions, and incrementing the min position
+        # if its exactly one greater.
+        #
+        # This is also where we discard items from `_known_persisted_positions`
+        # (to ensure the list doesn't infinitely grow).
+        while self._known_persisted_positions:
+            if self._known_persisted_positions[0] <= self._persisted_upto_position:
+                heapq.heappop(self._known_persisted_positions)
+            elif (
+                self._known_persisted_positions[0] == self._persisted_upto_position + 1
+            ):
+                heapq.heappop(self._known_persisted_positions)
+                self._persisted_upto_position += 1
+            else:
+                # There was a gap in seen positions, so there is nothing more to
+                # do.
+                break
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 63dfea4220..ffc1894748 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import abc
 import threading
-from typing import Callable, Optional
+from typing import Callable, List, Optional
 
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.storage.types import Cursor
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
         txn.execute("SELECT nextval(?)", (self._sequence_name,))
         return txn.fetchone()[0]
 
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        txn.execute(
+            "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+        )
+        return [i for (i,) in txn]
+
 
 GetFirstCallbackType = Callable[[Cursor], int]
 
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index ca7c16ff65..0bdf846edf 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -14,9 +14,13 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
+
+import attr
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
 from synapse.types import StreamToken
 
 logger = logging.getLogger(__name__)
@@ -25,38 +29,22 @@ logger = logging.getLogger(__name__)
 MAX_LIMIT = 1000
 
 
-class SourcePaginationConfig(object):
-
-    """A configuration object which stores pagination parameters for a
-    specific event source."""
-
-    def __init__(self, from_key=None, to_key=None, direction="f", limit=None):
-        self.from_key = from_key
-        self.to_key = to_key
-        self.direction = "f" if direction == "f" else "b"
-        self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
-
-    def __repr__(self):
-        return "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)" % (
-            self.from_key,
-            self.to_key,
-            self.direction,
-            self.limit,
-        )
-
-
-class PaginationConfig(object):
-
+@attr.s(slots=True)
+class PaginationConfig:
     """A configuration object which stores pagination parameters."""
 
-    def __init__(self, from_token=None, to_token=None, direction="f", limit=None):
-        self.from_token = from_token
-        self.to_token = to_token
-        self.direction = "f" if direction == "f" else "b"
-        self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
+    from_token = attr.ib(type=Optional[StreamToken])
+    to_token = attr.ib(type=Optional[StreamToken])
+    direction = attr.ib(type=str)
+    limit = attr.ib(type=Optional[int])
 
     @classmethod
-    def from_request(cls, request, raise_invalid_params=True, default_limit=None):
+    def from_request(
+        cls,
+        request: SynapseRequest,
+        raise_invalid_params: bool = True,
+        default_limit: Optional[int] = None,
+    ) -> "PaginationConfig":
         direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
 
         from_tok = parse_string(request, "from")
@@ -78,8 +66,11 @@ class PaginationConfig(object):
 
         limit = parse_integer(request, "limit", default=default_limit)
 
-        if limit and limit < 0:
-            raise SynapseError(400, "Limit must be 0 or above")
+        if limit:
+            if limit < 0:
+                raise SynapseError(400, "Limit must be 0 or above")
+
+            limit = min(int(limit), MAX_LIMIT)
 
         try:
             return PaginationConfig(from_tok, to_tok, direction, limit)
@@ -87,20 +78,10 @@ class PaginationConfig(object):
             logger.exception("Failed to create pagination config")
             raise SynapseError(400, "Invalid request.")
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % (
             self.from_token,
             self.to_token,
             self.direction,
             self.limit,
         )
-
-    def get_source_config(self, source_name):
-        keyname = "%s_key" % source_name
-
-        return SourcePaginationConfig(
-            from_key=getattr(self.from_token, keyname),
-            to_key=getattr(self.to_token, keyname) if self.to_token else None,
-            direction=self.direction,
-            limit=self.limit,
-        )
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 393e34b9fb..92fd5d489f 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -23,7 +23,7 @@ from synapse.handlers.typing import TypingNotificationEventSource
 from synapse.types import StreamToken
 
 
-class EventSources(object):
+class EventSources:
     SOURCE_TYPES = {
         "room": RoomEventSource,
         "presence": PresenceEventSource,
@@ -39,7 +39,7 @@ class EventSources(object):
         self.store = hs.get_datastore()
 
     def get_current_token(self) -> StreamToken:
-        push_rules_key, _ = self.store.get_push_rules_stream_token()
+        push_rules_key = self.store.get_max_push_rules_stream_id()
         to_device_key = self.store.get_to_device_stream_token()
         device_list_key = self.store.get_device_stream_token()
         groups_key = self.store.get_group_stream_token()
diff --git a/synapse/types.py b/synapse/types.py
index 9e580f4295..dc09448bdc 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,7 +18,7 @@ import re
 import string
 import sys
 from collections import namedtuple
-from typing import Any, Dict, Tuple, Type, TypeVar
+from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -41,8 +41,9 @@ else:
 # Define a state map type from type/state_key to T (usually an event ID or
 # event)
 T = TypeVar("T")
-StateMap = Dict[Tuple[str, str], T]
-
+StateKey = Tuple[str, str]
+StateMap = Mapping[StateKey, T]
+MutableStateMap = MutableMapping[StateKey, T]
 
 # the type of a JSON-serialisable dict. This could be made stronger, but it will
 # do for now.
@@ -51,7 +52,15 @@ JsonDict = Dict[str, Any]
 
 class Requester(
     namedtuple(
-        "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
+        "Requester",
+        [
+            "user",
+            "access_token_id",
+            "is_guest",
+            "shadow_banned",
+            "device_id",
+            "app_service",
+        ],
     )
 ):
     """
@@ -62,6 +71,7 @@ class Requester(
         access_token_id (int|None):  *ID* of the access token used for this
             request, or None if it came via the appservice API or similar
         is_guest (bool):  True if the user making this request is a guest user
+        shadow_banned (bool):  True if the user making this request has been shadow-banned.
         device_id (str|None):  device_id which was set at authentication time
         app_service (ApplicationService|None):  the AS requesting on behalf of the user
     """
@@ -77,6 +87,7 @@ class Requester(
             "user_id": self.user.to_string(),
             "access_token_id": self.access_token_id,
             "is_guest": self.is_guest,
+            "shadow_banned": self.shadow_banned,
             "device_id": self.device_id,
             "app_server_id": self.app_service.id if self.app_service else None,
         }
@@ -101,13 +112,19 @@ class Requester(
             user=UserID.from_string(input["user_id"]),
             access_token_id=input["access_token_id"],
             is_guest=input["is_guest"],
+            shadow_banned=input["shadow_banned"],
             device_id=input["device_id"],
             app_service=appservice,
         )
 
 
 def create_requester(
-    user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
+    user_id,
+    access_token_id=None,
+    is_guest=False,
+    shadow_banned=False,
+    device_id=None,
+    app_service=None,
 ):
     """
     Create a new ``Requester`` object
@@ -117,6 +134,7 @@ def create_requester(
         access_token_id (int|None):  *ID* of the access token used for this
             request, or None if it came via the appservice API or similar
         is_guest (bool):  True if the user making this request is a guest user
+        shadow_banned (bool):  True if the user making this request is shadow-banned.
         device_id (str|None):  device_id which was set at authentication time
         app_service (ApplicationService|None):  the AS requesting on behalf of the user
 
@@ -125,7 +143,9 @@ def create_requester(
     """
     if not isinstance(user_id, UserID):
         user_id = UserID.from_string(user_id)
-    return Requester(user_id, access_token_id, is_guest, device_id, app_service)
+    return Requester(
+        user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
+    )
 
 
 def get_domain_from_id(string):
@@ -342,22 +362,81 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
     return username.decode("ascii")
 
 
-class StreamToken(
-    namedtuple(
-        "Token",
-        (
-            "room_key",
-            "presence_key",
-            "typing_key",
-            "receipt_key",
-            "account_data_key",
-            "push_rules_key",
-            "to_device_key",
-            "device_list_key",
-            "groups_key",
-        ),
+@attr.s(frozen=True, slots=True)
+class RoomStreamToken:
+    """Tokens are positions between events. The token "s1" comes after event 1.
+
+            s0    s1
+            |     |
+        [0] V [1] V [2]
+
+    Tokens can either be a point in the live event stream or a cursor going
+    through historic events.
+
+    When traversing the live event stream events are ordered by when they
+    arrived at the homeserver.
+
+    When traversing historic events the events are ordered by their depth in
+    the event graph "topological_ordering" and then by when they arrived at the
+    homeserver "stream_ordering".
+
+    Live tokens start with an "s" followed by the "stream_ordering" id of the
+    event it comes after. Historic tokens start with a "t" followed by the
+    "topological_ordering" id of the event it comes after, followed by "-",
+    followed by the "stream_ordering" id of the event it comes after.
+    """
+
+    topological = attr.ib(
+        type=Optional[int],
+        validator=attr.validators.optional(attr.validators.instance_of(int)),
     )
-):
+    stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+
+    @classmethod
+    def parse(cls, string: str) -> "RoomStreamToken":
+        try:
+            if string[0] == "s":
+                return cls(topological=None, stream=int(string[1:]))
+            if string[0] == "t":
+                parts = string[1:].split("-", 1)
+                return cls(topological=int(parts[0]), stream=int(parts[1]))
+        except Exception:
+            pass
+        raise SynapseError(400, "Invalid token %r" % (string,))
+
+    @classmethod
+    def parse_stream_token(cls, string: str) -> "RoomStreamToken":
+        try:
+            if string[0] == "s":
+                return cls(topological=None, stream=int(string[1:]))
+        except Exception:
+            pass
+        raise SynapseError(400, "Invalid token %r" % (string,))
+
+    def as_tuple(self) -> Tuple[Optional[int], int]:
+        return (self.topological, self.stream)
+
+    def __str__(self) -> str:
+        if self.topological is not None:
+            return "t%d-%d" % (self.topological, self.stream)
+        else:
+            return "s%d" % (self.stream,)
+
+
+@attr.s(slots=True, frozen=True)
+class StreamToken:
+    room_key = attr.ib(
+        type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
+    )
+    presence_key = attr.ib(type=int)
+    typing_key = attr.ib(type=int)
+    receipt_key = attr.ib(type=int)
+    account_data_key = attr.ib(type=int)
+    push_rules_key = attr.ib(type=int)
+    to_device_key = attr.ib(type=int)
+    device_list_key = attr.ib(type=int)
+    groups_key = attr.ib(type=int)
+
     _SEPARATOR = "_"
     START = None  # type: StreamToken
 
@@ -365,24 +444,19 @@ class StreamToken(
     def from_string(cls, string):
         try:
             keys = string.split(cls._SEPARATOR)
-            while len(keys) < len(cls._fields):
+            while len(keys) < len(attr.fields(cls)):
                 # i.e. old token from before receipt_key
                 keys.append("0")
-            return cls(*keys)
+            return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
         except Exception:
             raise SynapseError(400, "Invalid Token")
 
     def to_string(self):
-        return self._SEPARATOR.join([str(k) for k in self])
+        return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
 
     @property
     def room_stream_id(self):
-        # TODO(markjh): Awful hack to work around hacks in the presence tests
-        # which assume that the keys are integers.
-        if type(self.room_key) is int:
-            return self.room_key
-        else:
-            return int(self.room_key[1:].split("-")[-1])
+        return self.room_key.stream
 
     def is_after(self, other):
         """Does this token contain events that the other doesn't?"""
@@ -398,7 +472,7 @@ class StreamToken(
             or (int(other.groups_key) < int(self.groups_key))
         )
 
-    def copy_and_advance(self, key, new_value):
+    def copy_and_advance(self, key, new_value) -> "StreamToken":
         """Advance the given key in the token to a new value if and only if the
         new value is after the old value.
         """
@@ -414,64 +488,11 @@ class StreamToken(
         else:
             return self
 
-    def copy_and_replace(self, key, new_value):
-        return self._replace(**{key: new_value})
-
-
-StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))
-
-
-class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
-    """Tokens are positions between events. The token "s1" comes after event 1.
-
-            s0    s1
-            |     |
-        [0] V [1] V [2]
-
-    Tokens can either be a point in the live event stream or a cursor going
-    through historic events.
-
-    When traversing the live event stream events are ordered by when they
-    arrived at the homeserver.
-
-    When traversing historic events the events are ordered by their depth in
-    the event graph "topological_ordering" and then by when they arrived at the
-    homeserver "stream_ordering".
-
-    Live tokens start with an "s" followed by the "stream_ordering" id of the
-    event it comes after. Historic tokens start with a "t" followed by the
-    "topological_ordering" id of the event it comes after, followed by "-",
-    followed by the "stream_ordering" id of the event it comes after.
-    """
-
-    __slots__ = []  # type: list
-
-    @classmethod
-    def parse(cls, string):
-        try:
-            if string[0] == "s":
-                return cls(topological=None, stream=int(string[1:]))
-            if string[0] == "t":
-                parts = string[1:].split("-", 1)
-                return cls(topological=int(parts[0]), stream=int(parts[1]))
-        except Exception:
-            pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
+    def copy_and_replace(self, key, new_value) -> "StreamToken":
+        return attr.evolve(self, **{key: new_value})
 
-    @classmethod
-    def parse_stream_token(cls, string):
-        try:
-            if string[0] == "s":
-                return cls(topological=None, stream=int(string[1:]))
-        except Exception:
-            pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
 
-    def __str__(self):
-        if self.topological is not None:
-            return "t%d-%d" % (self.topological, self.stream)
-        else:
-            return "s%d" % (self.stream,)
+StreamToken.START = StreamToken.from_string("s0_0")
 
 
 class ThirdPartyInstanceID(
@@ -509,7 +530,7 @@ class ThirdPartyInstanceID(
 
 
 @attr.s(slots=True)
-class ReadReceipt(object):
+class ReadReceipt:
     """Information about a read-receipt"""
 
     room_id = attr.ib()
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index b3f76428b6..d55b93d763 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,11 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import json
 import logging
 import re
 
 import attr
-from canonicaljson import json
 
 from twisted.internet import defer, task
 
@@ -25,8 +25,18 @@ from synapse.logging import context
 
 logger = logging.getLogger(__name__)
 
-# Create a custom encoder to reduce the whitespace produced by JSON encoding.
-json_encoder = json.JSONEncoder(separators=(",", ":"))
+
+def _reject_invalid_json(val):
+    """Do not allow Infinity, -Infinity, or NaN values in JSON."""
+    raise ValueError("Invalid JSON value: '%s'" % val)
+
+
+# Create a custom encoder to reduce the whitespace produced by JSON encoding and
+# ensure that valid JSON is produced.
+json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
+
+# Create a custom decoder to reject Python extensions to JSON.
+json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
 
 
 def unwrapFirstError(failure):
@@ -35,8 +45,8 @@ def unwrapFirstError(failure):
     return failure.value.subFailure
 
 
-@attr.s
-class Clock(object):
+@attr.s(slots=True)
+class Clock:
     """
     A Clock wraps a Twisted reactor and provides utilities on top of it.
 
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f562770922..67ce9a5f39 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -17,12 +17,25 @@
 import collections
 import logging
 from contextlib import contextmanager
-from typing import Dict, Sequence, Set, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Hashable,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    TypeVar,
+    Union,
+)
 
 import attr
+from typing_extensions import ContextManager
 
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
+from twisted.internet.interfaces import IReactorTime
 from twisted.python import failure
 
 from synapse.logging.context import (
@@ -35,7 +48,7 @@ from synapse.util import Clock, unwrapFirstError
 logger = logging.getLogger(__name__)
 
 
-class ObservableDeferred(object):
+class ObservableDeferred:
     """Wraps a deferred object so that we can add observer deferreds. These
     observer deferreds do not affect the callback chain of the original
     deferred.
@@ -53,7 +66,7 @@ class ObservableDeferred(object):
 
     __slots__ = ["_deferred", "_observers", "_result"]
 
-    def __init__(self, deferred, consumeErrors=False):
+    def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", set())
@@ -110,25 +123,25 @@ class ObservableDeferred(object):
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
-    def observers(self):
+    def observers(self) -> List[defer.Deferred]:
         return self._observers
 
-    def has_called(self):
+    def has_called(self) -> bool:
         return self._result is not None
 
-    def has_succeeded(self):
+    def has_succeeded(self) -> bool:
         return self._result is not None and self._result[0] is True
 
-    def get_result(self):
+    def get_result(self) -> Any:
         return self._result[1]
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         return getattr(self._deferred, name)
 
-    def __setattr__(self, name, value):
+    def __setattr__(self, name: str, value: Any) -> None:
         setattr(self._deferred, name, value)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
             id(self),
             self._result,
@@ -136,18 +149,20 @@ class ObservableDeferred(object):
         )
 
 
-def concurrently_execute(func, args, limit):
-    """Executes the function with each argument conncurrently while limiting
+def concurrently_execute(
+    func: Callable, args: Iterable[Any], limit: int
+) -> defer.Deferred:
+    """Executes the function with each argument concurrently while limiting
     the number of concurrent executions.
 
     Args:
-        func (func): Function to execute, should return a deferred or coroutine.
-        args (Iterable): List of arguments to pass to func, each invocation of func
+        func: Function to execute, should return a deferred or coroutine.
+        args: List of arguments to pass to func, each invocation of func
             gets a single argument.
-        limit (int): Maximum number of conccurent executions.
+        limit: Maximum number of conccurent executions.
 
     Returns:
-        deferred: Resolved when all function invocations have finished.
+        Deferred[list]: Resolved when all function invocations have finished.
     """
     it = iter(args)
 
@@ -166,14 +181,17 @@ def concurrently_execute(func, args, limit):
     ).addErrback(unwrapFirstError)
 
 
-def yieldable_gather_results(func, iter, *args, **kwargs):
+def yieldable_gather_results(
+    func: Callable, iter: Iterable, *args: Any, **kwargs: Any
+) -> defer.Deferred:
     """Executes the function with each argument concurrently.
 
     Args:
-        func (func): Function to execute that returns a Deferred
-        iter (iter): An iterable that yields items that get passed as the first
+        func: Function to execute that returns a Deferred
+        iter: An iterable that yields items that get passed as the first
             argument to the function
         *args: Arguments to be passed to each call to func
+        **kwargs: Keyword arguments to be passed to each call to func
 
     Returns
         Deferred[list]: Resolved when all functions have been invoked, or errors if
@@ -187,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
     ).addErrback(unwrapFirstError)
 
 
-class Linearizer(object):
+@attr.s(slots=True)
+class _LinearizerEntry:
+    # The number of things executing.
+    count = attr.ib(type=int)
+    # Deferreds for the things blocked from executing.
+    deferreds = attr.ib(type=collections.OrderedDict)
+
+
+class Linearizer:
     """Limits concurrent access to resources based on a key. Useful to ensure
     only a few things happen at a time on a given resource.
 
     Example:
 
-        with (yield limiter.queue("test_key")):
+        with await limiter.queue("test_key"):
             # do some work.
 
     """
 
-    def __init__(self, name=None, max_count=1, clock=None):
+    def __init__(
+        self,
+        name: Optional[str] = None,
+        max_count: int = 1,
+        clock: Optional[Clock] = None,
+    ):
         """
         Args:
-            max_count(int): The maximum number of concurrent accesses
+            max_count: The maximum number of concurrent accesses
         """
         if name is None:
-            self.name = id(self)
+            self.name = id(self)  # type: Union[str, int]
         else:
             self.name = name
 
@@ -215,15 +246,10 @@ class Linearizer(object):
         self._clock = clock
         self.max_count = max_count
 
-        # key_to_defer is a map from the key to a 2 element list where
-        # the first element is the number of things executing, and
-        # the second element is an OrderedDict, where the keys are deferreds for the
-        # things blocked from executing.
-        self.key_to_defer = (
-            {}
-        )  # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
+        # key_to_defer is a map from the key to a _LinearizerEntry.
+        self.key_to_defer = {}  # type: Dict[Hashable, _LinearizerEntry]
 
-    def is_queued(self, key) -> bool:
+    def is_queued(self, key: Hashable) -> bool:
         """Checks whether there is a process queued up waiting
         """
         entry = self.key_to_defer.get(key)
@@ -233,25 +259,27 @@ class Linearizer(object):
 
         # There are waiting deferreds only in the OrderedDict of deferreds is
         # non-empty.
-        return bool(entry[1])
+        return bool(entry.deferreds)
 
-    def queue(self, key):
+    def queue(self, key: Hashable) -> defer.Deferred:
         # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
         # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
         # propagated inside inlineCallbacks until Twisted 18.7)
-        entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
+        entry = self.key_to_defer.setdefault(
+            key, _LinearizerEntry(0, collections.OrderedDict())
+        )
 
         # If the number of things executing is greater than the maximum
         # then add a deferred to the list of blocked items
         # When one of the things currently executing finishes it will callback
         # this item so that it can continue executing.
-        if entry[0] >= self.max_count:
+        if entry.count >= self.max_count:
             res = self._await_lock(key)
         else:
             logger.debug(
                 "Acquired uncontended linearizer lock %r for key %r", self.name, key
             )
-            entry[0] += 1
+            entry.count += 1
             res = defer.succeed(None)
 
         # once we successfully get the lock, we need to return a context manager which
@@ -266,15 +294,15 @@ class Linearizer(object):
 
                 # We've finished executing so check if there are any things
                 # blocked waiting to execute and start one of them
-                entry[0] -= 1
+                entry.count -= 1
 
-                if entry[1]:
-                    (next_def, _) = entry[1].popitem(last=False)
+                if entry.deferreds:
+                    (next_def, _) = entry.deferreds.popitem(last=False)
 
                     # we need to run the next thing in the sentinel context.
                     with PreserveLoggingContext():
                         next_def.callback(None)
-                elif entry[0] == 0:
+                elif entry.count == 0:
                     # We were the last thing for this key: remove it from the
                     # map.
                     del self.key_to_defer[key]
@@ -282,7 +310,7 @@ class Linearizer(object):
         res.addCallback(_ctx_manager)
         return res
 
-    def _await_lock(self, key):
+    def _await_lock(self, key: Hashable) -> defer.Deferred:
         """Helper for queue: adds a deferred to the queue
 
         Assumes that we've already checked that we've reached the limit of the number
@@ -297,11 +325,11 @@ class Linearizer(object):
         logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
 
         new_defer = make_deferred_yieldable(defer.Deferred())
-        entry[1][new_defer] = 1
+        entry.deferreds[new_defer] = 1
 
         def cb(_r):
             logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
-            entry[0] += 1
+            entry.count += 1
 
             # if the code holding the lock completes synchronously, then it
             # will recursively run the next claimant on the list. That can
@@ -330,19 +358,19 @@ class Linearizer(object):
                 )
 
             # we just have to take ourselves back out of the queue.
-            del entry[1][new_defer]
+            del entry.deferreds[new_defer]
             return e
 
         new_defer.addCallbacks(cb, eb)
         return new_defer
 
 
-class ReadWriteLock(object):
-    """A deferred style read write lock.
+class ReadWriteLock:
+    """An async read write lock.
 
     Example:
 
-        with (yield read_write_lock.read("test_key")):
+        with await read_write_lock.read("test_key"):
             # do some work
     """
 
@@ -365,8 +393,7 @@ class ReadWriteLock(object):
         # Latest writer queued
         self.key_to_current_writer = {}  # type: Dict[str, defer.Deferred]
 
-    @defer.inlineCallbacks
-    def read(self, key):
+    async def read(self, key: str) -> ContextManager:
         new_defer = defer.Deferred()
 
         curr_readers = self.key_to_current_readers.setdefault(key, set())
@@ -376,7 +403,8 @@ class ReadWriteLock(object):
 
         # We wait for the latest writer to finish writing. We can safely ignore
         # any existing readers... as they're readers.
-        yield make_deferred_yieldable(curr_writer)
+        if curr_writer:
+            await make_deferred_yieldable(curr_writer)
 
         @contextmanager
         def _ctx_manager():
@@ -388,8 +416,7 @@ class ReadWriteLock(object):
 
         return _ctx_manager()
 
-    @defer.inlineCallbacks
-    def write(self, key):
+    async def write(self, key: str) -> ContextManager:
         new_defer = defer.Deferred()
 
         curr_readers = self.key_to_current_readers.get(key, set())
@@ -405,7 +432,7 @@ class ReadWriteLock(object):
         curr_readers.clear()
         self.key_to_current_writer[key] = new_defer
 
-        yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
+        await make_deferred_yieldable(defer.gatherResults(to_wait_on))
 
         @contextmanager
         def _ctx_manager():
@@ -419,14 +446,22 @@ class ReadWriteLock(object):
         return _ctx_manager()
 
 
-def _cancelled_to_timed_out_error(value, timeout):
+R = TypeVar("R")
+
+
+def _cancelled_to_timed_out_error(value: R, timeout: float) -> R:
     if isinstance(value, failure.Failure):
         value.trap(CancelledError)
         raise defer.TimeoutError(timeout, "Deferred")
     return value
 
 
-def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
+def timeout_deferred(
+    deferred: defer.Deferred,
+    timeout: float,
+    reactor: IReactorTime,
+    on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None,
+) -> defer.Deferred:
     """The in built twisted `Deferred.addTimeout` fails to time out deferreds
     that have a canceller that throws exceptions. This method creates a new
     deferred that wraps and times out the given deferred, correctly handling
@@ -437,10 +472,10 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
     NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
 
     Args:
-        deferred (Deferred)
-        timeout (float): Timeout in seconds
-        reactor (twisted.interfaces.IReactorTime): The twisted reactor to use
-        on_timeout_cancel (callable): A callable which is called immediately
+        deferred: The Deferred to potentially timeout.
+        timeout: Timeout in seconds
+        reactor: The twisted reactor to use
+        on_timeout_cancel: A callable which is called immediately
             after the deferred times out, and not if this deferred is
             otherwise cancelled before the timeout.
 
@@ -452,7 +487,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
             CancelledError Failure into a defer.TimeoutError.
 
     Returns:
-        Deferred
+        A new Deferred.
     """
 
     new_d = defer.Deferred()
@@ -502,7 +537,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
 
 
 @attr.s(slots=True, frozen=True)
-class DoneAwaitable(object):
+class DoneAwaitable:
     """Simple awaitable that returns the provided value.
     """
 
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index dd356bf156..8fc05be278 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -42,8 +42,8 @@ response_cache_evicted = Gauge(
 response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
 
 
-@attr.s
-class CacheMetric(object):
+@attr.s(slots=True)
+class CacheMetric:
 
     _cache = attr.ib()
     _cache_type = attr.ib(type=str)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index c2d72a82cf..98b34f2223 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,11 +18,10 @@ import functools
 import inspect
 import logging
 import threading
-from typing import Any, Tuple, Union, cast
+from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
 from weakref import WeakValueDictionary
 
 from prometheus_client import Gauge
-from typing_extensions import Protocol
 
 from twisted.internet import defer
 
@@ -38,8 +37,10 @@ logger = logging.getLogger(__name__)
 
 CacheKey = Union[Tuple, Any]
 
+F = TypeVar("F", bound=Callable[..., Any])
 
-class _CachedFunction(Protocol):
+
+class _CachedFunction(Generic[F]):
     invalidate = None  # type: Any
     invalidate_all = None  # type: Any
     invalidate_many = None  # type: Any
@@ -47,8 +48,11 @@ class _CachedFunction(Protocol):
     cache = None  # type: Any
     num_args = None  # type: Any
 
-    def __name__(self):
-        ...
+    __name__ = None  # type: str
+
+    # Note: This function signature is actually fiddled with by the synapse mypy
+    # plugin to a) make it a bound method, and b) remove any `cache_context` arg.
+    __call__ = None  # type: F
 
 
 cache_pending_metric = Gauge(
@@ -60,7 +64,7 @@ cache_pending_metric = Gauge(
 _CacheSentinel = object()
 
 
-class CacheEntry(object):
+class CacheEntry:
     __slots__ = ["deferred", "callbacks", "invalidated"]
 
     def __init__(self, deferred, callbacks):
@@ -76,7 +80,7 @@ class CacheEntry(object):
             self.callbacks.clear()
 
 
-class Cache(object):
+class Cache:
     __slots__ = (
         "cache",
         "name",
@@ -123,7 +127,7 @@ class Cache(object):
 
         self.name = name
         self.keylen = keylen
-        self.thread = None
+        self.thread = None  # type: Optional[threading.Thread]
         self.metrics = register_cache(
             "cache",
             name,
@@ -284,17 +288,10 @@ class Cache(object):
         self._pending_deferred_cache.clear()
 
 
-class _CacheDescriptorBase(object):
-    def __init__(
-        self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
-    ):
+class _CacheDescriptorBase:
+    def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
         self.orig = orig
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
-
         arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
 
@@ -364,7 +361,7 @@ class CacheDescriptor(_CacheDescriptorBase):
     invalidated) by adding a special "cache_context" argument to the function
     and passing that as a kwarg to all caches called. For example::
 
-        @cachedInlineCallbacks(cache_context=True)
+        @cached(cache_context=True)
         def foo(self, key, cache_context):
             r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
@@ -382,17 +379,11 @@ class CacheDescriptor(_CacheDescriptorBase):
         max_entries=1000,
         num_args=None,
         tree=False,
-        inlineCallbacks=False,
         cache_context=False,
         iterable=False,
     ):
 
-        super(CacheDescriptor, self).__init__(
-            orig,
-            num_args=num_args,
-            inlineCallbacks=inlineCallbacks,
-            cache_context=cache_context,
-        )
+        super().__init__(orig, num_args=num_args, cache_context=cache_context)
 
         self.max_entries = max_entries
         self.tree = tree
@@ -465,9 +456,7 @@ class CacheDescriptor(_CacheDescriptorBase):
                     observer = defer.succeed(cached_result_d)
 
             except KeyError:
-                ret = defer.maybeDeferred(
-                    preserve_fn(self.function_to_call), obj, *args, **kwargs
-                )
+                ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
 
                 def onErr(f):
                     cache.invalidate(cache_key)
@@ -510,9 +499,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
     of results.
     """
 
-    def __init__(
-        self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
-    ):
+    def __init__(self, orig, cached_method_name, list_name, num_args=None):
         """
         Args:
             orig (function)
@@ -521,12 +508,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
             num_args (int): number of positional arguments (excluding ``self``,
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
-            inlineCallbacks (bool): Whether orig is a generator that should
-                be wrapped by defer.inlineCallbacks
         """
-        super(CacheListDescriptor, self).__init__(
-            orig, num_args=num_args, inlineCallbacks=inlineCallbacks
-        )
+        super().__init__(orig, num_args=num_args)
 
         self.list_name = list_name
 
@@ -631,7 +614,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
                 cached_defers.append(
                     defer.maybeDeferred(
-                        preserve_fn(self.function_to_call), **args_to_call
+                        preserve_fn(self.orig), **args_to_call
                     ).addCallbacks(complete_all, errback)
                 )
 
@@ -683,9 +666,13 @@ class _CacheContext:
 
 
 def cached(
-    max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
-):
-    return lambda orig: CacheDescriptor(
+    max_entries: int = 1000,
+    num_args: Optional[int] = None,
+    tree: bool = False,
+    cache_context: bool = False,
+    iterable: bool = False,
+) -> Callable[[F], _CachedFunction[F]]:
+    func = lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
@@ -694,22 +681,12 @@ def cached(
         iterable=iterable,
     )
 
-
-def cachedInlineCallbacks(
-    max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
-):
-    return lambda orig: CacheDescriptor(
-        orig,
-        max_entries=max_entries,
-        num_args=num_args,
-        tree=tree,
-        inlineCallbacks=True,
-        cache_context=cache_context,
-        iterable=iterable,
-    )
+    return cast(Callable[[F], _CachedFunction[F]], func)
 
 
-def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
+def cachedList(
+    cached_method_name: str, list_name: str, num_args: Optional[int] = None
+) -> Callable[[F], _CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. A single argument
@@ -719,18 +696,16 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
     cache.
 
     Args:
-        cached_method_name (str): The name of the single-item lookup method.
+        cached_method_name: The name of the single-item lookup method.
             This is only used to find the cache to use.
-        list_name (str): The name of the argument that is the list to use to
+        list_name: The name of the argument that is the list to use to
             do batch lookups in the cache.
-        num_args (int): Number of arguments to use as the key in the cache
+        num_args: Number of arguments to use as the key in the cache
             (including list_name). Defaults to all named parameters.
-        inlineCallbacks (bool): Should the function be wrapped in an
-            `defer.inlineCallbacks`?
 
     Example:
 
-        class Example(object):
+        class Example:
             @cached(num_args=2)
             def do_something(self, first_arg):
                 ...
@@ -739,10 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
             def batch_do_something(self, first_arg, second_args):
                 ...
     """
-    return lambda orig: CacheListDescriptor(
+    func = lambda orig: CacheListDescriptor(
         orig,
         cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
-        inlineCallbacks=inlineCallbacks,
     )
+
+    return cast(Callable[[F], _CachedFunction[F]], func)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 6834e6f3ae..8592b93689 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -40,7 +40,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
         return len(self.value)
 
 
-class DictionaryCache(object):
+class DictionaryCache:
     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
     fetching a subset of dictionary keys for a particular key.
     """
@@ -53,7 +53,7 @@ class DictionaryCache(object):
         self.thread = None
         # caches_by_name[name] = self.cache
 
-        class Sentinel(object):
+        class Sentinel:
             __slots__ = []
 
         self.sentinel = Sentinel()
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 89a3420f92..e15f7ee698 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
 SENTINEL = object()
 
 
-class ExpiringCache(object):
+class ExpiringCache:
     def __init__(
         self,
         cache_name,
@@ -190,7 +190,7 @@ class ExpiringCache(object):
         return False
 
 
-class _CacheEntry(object):
+class _CacheEntry:
     __slots__ = ["time", "value"]
 
     def __init__(self, time, value):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index df4ea5901d..4bc1a67b58 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -30,7 +30,7 @@ def enumerate_leaves(node, depth):
                 yield m
 
 
-class _Node(object):
+class _Node:
     __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
 
     def __init__(self, prev_node, next_node, key, value, callbacks=set()):
@@ -41,7 +41,7 @@ class _Node(object):
         self.callbacks = callbacks
 
 
-class LruCache(object):
+class LruCache:
     """
     Least-recently-used cache.
     Supports del_multi only if cache_type=TreeCache
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index a6c60888e5..df1a721add 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -23,7 +23,7 @@ from synapse.util.caches import register_cache
 logger = logging.getLogger(__name__)
 
 
-class ResponseCache(object):
+class ResponseCache:
     """
     This caches a deferred response. Until the deferred completes it will be
     returned from the cache. This means that if the client retries the request
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index ecd9948e79..eb4d98f683 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -3,7 +3,7 @@ from typing import Dict
 SENTINEL = object()
 
 
-class TreeCache(object):
+class TreeCache:
     """
     Tree-based backing store for LruCache. Allows subtrees of data to be deleted
     efficiently.
@@ -89,7 +89,7 @@ def iterate_tree_cache_entry(d):
             yield d
 
 
-class _Entry(object):
+class _Entry:
     __slots__ = ["value"]
 
     def __init__(self, value):
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 6437aa907e..3e180cafd3 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
 SENTINEL = object()
 
 
-class TTLCache(object):
+class TTLCache:
     """A key/value cache implementation where each entry has its own TTL"""
 
     def __init__(self, cache_name, timer=time.time):
@@ -154,7 +154,7 @@ class TTLCache(object):
 
 
 @attr.s(frozen=True, slots=True)
-class _CacheEntry(object):
+class _CacheEntry:
     """TTLCache entry"""
 
     # expiry_time is the first attribute, so that entries are sorted by expiry.
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 22a857a306..f73e95393c 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -16,8 +16,6 @@ import inspect
 import logging
 
 from twisted.internet import defer
-from twisted.internet.defer import Deferred, fail, succeed
-from twisted.python import failure
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -29,12 +27,7 @@ def user_left_room(distributor, user, room_id):
     distributor.fire("user_left_room", user=user, room_id=room_id)
 
 
-# XXX: this is no longer used. We should probably kill it.
-def user_joined_room(distributor, user, room_id):
-    distributor.fire("user_joined_room", user=user, room_id=room_id)
-
-
-class Distributor(object):
+class Distributor:
     """A central dispatch point for loosely-connected pieces of code to
     register, observe, and fire signals.
 
@@ -81,29 +74,7 @@ class Distributor(object):
         run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
 
 
-def maybeAwaitableDeferred(f, *args, **kw):
-    """
-    Invoke a function that may or may not return a Deferred or an Awaitable.
-
-    This is a modified version of twisted.internet.defer.maybeDeferred.
-    """
-    try:
-        result = f(*args, **kw)
-    except Exception:
-        return fail(failure.Failure(captureVars=Deferred.debug))
-
-    if isinstance(result, Deferred):
-        return result
-    # Handle the additional case of an awaitable being returned.
-    elif inspect.isawaitable(result):
-        return defer.ensureDeferred(result)
-    elif isinstance(result, failure.Failure):
-        return fail(result)
-    else:
-        return succeed(result)
-
-
-class Signal(object):
+class Signal:
     """A Signal is a dispatch point that stores a list of callables as
     observers of it.
 
@@ -132,22 +103,17 @@ class Signal(object):
         Returns a Deferred that will complete when all the observers have
         completed."""
 
-        def do(observer):
-            def eb(failure):
+        async def do(observer):
+            try:
+                result = observer(*args, **kwargs)
+                if inspect.isawaitable(result):
+                    result = await result
+                return result
+            except Exception as e:
                 logger.warning(
-                    "%s signal observer %s failed: %r",
-                    self.name,
-                    observer,
-                    failure,
-                    exc_info=(
-                        failure.type,
-                        failure.value,
-                        failure.getTracebackObject(),
-                    ),
+                    "%s signal observer %s failed: %r", self.name, observer, e,
                 )
 
-            return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
-
         deferreds = [run_in_background(do, o) for o in self.observers]
 
         return make_deferred_yieldable(
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 6a3f6177b1..733f5e26e6 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -20,7 +20,7 @@ from twisted.internet import threads
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 
 
-class BackgroundFileConsumer(object):
+class BackgroundFileConsumer:
     """A consumer that writes to a file like object. Supports both push
     and pull producers
 
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 0e445e01d7..bf094c9386 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,7 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from canonicaljson import json
+import json
+
 from frozendict import frozendict
 
 
@@ -66,5 +67,5 @@ def _handle_frozendict(obj):
 # A JSONEncoder which is capable of encoding frozendicts without barfing.
 # Additionally reduce the whitespace produced by JSON encoding.
 frozendict_json_encoder = json.JSONEncoder(
-    default=_handle_frozendict, separators=(",", ":"),
+    allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
 )
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 6dce03dd3a..50516926f3 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 
-class JsonEncodedObject(object):
+class JsonEncodedObject:
     """ A common base class for defining protocol units that are represented
     as JSON.
 
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 13775b43f9..6e57c1ee72 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -93,7 +93,7 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
     return wrapper
 
 
-class Measure(object):
+class Measure:
     __slots__ = [
         "clock",
         "name",
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index e5efdfcd02..70d11e1ec3 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -29,7 +29,7 @@ from synapse.logging.context import (
 logger = logging.getLogger(__name__)
 
 
-class FederationRateLimiter(object):
+class FederationRateLimiter:
     def __init__(self, clock, config):
         """
         Args:
@@ -60,7 +60,7 @@ class FederationRateLimiter(object):
         return self.ratelimiters[host].ratelimit()
 
 
-class _PerHostRatelimiter(object):
+class _PerHostRatelimiter:
     def __init__(self, clock, config):
         """
         Args:
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 919988d3bc..79869aaa44 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -114,7 +114,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
     )
 
 
-class RetryDestinationLimiter(object):
+class RetryDestinationLimiter:
     def __init__(
         self,
         destination,
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 2e2b40a426..61d96a6c28 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -24,9 +24,7 @@ from synapse.api.errors import Codes, SynapseError
 _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
 
 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
-# Note: The : character is allowed here for older clients, but will be removed in a
-# future release. Context: https://github.com/matrix-org/synapse/issues/6766
-client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$")
+client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
 
 # random_string and random_string_with_symbols are used for a range of things,
 # some cryptographically important, some less so. We use SystemRandom to make sure
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 023beb5ede..be3b22469d 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 
-class _Entry(object):
+class _Entry:
     __slots__ = ["end_key", "queue"]
 
     def __init__(self, end_key):
@@ -22,7 +22,7 @@ class _Entry(object):
         self.queue = []
 
 
-class WheelTimer(object):
+class WheelTimer:
     """Stores arbitrary objects that will be returned after their timers have
     expired.
     """