summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/_scripts/register_new_matrix_user.py2
-rw-r--r--synapse/api/auth.py6
-rw-r--r--synapse/api/errors.py50
-rw-r--r--synapse/api/filtering.py4
-rw-r--r--synapse/api/urls.py1
-rw-r--r--synapse/app/admin_cmd.py3
-rw-r--r--synapse/app/generic_worker.py6
-rw-r--r--synapse/app/homeserver.py12
-rw-r--r--synapse/appservice/api.py4
-rw-r--r--synapse/config/_base.py21
-rw-r--r--synapse/config/_base.pyi1
-rw-r--r--synapse/config/consent_config.py2
-rw-r--r--synapse/config/emailconfig.py13
-rw-r--r--synapse/config/logger.py25
-rw-r--r--synapse/config/registration.py2
-rw-r--r--synapse/config/saml2_config.py36
-rw-r--r--synapse/config/server.py33
-rw-r--r--synapse/config/server_notices_config.py2
-rw-r--r--synapse/config/stats.py2
-rw-r--r--synapse/config/workers.py37
-rw-r--r--synapse/crypto/context_factory.py6
-rw-r--r--synapse/crypto/keyring.py4
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/federation/sender/__init__.py62
-rw-r--r--synapse/federation/sender/per_destination_queue.py140
-rw-r--r--synapse/federation/transport/server.py10
-rw-r--r--synapse/groups/groups_server.py2
-rw-r--r--synapse/handlers/acme_issuing_service.py2
-rw-r--r--synapse/handlers/admin.py8
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/deactivate_account.py2
-rw-r--r--synapse/handlers/device.py46
-rw-r--r--synapse/handlers/directory.py2
-rw-r--r--synapse/handlers/e2e_keys.py2
-rw-r--r--synapse/handlers/events.py8
-rw-r--r--synapse/handlers/federation.py110
-rw-r--r--synapse/handlers/groups_local.py2
-rw-r--r--synapse/handlers/identity.py2
-rw-r--r--synapse/handlers/initial_sync.py17
-rw-r--r--synapse/handlers/message.py127
-rw-r--r--synapse/handlers/oidc_handler.py4
-rw-r--r--synapse/handlers/pagination.py51
-rw-r--r--synapse/handlers/presence.py3
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/handlers/read_marker.py2
-rw-r--r--synapse/handlers/receipts.py17
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/handlers/room.py31
-rw-r--r--synapse/handlers/room_list.py2
-rw-r--r--synapse/handlers/room_member.py53
-rw-r--r--synapse/handlers/room_member_worker.py11
-rw-r--r--synapse/handlers/saml_handler.py171
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/set_password.py2
-rw-r--r--synapse/handlers/sync.py60
-rw-r--r--synapse/handlers/user_directory.py2
-rw-r--r--synapse/http/__init__.py2
-rw-r--r--synapse/http/client.py187
-rw-r--r--synapse/http/federation/well_known_resolver.py2
-rw-r--r--synapse/http/matrixfederationclient.py2
-rw-r--r--synapse/logging/context.py4
-rw-r--r--synapse/logging/formatter.py2
-rw-r--r--synapse/logging/opentracing.py2
-rw-r--r--synapse/logging/scopecontextmanager.py6
-rw-r--r--synapse/logging/utils.py6
-rw-r--r--synapse/metrics/__init__.py4
-rw-r--r--synapse/notifier.py108
-rw-r--r--synapse/push/__init__.py2
-rw-r--r--synapse/push/emailpusher.py2
-rw-r--r--synapse/push/httppusher.py2
-rw-r--r--synapse/push/mailer.py2
-rw-r--r--synapse/push/pusherpool.py37
-rw-r--r--synapse/python_dependencies.py12
-rw-r--r--synapse/replication/http/_base.py4
-rw-r--r--synapse/replication/http/devices.py2
-rw-r--r--synapse/replication/http/federation.py20
-rw-r--r--synapse/replication/http/login.py2
-rw-r--r--synapse/replication/http/membership.py16
-rw-r--r--synapse/replication/http/register.py4
-rw-r--r--synapse/replication/http/send_event.py2
-rw-r--r--synapse/replication/slave/storage/_base.py4
-rw-r--r--synapse/replication/slave/storage/account_data.py2
-rw-r--r--synapse/replication/slave/storage/client_ips.py2
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py2
-rw-r--r--synapse/replication/slave/storage/devices.py2
-rw-r--r--synapse/replication/slave/storage/events.py2
-rw-r--r--synapse/replication/slave/storage/filtering.py2
-rw-r--r--synapse/replication/slave/storage/groups.py2
-rw-r--r--synapse/replication/slave/storage/presence.py2
-rw-r--r--synapse/replication/slave/storage/pushers.py2
-rw-r--r--synapse/replication/slave/storage/receipts.py2
-rw-r--r--synapse/replication/slave/storage/room.py2
-rw-r--r--synapse/replication/tcp/client.py16
-rw-r--r--synapse/replication/tcp/handler.py2
-rw-r--r--synapse/replication/tcp/resource.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py6
-rw-r--r--synapse/replication/tcp/streams/events.py4
-rw-r--r--synapse/res/templates/password_reset_confirmation.html16
-rw-r--r--synapse/res/templates/saml_error.html52
-rw-r--r--synapse/res/templates/sso_error.html43
-rw-r--r--synapse/rest/__init__.py6
-rw-r--r--synapse/rest/admin/__init__.py8
-rw-r--r--synapse/rest/admin/_base.py4
-rw-r--r--synapse/rest/admin/devices.py17
-rw-r--r--synapse/rest/admin/event_reports.py88
-rw-r--r--synapse/rest/admin/purge_room_servlet.py5
-rw-r--r--synapse/rest/admin/server_notice_servlet.py9
-rw-r--r--synapse/rest/admin/users.py34
-rw-r--r--synapse/rest/client/v1/directory.py6
-rw-r--r--synapse/rest/client/v1/events.py4
-rw-r--r--synapse/rest/client/v1/initial_sync.py2
-rw-r--r--synapse/rest/client/v1/login.py53
-rw-r--r--synapse/rest/client/v1/logout.py4
-rw-r--r--synapse/rest/client/v1/presence.py2
-rw-r--r--synapse/rest/client/v1/profile.py6
-rw-r--r--synapse/rest/client/v1/push_rule.py17
-rw-r--r--synapse/rest/client/v1/pusher.py6
-rw-r--r--synapse/rest/client/v1/room.py38
-rw-r--r--synapse/rest/client/v1/voip.py2
-rw-r--r--synapse/rest/client/v2_alpha/account.py166
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py4
-rw-r--r--synapse/rest/client/v2_alpha/account_validity.py4
-rw-r--r--synapse/rest/client/v2_alpha/auth.py2
-rw-r--r--synapse/rest/client/v2_alpha/capabilities.py2
-rw-r--r--synapse/rest/client/v2_alpha/devices.py6
-rw-r--r--synapse/rest/client/v2_alpha/filter.py4
-rw-r--r--synapse/rest/client/v2_alpha/groups.py48
-rw-r--r--synapse/rest/client/v2_alpha/keys.py12
-rw-r--r--synapse/rest/client/v2_alpha/notifications.py2
-rw-r--r--synapse/rest/client/v2_alpha/openid.py2
-rw-r--r--synapse/rest/client/v2_alpha/password_policy.py2
-rw-r--r--synapse/rest/client/v2_alpha/read_marker.py2
-rw-r--r--synapse/rest/client/v2_alpha/receipts.py2
-rw-r--r--synapse/rest/client/v2_alpha/register.py23
-rw-r--r--synapse/rest/client/v2_alpha/relations.py8
-rw-r--r--synapse/rest/client/v2_alpha/report_event.py2
-rw-r--r--synapse/rest/client/v2_alpha/room_keys.py6
-rw-r--r--synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py2
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py2
-rw-r--r--synapse/rest/client/v2_alpha/shared_rooms.py2
-rw-r--r--synapse/rest/client/v2_alpha/sync.py2
-rw-r--r--synapse/rest/client/v2_alpha/tags.py4
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py8
-rw-r--r--synapse/rest/client/v2_alpha/tokenrefresh.py2
-rw-r--r--synapse/rest/client/v2_alpha/user_directory.py2
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py2
-rw-r--r--synapse/rest/media/v1/filepath.py19
-rw-r--r--synapse/rest/media/v1/media_repository.py69
-rw-r--r--synapse/rest/media/v1/media_storage.py28
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py16
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py5
-rw-r--r--synapse/rest/media/v1/thumbnailer.py18
-rw-r--r--synapse/rest/saml2/response_resource.py16
-rw-r--r--synapse/rest/synapse/__init__.py14
-rw-r--r--synapse/rest/synapse/client/__init__.py14
-rw-r--r--synapse/rest/synapse/client/password_reset.py127
-rw-r--r--synapse/state/__init__.py66
-rw-r--r--synapse/storage/__init__.py5
-rw-r--r--synapse/storage/database.py4
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/__init__.py10
-rw-r--r--synapse/storage/databases/main/account_data.py16
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/client_ips.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py8
-rw-r--r--synapse/storage/databases/main/devices.py17
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py6
-rw-r--r--synapse/storage/databases/main/events.py49
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py76
-rw-r--r--synapse/storage/databases/main/group_server.py2
-rw-r--r--synapse/storage/databases/main/media_repository.py63
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py4
-rw-r--r--synapse/storage/databases/main/presence.py4
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py148
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/receipts.py14
-rw-r--r--synapse/storage/databases/main/registration.py20
-rw-r--r--synapse/storage/databases/main/room.py110
-rw-r--r--synapse/storage/databases/main/roommember.py20
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres33
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite44
-rw-r--r--synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql28
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres26
-rw-r--r--synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql42
-rw-r--r--synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql22
-rw-r--r--synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql21
-rw-r--r--synapse/storage/databases/main/schema/delta/58/18stream_positions.sql22
-rw-r--r--synapse/storage/databases/main/search.py4
-rw-r--r--synapse/storage/databases/main/state.py6
-rw-r--r--synapse/storage/databases/main/stats.py40
-rw-r--r--synapse/storage/databases/main/stream.py82
-rw-r--r--synapse/storage/databases/main/tags.py4
-rw-r--r--synapse/storage/databases/main/transactions.py211
-rw-r--r--synapse/storage/databases/main/ui_auth.py2
-rw-r--r--synapse/storage/databases/main/user_directory.py4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py2
-rw-r--r--synapse/storage/databases/state/bg_updates.py2
-rw-r--r--synapse/storage/databases/state/store.py2
-rw-r--r--synapse/storage/persist_events.py28
-rw-r--r--synapse/storage/prepare_database.py28
-rw-r--r--synapse/storage/relations.py2
-rw-r--r--synapse/storage/roommember.py2
-rw-r--r--synapse/storage/util/id_generators.py331
-rw-r--r--synapse/streams/config.py61
-rw-r--r--synapse/types.py174
-rw-r--r--synapse/util/__init__.py4
-rw-r--r--synapse/util/async_helpers.py135
-rw-r--r--synapse/util/caches/__init__.py2
-rw-r--r--synapse/util/distributor.py50
-rw-r--r--synapse/util/frozenutils.py5
-rw-r--r--synapse/util/manhole.py2
-rw-r--r--synapse/util/patch_inline_callbacks.py2
-rw-r--r--synapse/util/retryutils.py2
220 files changed, 3272 insertions, 1635 deletions
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 55cce2db22..da0996edbc 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -14,8 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from __future__ import print_function
-
 import argparse
 import getpass
 import hashlib
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 75388643ee..1071a0576e 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -218,11 +218,7 @@ class Auth:
             # Deny the request if the user account has expired.
             if self._account_validity.enabled and not allow_expired:
                 user_id = user.to_string()
-                expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
-                if (
-                    expiration_ts is not None
-                    and self.clock.time_msec() >= expiration_ts
-                ):
+                if await self.store.is_account_expired(user_id, self.clock.time_msec()):
                     raise AuthError(
                         403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
                     )
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 94a9e58eae..cd6670d0a2 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -87,7 +87,7 @@ class CodeMessageException(RuntimeError):
     """
 
     def __init__(self, code: Union[int, HTTPStatus], msg: str):
-        super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
+        super().__init__("%d: %s" % (code, msg))
 
         # Some calls to this method pass instances of http.HTTPStatus for `code`.
         # While HTTPStatus is a subclass of int, it has magic __str__ methods
@@ -138,7 +138,7 @@ class SynapseError(CodeMessageException):
             msg: The human-readable error message.
             errcode: The matrix error code e.g 'M_FORBIDDEN'
         """
-        super(SynapseError, self).__init__(code, msg)
+        super().__init__(code, msg)
         self.errcode = errcode
 
     def error_dict(self):
@@ -159,7 +159,7 @@ class ProxiedRequestError(SynapseError):
         errcode: str = Codes.UNKNOWN,
         additional_fields: Optional[Dict] = None,
     ):
-        super(ProxiedRequestError, self).__init__(code, msg, errcode)
+        super().__init__(code, msg, errcode)
         if additional_fields is None:
             self._additional_fields = {}  # type: Dict
         else:
@@ -181,7 +181,7 @@ class ConsentNotGivenError(SynapseError):
             msg: The human-readable error message
             consent_url: The URL where the user can give their consent
         """
-        super(ConsentNotGivenError, self).__init__(
+        super().__init__(
             code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN
         )
         self._consent_uri = consent_uri
@@ -201,7 +201,7 @@ class UserDeactivatedError(SynapseError):
         Args:
             msg: The human-readable error message
         """
-        super(UserDeactivatedError, self).__init__(
+        super().__init__(
             code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED
         )
 
@@ -225,7 +225,7 @@ class FederationDeniedError(SynapseError):
 
         self.destination = destination
 
-        super(FederationDeniedError, self).__init__(
+        super().__init__(
             code=403,
             msg="Federation denied with %s." % (self.destination,),
             errcode=Codes.FORBIDDEN,
@@ -244,9 +244,7 @@ class InteractiveAuthIncompleteError(Exception):
     """
 
     def __init__(self, session_id: str, result: "JsonDict"):
-        super(InteractiveAuthIncompleteError, self).__init__(
-            "Interactive auth not yet complete"
-        )
+        super().__init__("Interactive auth not yet complete")
         self.session_id = session_id
         self.result = result
 
@@ -261,14 +259,14 @@ class UnrecognizedRequestError(SynapseError):
             message = "Unrecognized request"
         else:
             message = args[0]
-        super(UnrecognizedRequestError, self).__init__(400, message, **kwargs)
+        super().__init__(400, message, **kwargs)
 
 
 class NotFoundError(SynapseError):
     """An error indicating we can't find the thing you asked for"""
 
     def __init__(self, msg: str = "Not found", errcode: str = Codes.NOT_FOUND):
-        super(NotFoundError, self).__init__(404, msg, errcode=errcode)
+        super().__init__(404, msg, errcode=errcode)
 
 
 class AuthError(SynapseError):
@@ -279,7 +277,7 @@ class AuthError(SynapseError):
     def __init__(self, *args, **kwargs):
         if "errcode" not in kwargs:
             kwargs["errcode"] = Codes.FORBIDDEN
-        super(AuthError, self).__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
 
 
 class InvalidClientCredentialsError(SynapseError):
@@ -335,7 +333,7 @@ class ResourceLimitError(SynapseError):
     ):
         self.admin_contact = admin_contact
         self.limit_type = limit_type
-        super(ResourceLimitError, self).__init__(code, msg, errcode=errcode)
+        super().__init__(code, msg, errcode=errcode)
 
     def error_dict(self):
         return cs_error(
@@ -352,7 +350,7 @@ class EventSizeError(SynapseError):
     def __init__(self, *args, **kwargs):
         if "errcode" not in kwargs:
             kwargs["errcode"] = Codes.TOO_LARGE
-        super(EventSizeError, self).__init__(413, *args, **kwargs)
+        super().__init__(413, *args, **kwargs)
 
 
 class EventStreamError(SynapseError):
@@ -361,7 +359,7 @@ class EventStreamError(SynapseError):
     def __init__(self, *args, **kwargs):
         if "errcode" not in kwargs:
             kwargs["errcode"] = Codes.BAD_PAGINATION
-        super(EventStreamError, self).__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
 
 
 class LoginError(SynapseError):
@@ -384,7 +382,7 @@ class InvalidCaptchaError(SynapseError):
         error_url: Optional[str] = None,
         errcode: str = Codes.CAPTCHA_INVALID,
     ):
-        super(InvalidCaptchaError, self).__init__(code, msg, errcode)
+        super().__init__(code, msg, errcode)
         self.error_url = error_url
 
     def error_dict(self):
@@ -402,7 +400,7 @@ class LimitExceededError(SynapseError):
         retry_after_ms: Optional[int] = None,
         errcode: str = Codes.LIMIT_EXCEEDED,
     ):
-        super(LimitExceededError, self).__init__(code, msg, errcode)
+        super().__init__(code, msg, errcode)
         self.retry_after_ms = retry_after_ms
 
     def error_dict(self):
@@ -418,9 +416,7 @@ class RoomKeysVersionError(SynapseError):
         Args:
             current_version: the current version of the store they should have used
         """
-        super(RoomKeysVersionError, self).__init__(
-            403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION
-        )
+        super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION)
         self.current_version = current_version
 
 
@@ -429,7 +425,7 @@ class UnsupportedRoomVersionError(SynapseError):
     not support."""
 
     def __init__(self, msg: str = "Homeserver does not support this room version"):
-        super(UnsupportedRoomVersionError, self).__init__(
+        super().__init__(
             code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION,
         )
 
@@ -440,7 +436,7 @@ class ThreepidValidationError(SynapseError):
     def __init__(self, *args, **kwargs):
         if "errcode" not in kwargs:
             kwargs["errcode"] = Codes.FORBIDDEN
-        super(ThreepidValidationError, self).__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
 
 
 class IncompatibleRoomVersionError(SynapseError):
@@ -451,7 +447,7 @@ class IncompatibleRoomVersionError(SynapseError):
     """
 
     def __init__(self, room_version: str):
-        super(IncompatibleRoomVersionError, self).__init__(
+        super().__init__(
             code=400,
             msg="Your homeserver does not support the features required to "
             "join this room",
@@ -473,7 +469,7 @@ class PasswordRefusedError(SynapseError):
         msg: str = "This password doesn't comply with the server's policy",
         errcode: str = Codes.WEAK_PASSWORD,
     ):
-        super(PasswordRefusedError, self).__init__(
+        super().__init__(
             code=400, msg=msg, errcode=errcode,
         )
 
@@ -488,7 +484,7 @@ class RequestSendFailed(RuntimeError):
     """
 
     def __init__(self, inner_exception, can_retry):
-        super(RequestSendFailed, self).__init__(
+        super().__init__(
             "Failed to send request: %s: %s"
             % (type(inner_exception).__name__, inner_exception)
         )
@@ -542,7 +538,7 @@ class FederationError(RuntimeError):
         self.source = source
 
         msg = "%s %s: %s" % (level, code, reason)
-        super(FederationError, self).__init__(msg)
+        super().__init__(msg)
 
     def get_dict(self):
         return {
@@ -570,7 +566,7 @@ class HttpResponseException(CodeMessageException):
             msg: reason phrase from HTTP response status line
             response: body of response
         """
-        super(HttpResponseException, self).__init__(code, msg)
+        super().__init__(code, msg)
         self.response = response
 
     def to_synapse_error(self):
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 2a2c9e6f13..5caf336fd0 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,10 +15,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import json
 from typing import List
 
 import jsonschema
-from canonicaljson import json
 from jsonschema import FormatChecker
 
 from synapse.api.constants import EventContentFields
@@ -132,7 +132,7 @@ def matrix_user_id_validator(user_id_str):
 
 class Filtering:
     def __init__(self, hs):
-        super(Filtering, self).__init__()
+        super().__init__()
         self.store = hs.get_datastore()
 
     async def get_user_filter(self, user_localpart, filter_id):
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index bbfccf955e..6379c86dde 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -21,6 +21,7 @@ from urllib.parse import urlencode
 
 from synapse.config import ConfigError
 
+SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client"
 CLIENT_API_PREFIX = "/_matrix/client"
 FEDERATION_PREFIX = "/_matrix/federation"
 FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index b6c9085670..7d309b1bb0 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -14,13 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import argparse
+import json
 import logging
 import os
 import sys
 import tempfile
 
-from canonicaljson import json
-
 from twisted.internet import defer, task
 
 import synapse
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f985810e88..c38413c893 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -152,7 +152,7 @@ class PresenceStatusStubServlet(RestServlet):
     PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
 
     def __init__(self, hs):
-        super(PresenceStatusStubServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
 
     async def on_GET(self, request, user_id):
@@ -176,7 +176,7 @@ class KeyUploadServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(KeyUploadServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.http_client = hs.get_simple_http_client()
@@ -646,7 +646,7 @@ class GenericWorkerServer(HomeServer):
 
 class GenericWorkerReplicationHandler(ReplicationDataHandler):
     def __init__(self, hs):
-        super(GenericWorkerReplicationHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.presence_handler = hs.get_presence_handler()  # type: GenericWorkerPresence
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 6014adc850..dff739e106 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -15,8 +15,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from __future__ import print_function
-
 import gc
 import logging
 import math
@@ -48,6 +46,7 @@ from synapse.api.urls import (
 from synapse.app import _base
 from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
 from synapse.config._base import ConfigError
+from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig
 from synapse.federation.transport.server import TransportLayerServer
@@ -209,6 +208,15 @@ class SynapseHomeServer(HomeServer):
 
                 resources["/_matrix/saml2"] = SAML2Resource(self)
 
+            if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+                from synapse.rest.synapse.client.password_reset import (
+                    PasswordResetSubmitTokenResource,
+                )
+
+                resources[
+                    "/_synapse/client/password_reset/email/submit_token"
+                ] = PasswordResetSubmitTokenResource(self)
+
         if name == "consent":
             from synapse.rest.consent.consent_resource import ConsentResource
 
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index bb6fa8299a..c526c28b93 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -88,7 +88,7 @@ class ApplicationServiceApi(SimpleHttpClient):
     """
 
     def __init__(self, hs):
-        super(ApplicationServiceApi, self).__init__(hs)
+        super().__init__(hs)
         self.clock = hs.get_clock()
 
         self.protocol_meta_cache = ResponseCache(
@@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                 urllib.parse.quote(protocol),
             )
             try:
-                info = await self.get_json(uri, {})
+                info = await self.get_json(uri)
 
                 if not _is_valid_3pe_metadata(info):
                     logger.warning(
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index f8ab8e38df..05a66841c3 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -838,11 +838,26 @@ class ShardedWorkerHandlingConfig:
     def should_handle(self, instance_name: str, key: str) -> bool:
         """Whether this instance is responsible for handling the given key.
         """
-
-        # If multiple instances are not defined we always return true.
+        # If multiple instances are not defined we always return true
         if not self.instances or len(self.instances) == 1:
             return True
 
+        return self.get_instance(key) == instance_name
+
+    def get_instance(self, key: str) -> str:
+        """Get the instance responsible for handling the given key.
+
+        Note: For things like federation sending the config for which instance
+        is sending is known only to the sender instance if there is only one.
+        Therefore `should_handle` should be used where possible.
+        """
+
+        if not self.instances:
+            return "master"
+
+        if len(self.instances) == 1:
+            return self.instances[0]
+
         # We shard by taking the hash, modulo it by the number of instances and
         # then checking whether this instance matches the instance at that
         # index.
@@ -852,7 +867,7 @@ class ShardedWorkerHandlingConfig:
         dest_hash = sha256(key.encode("utf8")).digest()
         dest_int = int.from_bytes(dest_hash, byteorder="little")
         remainder = dest_int % (len(self.instances))
-        return self.instances[remainder] == instance_name
+        return self.instances[remainder]
 
 
 __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index eb911e8f9f..b8faafa9bd 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig:
     instances: List[str]
     def __init__(self, instances: List[str]) -> None: ...
     def should_handle(self, instance_name: str, key: str) -> bool: ...
+    def get_instance(self, key: str) -> str: ...
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index aec9c4bbce..fbddebeeab 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -77,7 +77,7 @@ class ConsentConfig(Config):
     section = "consent"
 
     def __init__(self, *args):
-        super(ConsentConfig, self).__init__(*args)
+        super().__init__(*args)
 
         self.user_consent_version = None
         self.user_consent_template_dir = None
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 7a796996c0..cceffbfee2 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -14,7 +14,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from __future__ import print_function
 
 # This file can't be called email.py because if it is, we cannot:
 import email.utils
@@ -228,6 +227,7 @@ class EmailConfig(Config):
                 self.email_registration_template_text,
                 self.email_add_threepid_template_html,
                 self.email_add_threepid_template_text,
+                self.email_password_reset_template_confirmation_html,
                 self.email_password_reset_template_failure_html,
                 self.email_registration_template_failure_html,
                 self.email_add_threepid_template_failure_html,
@@ -242,6 +242,7 @@ class EmailConfig(Config):
                     registration_template_text,
                     add_threepid_template_html,
                     add_threepid_template_text,
+                    "password_reset_confirmation.html",
                     password_reset_template_failure_html,
                     registration_template_failure_html,
                     add_threepid_template_failure_html,
@@ -404,9 +405,13 @@ class EmailConfig(Config):
           # * The contents of password reset emails sent by the homeserver:
           #   'password_reset.html' and 'password_reset.txt'
           #
-          # * HTML pages for success and failure that a user will see when they follow
-          #   the link in the password reset email: 'password_reset_success.html' and
-          #   'password_reset_failure.html'
+          # * An HTML page that a user will see when they follow the link in the password
+          #   reset email. The user will be asked to confirm the action before their
+          #   password is reset: 'password_reset_confirmation.html'
+          #
+          # * HTML pages for success and failure that a user will see when they confirm
+          #   the password reset flow using the page above: 'password_reset_success.html'
+          #   and 'password_reset_failure.html'
           #
           # * The contents of address verification emails sent during registration:
           #   'registration.html' and 'registration.txt'
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index c96e6ef62a..13d6f6a3ea 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -17,6 +17,7 @@ import logging
 import logging.config
 import os
 import sys
+import threading
 from string import Template
 
 import yaml
@@ -25,6 +26,7 @@ from twisted.logger import (
     ILogObserver,
     LogBeginner,
     STDLibLogObserver,
+    eventAsText,
     globalLogBeginner,
 )
 
@@ -216,8 +218,9 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
     # system.
     observer = STDLibLogObserver()
 
-    def _log(event):
+    threadlocal = threading.local()
 
+    def _log(event):
         if "log_text" in event:
             if event["log_text"].startswith("DNSDatagramProtocol starting on "):
                 return
@@ -228,7 +231,25 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
             if event["log_text"].startswith("Timing out client"):
                 return
 
-        return observer(event)
+        # this is a workaround to make sure we don't get stack overflows when the
+        # logging system raises an error which is written to stderr which is redirected
+        # to the logging system, etc.
+        if getattr(threadlocal, "active", False):
+            # write the text of the event, if any, to the *real* stderr (which may
+            # be redirected to /dev/null, but there's not much we can do)
+            try:
+                event_text = eventAsText(event)
+                print("logging during logging: %s" % event_text, file=sys.__stderr__)
+            except Exception:
+                # gah.
+                pass
+            return
+
+        try:
+            threadlocal.active = True
+            return observer(event)
+        finally:
+            threadlocal.active = False
 
     logBeginner.beginLoggingTo([_log], redirectStandardIO=not config.no_redirect_stdio)
     if not config.no_redirect_stdio:
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index a185655774..5ffbb934fe 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -30,7 +30,7 @@ class AccountValidityConfig(Config):
     def __init__(self, config, synapse_config):
         if config is None:
             return
-        super(AccountValidityConfig, self).__init__()
+        super().__init__()
         self.enabled = config.get("enabled", False)
         self.renew_by_email_enabled = "renew_at" in config
 
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 755478e2ff..99aa8b3bf1 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -169,12 +169,6 @@ class SAML2Config(Config):
             saml2_config.get("saml_session_lifetime", "15m")
         )
 
-        # We enable autoescape here as the message may potentially come from a
-        # remote resource
-        self.saml2_error_html_template = self.read_templates(
-            ["saml_error.html"], saml2_config.get("template_dir"), autoescape=True
-        )[0]
-
     def _default_saml_config_dict(
         self, required_attributes: set, optional_attributes: set
     ):
@@ -227,11 +221,14 @@ class SAML2Config(Config):
         # At least one of `sp_config` or `config_path` must be set in this section to
         # enable SAML login.
         #
-        # (You will probably also want to set the following options to `false` to
+        # You will probably also want to set the following options to `false` to
         # disable the regular login/registration flows:
         #   * enable_registration
         #   * password_config.enabled
         #
+        # You will also want to investigate the settings under the "sso" configuration
+        # section below.
+        #
         # Once SAML support is enabled, a metadata file will be exposed at
         # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
         # use to configure your SAML IdP with. Alternatively, you can manually configure
@@ -353,31 +350,6 @@ class SAML2Config(Config):
           #    value: "staff"
           #  - attribute: department
           #    value: "sales"
-
-          # Directory in which Synapse will try to find the template files below.
-          # If not set, default templates from within the Synapse package will be used.
-          #
-          # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
-          # If you *do* uncomment it, you will need to make sure that all the templates
-          # below are in the directory.
-          #
-          # Synapse will look for the following templates in this directory:
-          #
-          # * HTML page to display to users if something goes wrong during the
-          #   authentication process: 'saml_error.html'.
-          #
-          #   When rendering, this template is given the following variables:
-          #     * code: an HTML error code corresponding to the error that is being
-          #       returned (typically 400 or 500)
-          #
-          #     * msg: a textual message describing the error.
-          #
-          #   The variables will automatically be HTML-escaped.
-          #
-          # You can see the default templates at:
-          # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
-          #
-          #template_dir: "res/templates"
         """ % {
             "config_dir_path": config_dir_path
         }
diff --git a/synapse/config/server.py b/synapse/config/server.py
index e85c6a0840..532b910470 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,7 +19,7 @@ import logging
 import os.path
 import re
 from textwrap import indent
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Set
 
 import attr
 import yaml
@@ -542,6 +542,19 @@ class ServerConfig(Config):
             users_new_default_push_rules
         )  # type: set
 
+        # Whitelist of domain names that given next_link parameters must have
+        next_link_domain_whitelist = config.get(
+            "next_link_domain_whitelist"
+        )  # type: Optional[List[str]]
+
+        self.next_link_domain_whitelist = None  # type: Optional[Set[str]]
+        if next_link_domain_whitelist is not None:
+            if not isinstance(next_link_domain_whitelist, list):
+                raise ConfigError("'next_link_domain_whitelist' must be a list")
+
+            # Turn the list into a set to improve lookup speed.
+            self.next_link_domain_whitelist = set(next_link_domain_whitelist)
+
     def has_tls_listener(self) -> bool:
         return any(listener.tls for listener in self.listeners)
 
@@ -1014,6 +1027,24 @@ class ServerConfig(Config):
         # act as if no error happened and return a fake session ID ('sid') to clients.
         #
         #request_token_inhibit_3pid_errors: true
+
+        # A list of domains that the domain portion of 'next_link' parameters
+        # must match.
+        #
+        # This parameter is optionally provided by clients while requesting
+        # validation of an email or phone number, and maps to a link that
+        # users will be automatically redirected to after validation
+        # succeeds. Clients can make use this parameter to aid the validation
+        # process.
+        #
+        # The whitelist is applied whether the homeserver or an
+        # identity server is handling validation.
+        #
+        # The default value is no whitelist functionality; all domains are
+        # allowed. Setting this value to an empty list will instead disallow
+        # all domains.
+        #
+        #next_link_domain_whitelist: ["matrix.org"]
         """
             % locals()
         )
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
index 6c427b6f92..57f69dc8e2 100644
--- a/synapse/config/server_notices_config.py
+++ b/synapse/config/server_notices_config.py
@@ -62,7 +62,7 @@ class ServerNoticesConfig(Config):
     section = "servernotices"
 
     def __init__(self, *args):
-        super(ServerNoticesConfig, self).__init__(*args)
+        super().__init__(*args)
         self.server_notices_mxid = None
         self.server_notices_mxid_display_name = None
         self.server_notices_mxid_avatar_url = None
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index 62485189ea..b559bfa411 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -13,8 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from __future__ import division
-
 import sys
 
 from ._base import Config
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index c784a71508..f23e42cdf9 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -13,12 +13,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import List, Union
+
 import attr
 
 from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
 from .server import ListenerConfig, parse_listener_def
 
 
+def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
+    """Helper for allowing parsing a string or list of strings to a config
+    option expecting a list of strings.
+    """
+
+    if isinstance(obj, str):
+        return [obj]
+    return obj
+
+
 @attr.s
 class InstanceLocationConfig:
     """The host and port to talk to an instance via HTTP replication.
@@ -33,11 +45,13 @@ class WriterLocations:
     """Specifies the instances that write various streams.
 
     Attributes:
-        events: The instance that writes to the event and backfill streams.
-        events: The instance that writes to the typing stream.
+        events: The instances that write to the event and backfill streams.
+        typing: The instance that writes to the typing stream.
     """
 
-    events = attr.ib(default="master", type=str)
+    events = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter
+    )
     typing = attr.ib(default="master", type=str)
 
 
@@ -105,15 +119,18 @@ class WorkerConfig(Config):
         writers = config.get("stream_writers") or {}
         self.writers = WriterLocations(**writers)
 
-        # Check that the configured writer for events and typing also appears in
+        # Check that the configured writers for events and typing also appears in
         # `instance_map`.
         for stream in ("events", "typing"):
-            instance = getattr(self.writers, stream)
-            if instance != "master" and instance not in self.instance_map:
-                raise ConfigError(
-                    "Instance %r is configured to write %s but does not appear in `instance_map` config."
-                    % (instance, stream)
-                )
+            instances = _instance_to_list_converter(getattr(self.writers, stream))
+            for instance in instances:
+                if instance != "master" and instance not in self.instance_map:
+                    raise ConfigError(
+                        "Instance %r is configured to write %s but does not appear in `instance_map` config."
+                        % (instance, stream)
+                    )
+
+        self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 2b03f5ac76..79668a402e 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -45,7 +45,11 @@ _TLS_VERSION_MAP = {
 
 class ServerContextFactory(ContextFactory):
     """Factory for PyOpenSSL SSL contexts that are used to handle incoming
-    connections."""
+    connections.
+
+    TODO: replace this with an implementation of IOpenSSLServerConnectionCreator,
+    per https://github.com/matrix-org/synapse/issues/1691
+    """
 
     def __init__(self, config):
         # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 32c31b1cd1..42e4087a92 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -558,7 +558,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
     """KeyFetcher impl which fetches keys from the "perspectives" servers"""
 
     def __init__(self, hs):
-        super(PerspectivesKeyFetcher, self).__init__(hs)
+        super().__init__(hs)
         self.clock = hs.get_clock()
         self.client = hs.get_http_client()
         self.key_servers = self.config.key_servers
@@ -728,7 +728,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
     """KeyFetcher impl which fetches keys from the origin servers"""
 
     def __init__(self, hs):
-        super(ServerKeyFetcher, self).__init__(hs)
+        super().__init__(hs)
         self.clock = hs.get_clock()
         self.client = hs.get_http_client()
 
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d42930d1b9..688d43fffb 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -79,7 +79,7 @@ class InvalidResponseError(RuntimeError):
 
 class FederationClient(FederationBase):
     def __init__(self, hs):
-        super(FederationClient, self).__init__(hs)
+        super().__init__(hs)
 
         self.pdu_destination_tried = {}
         self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ff00f0b302..2dcd081cbc 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -90,7 +90,7 @@ pdu_process_time = Histogram(
 
 class FederationServer(FederationBase):
     def __init__(self, hs):
-        super(FederationServer, self).__init__(hs)
+        super().__init__(hs)
 
         self.auth = hs.get_auth()
         self.handler = hs.get_handlers().federation_handler
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 552519e82c..8bb17b3a05 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -55,6 +55,15 @@ sent_pdus_destination_dist_total = Counter(
     "Total number of PDUs queued for sending across all destinations",
 )
 
+# Time (in s) after Synapse's startup that we will begin to wake up destinations
+# that have catch-up outstanding.
+CATCH_UP_STARTUP_DELAY_SEC = 15
+
+# Time (in s) to wait in between waking up each destination, i.e. one destination
+# will be woken up every <x> seconds after Synapse's startup until we have woken
+# every destination has outstanding catch-up.
+CATCH_UP_STARTUP_INTERVAL_SEC = 5
+
 
 class FederationSender:
     def __init__(self, hs: "synapse.server.HomeServer"):
@@ -125,6 +134,14 @@ class FederationSender:
             1000.0 / hs.config.federation_rr_transactions_per_room_per_second
         )
 
+        # wake up destinations that have outstanding PDUs to be caught up
+        self._catchup_after_startup_timer = self.clock.call_later(
+            CATCH_UP_STARTUP_DELAY_SEC,
+            run_as_background_process,
+            "wake_destinations_needing_catchup",
+            self._wake_destinations_needing_catchup,
+        )
+
     def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
         """Get or create a PerDestinationQueue for the given destination
 
@@ -209,7 +226,7 @@ class FederationSender:
                     logger.debug("Sending %s to %r", event, destinations)
 
                     if destinations:
-                        self._send_pdu(event, destinations)
+                        await self._send_pdu(event, destinations)
 
                         now = self.clock.time_msec()
                         ts = await self.store.get_received_ts(event.event_id)
@@ -265,7 +282,7 @@ class FederationSender:
         finally:
             self._is_processing = False
 
-    def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
+    async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
         # We loop through all destinations to see whether we already have
         # a transaction in progress. If we do, stick it in the pending_pdus
         # table and we'll get back to it later.
@@ -280,6 +297,13 @@ class FederationSender:
         sent_pdus_destination_dist_total.inc(len(destinations))
         sent_pdus_destination_dist_count.inc()
 
+        # track the fact that we have a PDU for these destinations,
+        # to allow us to perform catch-up later on if the remote is unreachable
+        # for a while.
+        await self.store.store_destination_rooms_entries(
+            destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
+        )
+
         for destination in destinations:
             self._get_per_destination_queue(destination).send_pdu(pdu)
 
@@ -553,3 +577,37 @@ class FederationSender:
         # Dummy implementation for case where federation sender isn't offloaded
         # to a worker.
         return [], 0, False
+
+    async def _wake_destinations_needing_catchup(self):
+        """
+        Wakes up destinations that need catch-up and are not currently being
+        backed off from.
+
+        In order to reduce load spikes, adds a delay between each destination.
+        """
+
+        last_processed = None  # type: Optional[str]
+
+        while True:
+            destinations_to_wake = await self.store.get_catch_up_outstanding_destinations(
+                last_processed
+            )
+
+            if not destinations_to_wake:
+                # finished waking all destinations!
+                self._catchup_after_startup_timer = None
+                break
+
+            destinations_to_wake = [
+                d
+                for d in destinations_to_wake
+                if self._federation_shard_config.should_handle(self._instance_name, d)
+            ]
+
+            for last_processed in destinations_to_wake:
+                logger.info(
+                    "Destination %s has outstanding catch-up, waking up.",
+                    last_processed,
+                )
+                self.wake_destination(last_processed)
+                await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index defc228c23..2657767fd1 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 import datetime
 import logging
-from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
 
 from prometheus_client import Counter
 
@@ -92,6 +92,21 @@ class PerDestinationQueue:
         self._destination = destination
         self.transmission_loop_running = False
 
+        # True whilst we are sending events that the remote homeserver missed
+        # because it was unreachable. We start in this state so we can perform
+        # catch-up at startup.
+        # New events will only be sent once this is finished, at which point
+        # _catching_up is flipped to False.
+        self._catching_up = True  # type: bool
+
+        # The stream_ordering of the most recent PDU that was discarded due to
+        # being in catch-up mode.
+        self._catchup_last_skipped = 0  # type: int
+
+        # Cache of the last successfully-transmitted stream ordering for this
+        # destination (we are the only updater so this is safe)
+        self._last_successful_stream_ordering = None  # type: Optional[int]
+
         # a list of pending PDUs
         self._pending_pdus = []  # type: List[EventBase]
 
@@ -138,7 +153,13 @@ class PerDestinationQueue:
         Args:
             pdu: pdu to send
         """
-        self._pending_pdus.append(pdu)
+        if not self._catching_up or self._last_successful_stream_ordering is None:
+            # only enqueue the PDU if we are not catching up (False) or do not
+            # yet know if we have anything to catch up (None)
+            self._pending_pdus.append(pdu)
+        else:
+            self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
+
         self.attempt_new_transaction()
 
     def send_presence(self, states: Iterable[UserPresenceState]) -> None:
@@ -218,6 +239,13 @@ class PerDestinationQueue:
             # hence why we throw the result away.
             await get_retry_limiter(self._destination, self._clock, self._store)
 
+            if self._catching_up:
+                # we potentially need to catch-up first
+                await self._catch_up_transmission_loop()
+                if self._catching_up:
+                    # not caught up yet
+                    return
+
             pending_pdus = []
             while True:
                 # We have to keep 2 free slots for presence and rr_edus
@@ -325,6 +353,17 @@ class PerDestinationQueue:
 
                     self._last_device_stream_id = device_stream_id
                     self._last_device_list_stream_id = dev_list_id
+
+                    if pending_pdus:
+                        # we sent some PDUs and it was successful, so update our
+                        # last_successful_stream_ordering in the destinations table.
+                        final_pdu = pending_pdus[-1]
+                        last_successful_stream_ordering = (
+                            final_pdu.internal_metadata.stream_ordering
+                        )
+                        await self._store.set_destination_last_successful_stream_ordering(
+                            self._destination, last_successful_stream_ordering
+                        )
                 else:
                     break
         except NotRetryingDestination as e:
@@ -340,8 +379,9 @@ class PerDestinationQueue:
             if e.retry_interval > 60 * 60 * 1000:
                 # we won't retry for another hour!
                 # (this suggests a significant outage)
-                # We drop pending PDUs and EDUs because otherwise they will
+                # We drop pending EDUs because otherwise they will
                 # rack up indefinitely.
+                # (Dropping PDUs is already performed by `_start_catching_up`.)
                 # Note that:
                 # - the EDUs that are being dropped here are those that we can
                 #   afford to drop (specifically, only typing notifications,
@@ -353,11 +393,12 @@ class PerDestinationQueue:
 
                 # dropping read receipts is a bit sad but should be solved
                 # through another mechanism, because this is all volatile!
-                self._pending_pdus = []
                 self._pending_edus = []
                 self._pending_edus_keyed = {}
                 self._pending_presence = {}
                 self._pending_rrs = {}
+
+            self._start_catching_up()
         except FederationDeniedError as e:
             logger.info(e)
         except HttpResponseException as e:
@@ -367,6 +408,8 @@ class PerDestinationQueue:
                 e.code,
                 e,
             )
+
+            self._start_catching_up()
         except RequestSendFailed as e:
             logger.warning(
                 "TX [%s] Failed to send transaction: %s", self._destination, e
@@ -376,16 +419,96 @@ class PerDestinationQueue:
                 logger.info(
                     "Failed to send event %s to %s", p.event_id, self._destination
                 )
+
+            self._start_catching_up()
         except Exception:
             logger.exception("TX [%s] Failed to send transaction", self._destination)
             for p in pending_pdus:
                 logger.info(
                     "Failed to send event %s to %s", p.event_id, self._destination
                 )
+
+            self._start_catching_up()
         finally:
             # We want to be *very* sure we clear this after we stop processing
             self.transmission_loop_running = False
 
+    async def _catch_up_transmission_loop(self) -> None:
+        first_catch_up_check = self._last_successful_stream_ordering is None
+
+        if first_catch_up_check:
+            # first catchup so get last_successful_stream_ordering from database
+            self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering(
+                self._destination
+            )
+
+        if self._last_successful_stream_ordering is None:
+            # if it's still None, then this means we don't have the information
+            # in our database ­ we haven't successfully sent a PDU to this server
+            # (at least since the introduction of the feature tracking
+            # last_successful_stream_ordering).
+            # Sadly, this means we can't do anything here as we don't know what
+            # needs catching up — so catching up is futile; let's stop.
+            self._catching_up = False
+            return
+
+        # get at most 50 catchup room/PDUs
+        while True:
+            event_ids = await self._store.get_catch_up_room_event_ids(
+                self._destination, self._last_successful_stream_ordering,
+            )
+
+            if not event_ids:
+                # No more events to catch up on, but we can't ignore the chance
+                # of a race condition, so we check that no new events have been
+                # skipped due to us being in catch-up mode
+
+                if self._catchup_last_skipped > self._last_successful_stream_ordering:
+                    # another event has been skipped because we were in catch-up mode
+                    continue
+
+                # we are done catching up!
+                self._catching_up = False
+                break
+
+            if first_catch_up_check:
+                # as this is our check for needing catch-up, we may have PDUs in
+                # the queue from before we *knew* we had to do catch-up, so
+                # clear those out now.
+                self._start_catching_up()
+
+            # fetch the relevant events from the event store
+            # - redacted behaviour of REDACT is fine, since we only send metadata
+            #   of redacted events to the destination.
+            # - don't need to worry about rejected events as we do not actively
+            #   forward received events over federation.
+            catchup_pdus = await self._store.get_events_as_list(event_ids)
+            if not catchup_pdus:
+                raise AssertionError(
+                    "No events retrieved when we asked for %r. "
+                    "This should not happen." % event_ids
+                )
+
+            if logger.isEnabledFor(logging.INFO):
+                rooms = (p.room_id for p in catchup_pdus)
+                logger.info("Catching up rooms to %s: %r", self._destination, rooms)
+
+            success = await self._transaction_manager.send_new_transaction(
+                self._destination, catchup_pdus, []
+            )
+
+            if not success:
+                return
+
+            sent_transactions_counter.inc()
+            final_pdu = catchup_pdus[-1]
+            self._last_successful_stream_ordering = cast(
+                int, final_pdu.internal_metadata.stream_ordering
+            )
+            await self._store.set_destination_last_successful_stream_ordering(
+                self._destination, self._last_successful_stream_ordering
+            )
+
     def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
         if not self._pending_rrs:
             return
@@ -446,3 +569,12 @@ class PerDestinationQueue:
         ]
 
         return (edus, stream_id)
+
+    def _start_catching_up(self) -> None:
+        """
+        Marks this destination as being in catch-up mode.
+
+        This throws away the PDU queue.
+        """
+        self._catching_up = True
+        self._pending_pdus = []
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index cc7e9a973b..3a6b95631e 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -68,7 +68,7 @@ class TransportLayerServer(JsonResource):
         self.clock = hs.get_clock()
         self.servlet_groups = servlet_groups
 
-        super(TransportLayerServer, self).__init__(hs, canonical_json=False)
+        super().__init__(hs, canonical_json=False)
 
         self.authenticator = Authenticator(hs)
         self.ratelimiter = hs.get_federation_ratelimiter()
@@ -376,9 +376,7 @@ class FederationSendServlet(BaseFederationServlet):
     RATELIMIT = False
 
     def __init__(self, handler, server_name, **kwargs):
-        super(FederationSendServlet, self).__init__(
-            handler, server_name=server_name, **kwargs
-        )
+        super().__init__(handler, server_name=server_name, **kwargs)
         self.server_name = server_name
 
     # This is when someone is trying to send us a bunch of data.
@@ -773,9 +771,7 @@ class PublicRoomList(BaseFederationServlet):
     PATH = "/publicRooms"
 
     def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
-        super(PublicRoomList, self).__init__(
-            handler, authenticator, ratelimiter, server_name
-        )
+        super().__init__(handler, authenticator, ratelimiter, server_name)
         self.allow_access = allow_access
 
     async def on_GET(self, origin, content, query):
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 1dd20ee4e1..e5f85b472d 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -336,7 +336,7 @@ class GroupsServerWorkerHandler:
 
 class GroupsServerHandler(GroupsServerWorkerHandler):
     def __init__(self, hs):
-        super(GroupsServerHandler, self).__init__(hs)
+        super().__init__(hs)
 
         # Ensure attestations get renewed
         hs.get_groups_attestation_renewer()
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 69650ff221..7294649d71 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -76,7 +76,7 @@ def create_issuing_service(reactor, acme_url, account_key_file, well_known_resou
     )
 
 
-@attr.s
+@attr.s(slots=True)
 @implementer(ICertificateStore)
 class ErsatzStore:
     """
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 918d0e037c..dd981c597e 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
 
 class AdminHandler(BaseHandler):
     def __init__(self, hs):
-        super(AdminHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -125,8 +125,8 @@ class AdminHandler(BaseHandler):
             else:
                 stream_ordering = room.stream_ordering
 
-            from_key = str(RoomStreamToken(0, 0))
-            to_key = str(RoomStreamToken(None, stream_ordering))
+            from_key = RoomStreamToken(0, 0)
+            to_key = RoomStreamToken(None, stream_ordering)
 
             written_events = set()  # Events that we've processed in this room
 
@@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
                 if not events:
                     break
 
-                from_key = events[-1].internal_metadata.after
+                from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
 
                 events = await filter_events_for_client(self.storage, user_id, events)
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 90189869cc..0322b60cfc 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -145,7 +145,7 @@ class AuthHandler(BaseHandler):
         Args:
             hs (synapse.server.HomeServer):
         """
-        super(AuthHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
@@ -1235,7 +1235,7 @@ class AuthHandler(BaseHandler):
         return urllib.parse.urlunparse(url_parts)
 
 
-@attr.s
+@attr.s(slots=True)
 class MacaroonGenerator:
 
     hs = attr.ib()
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 25169157c1..0635ad5708 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -29,7 +29,7 @@ class DeactivateAccountHandler(BaseHandler):
     """Handler which deals with deactivating user accounts."""
 
     def __init__(self, hs):
-        super(DeactivateAccountHandler, self).__init__(hs)
+        super().__init__(hs)
         self.hs = hs
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 643d71a710..4149520d6c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
 from synapse.api import errors
 from synapse.api.constants import EventTypes
 from synapse.api.errors import (
+    Codes,
     FederationDeniedError,
     HttpResponseException,
     RequestSendFailed,
@@ -29,6 +30,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import (
     RoomStreamToken,
+    StreamToken,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
@@ -47,7 +49,7 @@ MAX_DEVICE_DISPLAY_NAME_LEN = 100
 
 class DeviceWorkerHandler(BaseHandler):
     def __init__(self, hs):
-        super(DeviceWorkerHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.hs = hs
         self.state = hs.get_state_handler()
@@ -104,18 +106,15 @@ class DeviceWorkerHandler(BaseHandler):
 
     @trace
     @measure_func("device.get_user_ids_changed")
-    async def get_user_ids_changed(self, user_id, from_token):
+    async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
         """Get list of users that have had the devices updated, or have newly
         joined a room, that `user_id` may be interested in.
-
-        Args:
-            user_id (str)
-            from_token (StreamToken)
         """
 
         set_tag("user_id", user_id)
         set_tag("from_token", from_token)
-        now_room_key = await self.store.get_room_events_max_id()
+        now_room_id = self.store.get_room_max_stream_ordering()
+        now_room_key = RoomStreamToken(None, now_room_id)
 
         room_ids = await self.store.get_rooms_for_user(user_id)
 
@@ -142,7 +141,7 @@ class DeviceWorkerHandler(BaseHandler):
         )
         rooms_changed.update(event.room_id for event in member_events)
 
-        stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
+        stream_ordering = from_token.room_key.stream
 
         possibly_changed = set(changed)
         possibly_left = set()
@@ -253,7 +252,7 @@ class DeviceWorkerHandler(BaseHandler):
 
 class DeviceHandler(DeviceWorkerHandler):
     def __init__(self, hs):
-        super(DeviceHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
 
@@ -267,6 +266,24 @@ class DeviceHandler(DeviceWorkerHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
+    def _check_device_name_length(self, name: str):
+        """
+        Checks whether a device name is longer than the maximum allowed length.
+
+        Args:
+            name: The name of the device.
+
+        Raises:
+            SynapseError: if the device name is too long.
+        """
+        if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN:
+            raise SynapseError(
+                400,
+                "Device display name is too long (max %i)"
+                % (MAX_DEVICE_DISPLAY_NAME_LEN,),
+                errcode=Codes.TOO_LARGE,
+            )
+
     async def check_device_registered(
         self, user_id, device_id, initial_device_display_name=None
     ):
@@ -284,6 +301,9 @@ class DeviceHandler(DeviceWorkerHandler):
         Returns:
             str: device id (generated if none was supplied)
         """
+
+        self._check_device_name_length(initial_device_display_name)
+
         if device_id is not None:
             new_device = await self.store.store_device(
                 user_id=user_id,
@@ -399,12 +419,8 @@ class DeviceHandler(DeviceWorkerHandler):
 
         # Reject a new displayname which is too long.
         new_display_name = content.get("display_name")
-        if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN:
-            raise SynapseError(
-                400,
-                "Device display name is too long (max %i)"
-                % (MAX_DEVICE_DISPLAY_NAME_LEN,),
-            )
+
+        self._check_device_name_length(new_display_name)
 
         try:
             await self.store.update_device(
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 46826eb784..62aa9a2da8 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
 
 class DirectoryHandler(BaseHandler):
     def __init__(self, hs):
-        super(DirectoryHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.state = hs.get_state_handler()
         self.appservice_handler = hs.get_application_service_handler()
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d629c7c16c..dd40fd1299 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1201,7 +1201,7 @@ def _one_time_keys_match(old_key_json, new_key):
     return old_key == new_key_copy
 
 
-@attr.s
+@attr.s(slots=True)
 class SignatureListItem:
     """An item in the signature list as used by upload_signatures_for_device_keys.
     """
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index b05e32f457..0875b74ea8 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -37,11 +37,7 @@ logger = logging.getLogger(__name__)
 
 class EventStreamHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(EventStreamHandler, self).__init__(hs)
-
-        self.distributor = hs.get_distributor()
-        self.distributor.declare("started_user_eventstream")
-        self.distributor.declare("stopped_user_eventstream")
+        super().__init__(hs)
 
         self.clock = hs.get_clock()
 
@@ -146,7 +142,7 @@ class EventStreamHandler(BaseHandler):
 
 class EventHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(EventHandler, self).__init__(hs)
+        super().__init__(hs)
         self.storage = hs.get_storage()
 
     async def get_event(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 014dab2940..9f773aefa7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -69,18 +69,18 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
     ReplicationStoreRoomOnInviteRestServlet,
 )
-from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     JsonDict,
     MutableStateMap,
+    PersistedEventPosition,
+    RoomStreamToken,
     StateMap,
     UserID,
     get_domain_from_id,
 )
 from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.util.stringutils import shortstr
 from synapse.visibility import filter_events_for_server
@@ -88,7 +88,7 @@ from synapse.visibility import filter_events_for_server
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+@attr.s(slots=True)
 class _NewEventInfo:
     """Holds information about a received event, ready for passing to _handle_new_events
 
@@ -117,7 +117,7 @@ class FederationHandler(BaseHandler):
     """
 
     def __init__(self, hs):
-        super(FederationHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.hs = hs
 
@@ -130,7 +130,6 @@ class FederationHandler(BaseHandler):
         self.keyring = hs.get_keyring()
         self.action_generator = hs.get_action_generator()
         self.is_mine_id = hs.is_mine_id
-        self.pusher_pool = hs.get_pusherpool()
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
         self._message_handler = hs.get_message_handler()
@@ -141,9 +140,6 @@ class FederationHandler(BaseHandler):
         self._replication = hs.get_replication_data_handler()
 
         self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
-        self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
-            hs
-        )
         self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
             hs
         )
@@ -704,31 +700,10 @@ class FederationHandler(BaseHandler):
         logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
 
         try:
-            context = await self._handle_new_event(origin, event, state=state)
+            await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
-        if event.type == EventTypes.Member:
-            if event.membership == Membership.JOIN:
-                # Only fire user_joined_room if the user has acutally
-                # joined the room. Don't bother if the user is just
-                # changing their profile info.
-                newly_joined = True
-
-                prev_state_ids = await context.get_prev_state_ids()
-
-                prev_state_id = prev_state_ids.get((event.type, event.state_key))
-                if prev_state_id:
-                    prev_state = await self.store.get_event(
-                        prev_state_id, allow_none=True
-                    )
-                    if prev_state and prev_state.membership == Membership.JOIN:
-                        newly_joined = False
-
-                if newly_joined:
-                    user = UserID.from_string(event.state_key)
-                    await self.user_joined_room(user, room_id)
-
         # For encrypted messages we check that we know about the sending device,
         # if we don't then we mark the device cache for that user as stale.
         if event.type == EventTypes.Encrypted:
@@ -923,7 +898,8 @@ class FederationHandler(BaseHandler):
                 )
             )
 
-        await self._handle_new_events(dest, ev_infos, backfilled=True)
+        if ev_infos:
+            await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -1265,7 +1241,7 @@ class FederationHandler(BaseHandler):
             event_infos.append(_NewEventInfo(event, None, auth))
 
         await self._handle_new_events(
-            destination, event_infos,
+            destination, room_id, event_infos,
         )
 
     def _sanity_check_event(self, ev):
@@ -1412,15 +1388,15 @@ class FederationHandler(BaseHandler):
             )
 
             max_stream_id = await self._persist_auth_tree(
-                origin, auth_chain, state, event, room_version_obj
+                origin, room_id, auth_chain, state, event, room_version_obj
             )
 
             # We wait here until this instance has seen the events come down
             # replication (if we're using replication) as the below uses caches.
-            #
-            # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                self.config.worker.writers.events, "events", max_stream_id
+                self.config.worker.events_shard_config.get_instance(room_id),
+                "events",
+                max_stream_id,
             )
 
             # Check whether this room is the result of an upgrade of a room we already know
@@ -1599,11 +1575,6 @@ class FederationHandler(BaseHandler):
             event.signatures,
         )
 
-        if event.type == EventTypes.Member:
-            if event.content["membership"] == Membership.JOIN:
-                user = UserID.from_string(event.state_key)
-                await self.user_joined_room(user, event.room_id)
-
         prev_state_ids = await context.get_prev_state_ids()
 
         state_ids = list(prev_state_ids.values())
@@ -1674,7 +1645,7 @@ class FederationHandler(BaseHandler):
         )
 
         context = await self.state_handler.compute_event_context(event)
-        await self.persist_events_and_notify([(event, context)])
+        await self.persist_events_and_notify(event.room_id, [(event, context)])
 
         return event
 
@@ -1701,7 +1672,9 @@ class FederationHandler(BaseHandler):
         await self.federation_client.send_leave(host_list, event)
 
         context = await self.state_handler.compute_event_context(event)
-        stream_id = await self.persist_events_and_notify([(event, context)])
+        stream_id = await self.persist_events_and_notify(
+            event.room_id, [(event, context)]
+        )
 
         return event, stream_id
 
@@ -1949,7 +1922,7 @@ class FederationHandler(BaseHandler):
                 )
 
             await self.persist_events_and_notify(
-                [(event, context)], backfilled=backfilled
+                event.room_id, [(event, context)], backfilled=backfilled
             )
         except Exception:
             run_in_background(
@@ -1962,6 +1935,7 @@ class FederationHandler(BaseHandler):
     async def _handle_new_events(
         self,
         origin: str,
+        room_id: str,
         event_infos: Iterable[_NewEventInfo],
         backfilled: bool = False,
     ) -> None:
@@ -1993,6 +1967,7 @@ class FederationHandler(BaseHandler):
         )
 
         await self.persist_events_and_notify(
+            room_id,
             [
                 (ev_info.event, context)
                 for ev_info, context in zip(event_infos, contexts)
@@ -2003,6 +1978,7 @@ class FederationHandler(BaseHandler):
     async def _persist_auth_tree(
         self,
         origin: str,
+        room_id: str,
         auth_events: List[EventBase],
         state: List[EventBase],
         event: EventBase,
@@ -2017,6 +1993,7 @@ class FederationHandler(BaseHandler):
 
         Args:
             origin: Where the events came from
+            room_id,
             auth_events
             state
             event
@@ -2091,17 +2068,20 @@ class FederationHandler(BaseHandler):
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
         await self.persist_events_and_notify(
+            room_id,
             [
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
-            ]
+            ],
         )
 
         new_event_context = await self.state_handler.compute_event_context(
             event, old_state=state
         )
 
-        return await self.persist_events_and_notify([(event, new_event_context)])
+        return await self.persist_events_and_notify(
+            room_id, [(event, new_event_context)]
+        )
 
     async def _prep_event(
         self,
@@ -2952,6 +2932,7 @@ class FederationHandler(BaseHandler):
 
     async def persist_events_and_notify(
         self,
+        room_id: str,
         event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
     ) -> int:
@@ -2959,20 +2940,25 @@ class FederationHandler(BaseHandler):
         necessary.
 
         Args:
-            event_and_contexts:
+            room_id: The room ID of events being persisted.
+            event_and_contexts: Sequence of events with their associated
+                context that should be persisted. All events must belong to
+                the same room.
             backfilled: Whether these events are a result of
                 backfilling or not
         """
-        if self.config.worker.writers.events != self._instance_name:
+        instance = self.config.worker.events_shard_config.get_instance(room_id)
+        if instance != self._instance_name:
             result = await self._send_events(
-                instance_name=self.config.worker.writers.events,
+                instance_name=instance,
                 store=self.store,
+                room_id=room_id,
                 event_and_contexts=event_and_contexts,
                 backfilled=backfilled,
             )
             return result["max_stream_id"]
         else:
-            max_stream_id = await self.storage.persistence.persist_events(
+            max_stream_token = await self.storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
@@ -2983,12 +2969,12 @@ class FederationHandler(BaseHandler):
 
             if not backfilled:  # Never notify for backfilled events
                 for event, _ in event_and_contexts:
-                    await self._notify_persisted_event(event, max_stream_id)
+                    await self._notify_persisted_event(event, max_stream_token)
 
-            return max_stream_id
+            return max_stream_token.stream
 
     async def _notify_persisted_event(
-        self, event: EventBase, max_stream_id: int
+        self, event: EventBase, max_stream_token: RoomStreamToken
     ) -> None:
         """Checks to see if notifier/pushers should be notified about the
         event or not.
@@ -3014,13 +3000,13 @@ class FederationHandler(BaseHandler):
         elif event.internal_metadata.is_outlier():
             return
 
-        event_stream_id = event.internal_metadata.stream_ordering
+        event_pos = PersistedEventPosition(
+            self._instance_name, event.internal_metadata.stream_ordering
+        )
         self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id, extra_users=extra_users
+            event, event_pos, max_stream_token, extra_users=extra_users
         )
 
-        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
     async def _clean_room_for_join(self, room_id: str) -> None:
         """Called to clean up any data in DB for a given room, ready for the
         server to join the room.
@@ -3033,16 +3019,6 @@ class FederationHandler(BaseHandler):
         else:
             await self.store.clean_room_for_join(room_id)
 
-    async def user_joined_room(self, user: UserID, room_id: str) -> None:
-        """Called when a new user has joined the room
-        """
-        if self.config.worker_app:
-            await self._notify_user_membership_change(
-                room_id=room_id, user_id=user.to_string(), change="joined"
-            )
-        else:
-            user_joined_room(self.distributor, user, room_id)
-
     async def get_room_complexity(
         self, remote_room_hosts: List[str], room_id: str
     ) -> Optional[dict]:
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 44df567983..9684e60fc8 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -240,7 +240,7 @@ class GroupsLocalWorkerHandler:
 
 class GroupsLocalHandler(GroupsLocalWorkerHandler):
     def __init__(self, hs):
-        super(GroupsLocalHandler, self).__init__(hs)
+        super().__init__(hs)
 
         # Ensure attestations get renewed
         hs.get_groups_attestation_renewer()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 0ce6ddfbe4..ab15570f7a 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -45,7 +45,7 @@ id_server_scheme = "https://"
 
 class IdentityHandler(BaseHandler):
     def __init__(self, hs):
-        super(IdentityHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.http_client = SimpleHttpClient(hs)
         # We create a blacklisting instance of SimpleHttpClient for contacting identity
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d5ddc583ad..8cd7eb22a3 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage.roommember import RoomsForUser
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, StreamToken, UserID
+from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
 
 class InitialSyncHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(InitialSyncHandler, self).__init__(hs)
+        super().__init__(hs)
         self.hs = hs
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
@@ -116,14 +116,13 @@ class InitialSyncHandler(BaseHandler):
         now_token = self.hs.get_event_sources().get_current_token()
 
         presence_stream = self.hs.get_event_sources().sources["presence"]
-        pagination_config = PaginationConfig(from_token=now_token)
-        presence, _ = await presence_stream.get_pagination_rows(
-            user, pagination_config.get_source_config("presence"), None
+        presence, _ = await presence_stream.get_new_events(
+            user, from_key=None, include_offline=False
         )
 
-        receipt_stream = self.hs.get_event_sources().sources["receipt"]
-        receipt, _ = await receipt_stream.get_pagination_rows(
-            user, pagination_config.get_source_config("receipt"), None
+        joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
+        receipt = await self.store.get_linearized_receipts_for_rooms(
+            joined_rooms, to_key=int(now_token.receipt_key),
         )
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -168,7 +167,7 @@ class InitialSyncHandler(BaseHandler):
                         self.state_handler.get_current_state, event.room_id
                     )
                 elif event.membership == Membership.LEAVE:
-                    room_end_token = "s%d" % (event.stream_ordering,)
+                    room_end_token = RoomStreamToken(None, event.stream_ordering,)
                     deferred_room_state = run_in_background(
                         self.state_store.get_state_for_events, [event.event_id]
                     )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 8a7b4916cd..ee271e85e5 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -376,9 +376,8 @@ class EventCreationHandler:
         self.notifier = hs.get_notifier()
         self.config = hs.config
         self.require_membership_for_aliases = hs.config.require_membership_for_aliases
-        self._is_event_writer = (
-            self.config.worker.writers.events == hs.get_instance_name()
-        )
+        self._events_shard_config = self.config.worker.events_shard_config
+        self._instance_name = hs.get_instance_name()
 
         self.room_invite_state_types = self.hs.config.room_invite_state_types
 
@@ -387,8 +386,6 @@ class EventCreationHandler:
         # This is only used to get at ratelimit function, and maybe_kick_guest_users
         self.base_handler = BaseHandler(hs)
 
-        self.pusher_pool = hs.get_pusherpool()
-
         # We arbitrarily limit concurrent event creation for a room to 5.
         # This is to stop us from diverging history *too* much.
         self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
@@ -904,9 +901,10 @@ class EventCreationHandler:
 
         try:
             # If we're a worker we need to hit out to the master.
-            if not self._is_event_writer:
+            writer_instance = self._events_shard_config.get_instance(event.room_id)
+            if writer_instance != self._instance_name:
                 result = await self.send_event(
-                    instance_name=self.config.worker.writers.events,
+                    instance_name=writer_instance,
                     event_id=event.event_id,
                     store=self.store,
                     requester=requester,
@@ -974,7 +972,10 @@ class EventCreationHandler:
 
         This should only be run on the instance in charge of persisting events.
         """
-        assert self._is_event_writer
+        assert self.storage.persistence is not None
+        assert self._events_shard_config.should_handle(
+            self._instance_name, event.room_id
+        )
 
         if ratelimit:
             # We check if this is a room admin redacting an event so that we
@@ -1137,7 +1138,7 @@ class EventCreationHandler:
             if prev_state_ids:
                 raise AuthError(403, "Changing the room create event is forbidden")
 
-        event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
+        event_pos, max_stream_token = await self.storage.persistence.persist_event(
             event, context=context
         )
 
@@ -1145,12 +1146,10 @@ class EventCreationHandler:
             # If there's an expiry timestamp on the event, schedule its expiry.
             self._message_handler.maybe_schedule_expiry(event)
 
-        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
         def _notify():
             try:
                 self.notifier.on_new_room_event(
-                    event, event_stream_id, max_stream_id, extra_users=extra_users
+                    event, event_pos, max_stream_token, extra_users=extra_users
                 )
             except Exception:
                 logger.exception("Error notifying about new room event")
@@ -1162,7 +1161,7 @@ class EventCreationHandler:
             # matters as sometimes presence code can take a while.
             run_in_background(self._bump_active_time, requester.user)
 
-        return event_stream_id
+        return event_pos.stream
 
     async def _bump_active_time(self, user: UserID) -> None:
         try:
@@ -1183,54 +1182,7 @@ class EventCreationHandler:
         )
 
         for room_id in room_ids:
-            # For each room we need to find a joined member we can use to send
-            # the dummy event with.
-
-            latest_event_ids = await self.store.get_prev_events_for_room(room_id)
-
-            members = await self.state.get_current_users_in_room(
-                room_id, latest_event_ids=latest_event_ids
-            )
-            dummy_event_sent = False
-            for user_id in members:
-                if not self.hs.is_mine_id(user_id):
-                    continue
-                requester = create_requester(user_id)
-                try:
-                    event, context = await self.create_event(
-                        requester,
-                        {
-                            "type": "org.matrix.dummy_event",
-                            "content": {},
-                            "room_id": room_id,
-                            "sender": user_id,
-                        },
-                        prev_event_ids=latest_event_ids,
-                    )
-
-                    event.internal_metadata.proactively_send = False
-
-                    # Since this is a dummy-event it is OK if it is sent by a
-                    # shadow-banned user.
-                    await self.send_nonmember_event(
-                        requester,
-                        event,
-                        context,
-                        ratelimit=False,
-                        ignore_shadow_ban=True,
-                    )
-                    dummy_event_sent = True
-                    break
-                except ConsentNotGivenError:
-                    logger.info(
-                        "Failed to send dummy event into room %s for user %s due to "
-                        "lack of consent. Will try another user" % (room_id, user_id)
-                    )
-                except AuthError:
-                    logger.info(
-                        "Failed to send dummy event into room %s for user %s due to "
-                        "lack of power. Will try another user" % (room_id, user_id)
-                    )
+            dummy_event_sent = await self._send_dummy_event_for_room(room_id)
 
             if not dummy_event_sent:
                 # Did not find a valid user in the room, so remove from future attempts
@@ -1243,6 +1195,59 @@ class EventCreationHandler:
                 now = self.clock.time_msec()
                 self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now
 
+    async def _send_dummy_event_for_room(self, room_id: str) -> bool:
+        """Attempt to send a dummy event for the given room.
+
+        Args:
+            room_id: room to try to send an event from
+
+        Returns:
+            True if a dummy event was successfully sent. False if no user was able
+            to send an event.
+        """
+
+        # For each room we need to find a joined member we can use to send
+        # the dummy event with.
+        latest_event_ids = await self.store.get_prev_events_for_room(room_id)
+        members = await self.state.get_current_users_in_room(
+            room_id, latest_event_ids=latest_event_ids
+        )
+        for user_id in members:
+            if not self.hs.is_mine_id(user_id):
+                continue
+            requester = create_requester(user_id)
+            try:
+                event, context = await self.create_event(
+                    requester,
+                    {
+                        "type": "org.matrix.dummy_event",
+                        "content": {},
+                        "room_id": room_id,
+                        "sender": user_id,
+                    },
+                    prev_event_ids=latest_event_ids,
+                )
+
+                event.internal_metadata.proactively_send = False
+
+                # Since this is a dummy-event it is OK if it is sent by a
+                # shadow-banned user.
+                await self.send_nonmember_event(
+                    requester, event, context, ratelimit=False, ignore_shadow_ban=True,
+                )
+                return True
+            except ConsentNotGivenError:
+                logger.info(
+                    "Failed to send dummy event into room %s for user %s due to "
+                    "lack of consent. Will try another user" % (room_id, user_id)
+                )
+            except AuthError:
+                logger.info(
+                    "Failed to send dummy event into room %s for user %s due to "
+                    "lack of power. Will try another user" % (room_id, user_id)
+                )
+        return False
+
     def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
         expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
         to_expire = set()
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 1b06f3173f..4230dbaf99 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -131,10 +131,10 @@ class OidcHandler:
     def _render_error(
         self, request, error: str, error_description: Optional[str] = None
     ) -> None:
-        """Renders the error template and respond with it.
+        """Render the error template and respond to the request with it.
 
         This is used to show errors to the user. The template of this page can
-        be found under ``synapse/res/templates/sso_error.html``.
+        be found under `synapse/res/templates/sso_error.html`.
 
         Args:
             request: The incoming request from the browser.
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6067585f9b..a0b3bdb5e0 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -335,20 +335,16 @@ class PaginationHandler:
         user_id = requester.user.to_string()
 
         if pagin_config.from_token:
-            room_token = pagin_config.from_token.room_key
+            from_token = pagin_config.from_token
         else:
-            pagin_config.from_token = (
-                self.hs.get_event_sources().get_current_token_for_pagination()
-            )
-            room_token = pagin_config.from_token.room_key
-
-        room_token = RoomStreamToken.parse(room_token)
+            from_token = self.hs.get_event_sources().get_current_token_for_pagination()
 
-        pagin_config.from_token = pagin_config.from_token.copy_and_replace(
-            "room_key", str(room_token)
-        )
+        if pagin_config.limit is None:
+            # This shouldn't happen as we've set a default limit before this
+            # gets called.
+            raise Exception("limit not set")
 
-        source_config = pagin_config.get_source_config("room")
+        room_token = from_token.room_key
 
         with await self.pagination_lock.read(room_id):
             (
@@ -358,7 +354,7 @@ class PaginationHandler:
                 room_id, user_id, allow_departed_users=True
             )
 
-            if source_config.direction == "b":
+            if pagin_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
                 # requires that we have a topo token.
                 if room_token.topological:
@@ -377,26 +373,35 @@ class PaginationHandler:
                     # case "JOIN" would have been returned.
                     assert member_event_id
 
-                    leave_token = await self.store.get_topological_token_for_event(
+                    leave_token_str = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
-                    if RoomStreamToken.parse(leave_token).topological < curr_topo:
-                        source_config.from_key = str(leave_token)
+                    leave_token = RoomStreamToken.parse(leave_token_str)
+                    assert leave_token.topological is not None
+
+                    if leave_token.topological < curr_topo:
+                        from_token = from_token.copy_and_replace(
+                            "room_key", leave_token
+                        )
 
                 await self.hs.get_handlers().federation_handler.maybe_backfill(
-                    room_id, curr_topo, limit=source_config.limit,
+                    room_id, curr_topo, limit=pagin_config.limit,
                 )
 
+            to_room_key = None
+            if pagin_config.to_token:
+                to_room_key = pagin_config.to_token.room_key
+
             events, next_key = await self.store.paginate_room_events(
                 room_id=room_id,
-                from_key=source_config.from_key,
-                to_key=source_config.to_key,
-                direction=source_config.direction,
-                limit=source_config.limit,
+                from_key=from_token.room_key,
+                to_key=to_room_key,
+                direction=pagin_config.direction,
+                limit=pagin_config.limit,
                 event_filter=event_filter,
             )
 
-            next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
+            next_token = from_token.copy_and_replace("room_key", next_key)
 
         if events:
             if event_filter:
@@ -409,7 +414,7 @@ class PaginationHandler:
         if not events:
             return {
                 "chunk": [],
-                "start": pagin_config.from_token.to_string(),
+                "start": from_token.to_string(),
                 "end": next_token.to_string(),
             }
 
@@ -438,7 +443,7 @@ class PaginationHandler:
                     events, time_now, as_client_event=as_client_event
                 )
             ),
-            "start": pagin_config.from_token.to_string(),
+            "start": from_token.to_string(),
             "end": next_token.to_string(),
         }
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 91a3aec1cc..1000ac95ff 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1108,9 +1108,6 @@ class PresenceEventSource:
     def get_current_key(self):
         return self.store.get_current_presence_token()
 
-    async def get_pagination_rows(self, user, pagination_config, key):
-        return await self.get_new_events(user, from_key=None, include_offline=False)
-
     @cached(num_args=2, cache_context=True)
     async def _get_interested_in(self, user, explicit_room_id, cache_context):
         """Returns the set of users that the given user should see presence
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 0cb8fad89a..5453e6dfc8 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -44,7 +44,7 @@ class BaseProfileHandler(BaseHandler):
     """
 
     def __init__(self, hs):
-        super(BaseProfileHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.federation = hs.get_federation_client()
         hs.get_federation_registry().register_query_handler(
@@ -369,7 +369,7 @@ class MasterProfileHandler(BaseProfileHandler):
     PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
 
     def __init__(self, hs):
-        super(MasterProfileHandler, self).__init__(hs)
+        super().__init__(hs)
 
         assert hs.config.worker_app is None
 
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index e3b528d271..c32f314a1c 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
 
 class ReadMarkerHandler(BaseHandler):
     def __init__(self, hs):
-        super(ReadMarkerHandler, self).__init__(hs)
+        super().__init__(hs)
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
         self.read_marker_linearizer = Linearizer(name="read_marker")
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 2cc6c2eb68..7225923757 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
 
 class ReceiptsHandler(BaseHandler):
     def __init__(self, hs):
-        super(ReceiptsHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
@@ -142,18 +142,3 @@ class ReceiptEventSource:
 
     def get_current_key(self, direction="f"):
         return self.store.get_max_receipt_stream_id()
-
-    async def get_pagination_rows(self, user, config, key):
-        to_key = int(config.from_key)
-
-        if config.to_key:
-            from_key = int(config.to_key)
-        else:
-            from_key = None
-
-        room_ids = await self.store.get_rooms_for_user(user.to_string())
-        events = await self.store.get_linearized_receipts_for_rooms(
-            room_ids, from_key=from_key, to_key=to_key
-        )
-
-        return (events, to_key)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index cde2dbca92..538f4b2a61 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -42,7 +42,7 @@ class RegistrationHandler(BaseHandler):
         Args:
             hs (synapse.server.HomeServer):
         """
-        super(RegistrationHandler, self).__init__(hs)
+        super().__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a29305f655..11bf146bed 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -70,7 +70,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
 
 class RoomCreationHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(RoomCreationHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
@@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler):
 
         # Always wait for room creation to progate before returning
         await self._replication.wait_for_stream_position(
-            self.hs.config.worker.writers.events, "events", last_stream_id
+            self.hs.config.worker.events_shard_config.get_instance(room_id),
+            "events",
+            last_stream_id,
         )
 
         return result, last_stream_id
@@ -1091,20 +1093,19 @@ class RoomEventSource:
     async def get_new_events(
         self,
         user: UserID,
-        from_key: str,
+        from_key: RoomStreamToken,
         limit: int,
         room_ids: List[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
-    ) -> Tuple[List[EventBase], str]:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         # We just ignore the key for now.
 
         to_key = self.get_current_key()
 
-        from_token = RoomStreamToken.parse(from_key)
-        if from_token.topological:
+        if from_key.topological:
             logger.warning("Stream has topological part!!!! %r", from_key)
-            from_key = "s%s" % (from_token.stream,)
+            from_key = RoomStreamToken(None, from_key.stream)
 
         app_service = self.store.get_app_service_by_user_id(user.to_string())
         if app_service:
@@ -1133,14 +1134,14 @@ class RoomEventSource:
                 events[:] = events[:limit]
 
             if events:
-                end_key = events[-1].internal_metadata.after
+                end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
             else:
                 end_key = to_key
 
         return (events, end_key)
 
-    def get_current_key(self) -> str:
-        return "s%d" % (self.store.get_room_max_stream_ordering(),)
+    def get_current_key(self) -> RoomStreamToken:
+        return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
 
     def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
         return self.store.get_room_events_max_id(room_id)
@@ -1260,10 +1261,10 @@ class RoomShutdownHandler:
             # We now wait for the create room to come back in via replication so
             # that we can assume that all the joins/invites have propogated before
             # we try and auto join below.
-            #
-            # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                self.hs.config.worker.writers.events, "events", stream_id
+                self.hs.config.worker.events_shard_config.get_instance(new_room_id),
+                "events",
+                stream_id,
             )
         else:
             new_room_id = None
@@ -1293,7 +1294,9 @@ class RoomShutdownHandler:
 
                 # Wait for leave to come in over replication before trying to forget.
                 await self._replication.wait_for_stream_position(
-                    self.hs.config.worker.writers.events, "events", stream_id
+                    self.hs.config.worker.events_shard_config.get_instance(room_id),
+                    "events",
+                    stream_id,
                 )
 
                 await self.room_member_handler.forget(target_requester.user, room_id)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5dd7b28391..4a13c8e912 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -38,7 +38,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
 
 class RoomListHandler(BaseHandler):
     def __init__(self, hs):
-        super(RoomListHandler, self).__init__(hs)
+        super().__init__(hs)
         self.enable_room_list_search = hs.config.enable_room_list_search
         self.response_cache = ResponseCache(hs, "room_list")
         self.remote_response_cache = ResponseCache(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 32b7e323fa..8feba8c90a 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -40,7 +40,7 @@ from synapse.events.validator import EventValidator
 from synapse.storage.roommember import RoomsForUser
 from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
 from synapse.util.async_helpers import Linearizer
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
 
 from ._base import BaseHandler
 
@@ -51,14 +51,12 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class RoomMemberHandler:
+class RoomMemberHandler(metaclass=abc.ABCMeta):
     # TODO(paul): This handler currently contains a messy conflation of
     #   low-level API that works on UserID objects and so on, and REST-level
     #   API that takes ID strings and returns pagination chunks. These concerns
     #   ought to be separated out a lot better.
 
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
@@ -82,13 +80,6 @@ class RoomMemberHandler:
         self._enable_lookup = hs.config.enable_3pid_lookup
         self.allow_per_room_profiles = self.config.allow_per_room_profiles
 
-        self._event_stream_writer_instance = hs.config.worker.writers.events
-        self._is_on_event_persistence_instance = (
-            self._event_stream_writer_instance == hs.get_instance_name()
-        )
-        if self._is_on_event_persistence_instance:
-            self.persist_event_storage = hs.get_storage().persistence
-
         self._join_rate_limiter_local = Ratelimiter(
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
@@ -149,17 +140,6 @@ class RoomMemberHandler:
         raise NotImplementedError()
 
     @abc.abstractmethod
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Notifies distributor on master process that the user has joined the
-        room.
-
-        Args:
-            target
-            room_id
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Notifies distributor on master process that the user has left the
         room.
@@ -221,7 +201,6 @@ class RoomMemberHandler:
 
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
 
-        newly_joined = False
         if event.membership == Membership.JOIN:
             newly_joined = True
             if prev_member_event_id:
@@ -246,12 +225,7 @@ class RoomMemberHandler:
             requester, event, context, extra_users=[target], ratelimit=ratelimit,
         )
 
-        if event.membership == Membership.JOIN and newly_joined:
-            # Only fire user_joined_room if the user has actually joined the
-            # room. Don't bother if the user is just changing their profile
-            # info.
-            await self._user_joined_room(target, room_id)
-        elif event.membership == Membership.LEAVE:
+        if event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
@@ -726,17 +700,7 @@ class RoomMemberHandler:
             (EventTypes.Member, event.state_key), None
         )
 
-        if event.membership == Membership.JOIN:
-            # Only fire user_joined_room if the user has actually joined the
-            # room. Don't bother if the user is just changing their profile
-            # info.
-            newly_joined = True
-            if prev_member_event_id:
-                prev_member_event = await self.store.get_event(prev_member_event_id)
-                newly_joined = prev_member_event.membership != Membership.JOIN
-            if newly_joined:
-                await self._user_joined_room(target_user, room_id)
-        elif event.membership == Membership.LEAVE:
+        if event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
@@ -1002,10 +966,9 @@ class RoomMemberHandler:
 
 class RoomMemberMasterHandler(RoomMemberHandler):
     def __init__(self, hs):
-        super(RoomMemberMasterHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.distributor = hs.get_distributor()
-        self.distributor.declare("user_joined_room")
         self.distributor.declare("user_left_room")
 
     async def _is_remote_room_too_complex(
@@ -1085,7 +1048,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         event_id, stream_id = await self.federation_handler.do_invite_join(
             remote_room_hosts, room_id, user.to_string(), content
         )
-        await self._user_joined_room(user, room_id)
 
         # Check the room we just joined wasn't too large, if we didn't fetch the
         # complexity of it before.
@@ -1228,11 +1190,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         )
         return event.event_id, stream_id
 
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Implements RoomMemberHandler._user_joined_room
-        """
-        user_joined_room(self.distributor, target, room_id)
-
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 897338fd54..f2e88f6a5b 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 
 class RoomMemberWorkerHandler(RoomMemberHandler):
     def __init__(self, hs):
-        super(RoomMemberWorkerHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self._remote_join_client = ReplRemoteJoin.make_client(hs)
         self._remote_reject_client = ReplRejectInvite.make_client(hs)
@@ -57,8 +57,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
             content=content,
         )
 
-        await self._user_joined_room(user, room_id)
-
         return ret["event_id"], ret["stream_id"]
 
     async def remote_reject_invite(
@@ -81,13 +79,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         )
         return ret["event_id"], ret["stream_id"]
 
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Implements RoomMemberHandler._user_joined_room
-        """
-        await self._notify_change_client(
-            user_id=target.to_string(), room_id=room_id, change="joined"
-        )
-
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 66b063f991..285c481a96 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -21,9 +21,10 @@ import saml2
 import saml2.response
 from saml2.client import Saml2Client
 
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
+from synapse.http.server import respond_with_html
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
@@ -41,7 +42,11 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+class MappingException(Exception):
+    """Used to catch errors when mapping the SAML2 response to a user."""
+
+
+@attr.s(slots=True)
 class Saml2SessionData:
     """Data we track about SAML2 sessions"""
 
@@ -68,6 +73,7 @@ class SamlHandler:
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
         self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
+        self._error_template = hs.config.sso_error_template
 
         # plugin to do custom mapping from saml response to mxid
         self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@@ -84,6 +90,25 @@ class SamlHandler:
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
 
+    def _render_error(
+        self, request, error: str, error_description: Optional[str] = None
+    ) -> None:
+        """Render the error template and respond to the request with it.
+
+        This is used to show errors to the user. The template of this page can
+        be found under `synapse/res/templates/sso_error.html`.
+
+        Args:
+            request: The incoming request from the browser.
+                We'll respond with an HTML page describing the error.
+            error: A technical identifier for this error.
+            error_description: A human-readable description of the error.
+        """
+        html = self._error_template.render(
+            error=error, error_description=error_description
+        )
+        respond_with_html(request, 400, html)
+
     def handle_redirect_request(
         self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
     ) -> bytes:
@@ -134,49 +159,6 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
-            0
-        ].decode("ascii", "surrogateescape")
-        ip_address = self.hs.get_ip_from_request(request)
-
-        user_id, current_session = await self._map_saml_response_to_user(
-            resp_bytes, relay_state, user_agent, ip_address
-        )
-
-        # Complete the interactive auth session or the login.
-        if current_session and current_session.ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, current_session.ui_auth_session_id, request
-            )
-
-        else:
-            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
-
-    async def _map_saml_response_to_user(
-        self,
-        resp_bytes: str,
-        client_redirect_url: str,
-        user_agent: str,
-        ip_address: str,
-    ) -> Tuple[str, Optional[Saml2SessionData]]:
-        """
-        Given a sample response, retrieve the cached session and user for it.
-
-        Args:
-            resp_bytes: The SAML response.
-            client_redirect_url: The redirect URL passed in by the client.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
-
-        Returns:
-             Tuple of the user ID and SAML session associated with this response.
-
-        Raises:
-            SynapseError if there was a problem with the response.
-            RedirectException: some mapping providers may raise this if they need
-                to redirect to an interstitial page.
-        """
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -189,12 +171,23 @@ class SamlHandler:
             # in the (user-visible) exception message, so let's log the exception here
             # so we can track down the session IDs later.
             logger.warning(str(e))
-            raise SynapseError(400, "Unexpected SAML2 login.")
+            self._render_error(
+                request, "unsolicited_response", "Unexpected SAML2 login."
+            )
+            return
         except Exception as e:
-            raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
+            self._render_error(
+                request,
+                "invalid_response",
+                "Unable to parse SAML2 response: %s." % (e,),
+            )
+            return
 
         if saml2_auth.not_signed:
-            raise SynapseError(400, "SAML2 response was not signed.")
+            self._render_error(
+                request, "unsigned_respond", "SAML2 response was not signed."
+            )
+            return
 
         logger.debug("SAML2 response: %s", saml2_auth.origxml)
         for assertion in saml2_auth.assertions:
@@ -213,15 +206,73 @@ class SamlHandler:
             saml2_auth.in_response_to, None
         )
 
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
         for requirement in self._saml2_attribute_requirements:
-            _check_attribute_requirement(saml2_auth.ava, requirement)
+            if not _check_attribute_requirement(saml2_auth.ava, requirement):
+                self._render_error(
+                    request, "unauthorised", "You are not authorised to log in here."
+                )
+                return
+
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+        ip_address = self.hs.get_ip_from_request(request)
+
+        # Call the mapper to register/login the user
+        try:
+            user_id = await self._map_saml_response_to_user(
+                saml2_auth, relay_state, user_agent, ip_address
+            )
+        except MappingException as e:
+            logger.exception("Could not map user")
+            self._render_error(request, "mapping_error", str(e))
+            return
+
+        # Complete the interactive auth session or the login.
+        if current_session and current_session.ui_auth_session_id:
+            await self._auth_handler.complete_sso_ui_auth(
+                user_id, current_session.ui_auth_session_id, request
+            )
+
+        else:
+            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+
+    async def _map_saml_response_to_user(
+        self,
+        saml2_auth: saml2.response.AuthnResponse,
+        client_redirect_url: str,
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """
+        Given a SAML response, retrieve the user ID for it and possibly register the user.
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+
+        Returns:
+             The user ID associated with this response.
+
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
+        """
 
         remote_user_id = self._user_mapping_provider.get_remote_user_id(
             saml2_auth, client_redirect_url
         )
 
         if not remote_user_id:
-            raise Exception("Failed to extract remote user id from SAML response")
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
+            )
 
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
@@ -235,7 +286,7 @@ class SamlHandler:
             )
             if registered_user_id is not None:
                 logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id, current_session
+                return registered_user_id
 
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
@@ -260,7 +311,7 @@ class SamlHandler:
                     await self._datastore.record_user_external_id(
                         self._auth_provider_id, remote_user_id, registered_user_id
                     )
-                    return registered_user_id, current_session
+                    return registered_user_id
 
             # Map saml response to user attributes using the configured mapping provider
             for i in range(1000):
@@ -277,7 +328,7 @@ class SamlHandler:
 
                 localpart = attribute_dict.get("mxid_localpart")
                 if not localpart:
-                    raise Exception(
+                    raise MappingException(
                         "Error parsing SAML2 response: SAML mapping provider plugin "
                         "did not return a mxid_localpart value"
                     )
@@ -294,8 +345,8 @@ class SamlHandler:
             else:
                 # Unable to generate a username in 1000 iterations
                 # Break and return error to the user
-                raise SynapseError(
-                    500, "Unable to generate a Matrix ID from the SAML response"
+                raise MappingException(
+                    "Unable to generate a Matrix ID from the SAML response"
                 )
 
             logger.info("Mapped SAML user to local part %s", localpart)
@@ -310,7 +361,7 @@ class SamlHandler:
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
-            return registered_user_id, current_session
+            return registered_user_id
 
     def expire_sessions(self):
         expire_before = self._clock.time_msec() - self._saml2_session_lifetime
@@ -323,11 +374,11 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
+def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
     values = ava.get(req.attribute, [])
     for v in values:
         if v == req.value:
-            return
+            return True
 
     logger.info(
         "SAML2 attribute %s did not match required value '%s' (was '%s')",
@@ -335,7 +386,7 @@ def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
         req.value,
         values,
     )
-    raise AuthError(403, "You are not authorized to log in here.")
+    return False
 
 
 DOT_REPLACE_PATTERN = re.compile(
@@ -390,7 +441,7 @@ class DefaultSamlMappingProvider:
             return saml_response.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "'uid' not in SAML2 response")
+            raise MappingException("'uid' not in SAML2 response")
 
     def saml_response_to_user_attributes(
         self,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index d58f9788c5..6a76c20d79 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
 
 class SearchHandler(BaseHandler):
     def __init__(self, hs):
-        super(SearchHandler, self).__init__(hs)
+        super().__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 4d245b618b..a5d67f828f 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -27,7 +27,7 @@ class SetPasswordHandler(BaseHandler):
     """Handler which deals with changing user account passwords"""
 
     def __init__(self, hs):
-        super(SetPasswordHandler, self).__init__(hs)
+        super().__init__(hs)
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
         self._password_policy_handler = hs.get_password_policy_handler()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e2ddb628ff..e948efef2e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -89,14 +89,12 @@ class TimelineBatch:
     events = attr.ib(type=List[EventBase])
     limited = attr.ib(bool)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
         return bool(self.events)
 
-    __bool__ = __nonzero__  # python3
-
 
 # We can't freeze this class, because we need to update it after it's instantiated to
 # update its unread count. This is because we calculate the unread count for a room only
@@ -114,7 +112,7 @@ class JoinedSyncResult:
     summary = attr.ib(type=Optional[JsonDict])
     unread_count = attr.ib(type=int)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
@@ -127,8 +125,6 @@ class JoinedSyncResult:
             # else in the result, we don't need to send it.
         )
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class ArchivedSyncResult:
@@ -137,26 +133,22 @@ class ArchivedSyncResult:
     state = attr.ib(type=StateMap[EventBase])
     account_data = attr.ib(type=List[JsonDict])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
         return bool(self.timeline or self.state or self.account_data)
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class InvitedSyncResult:
     room_id = attr.ib(type=str)
     invite = attr.ib(type=EventBase)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Invited rooms should always be reported to the client"""
         return True
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class GroupsSyncResult:
@@ -164,11 +156,9 @@ class GroupsSyncResult:
     invite = attr.ib(type=JsonDict)
     leave = attr.ib(type=JsonDict)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         return bool(self.join or self.invite or self.leave)
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class DeviceLists:
@@ -181,13 +171,11 @@ class DeviceLists:
     changed = attr.ib(type=Collection[str])
     left = attr.ib(type=Collection[str])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         return bool(self.changed or self.left)
 
-    __bool__ = __nonzero__  # python3
-
 
-@attr.s
+@attr.s(slots=True)
 class _RoomChanges:
     """The set of room entries to include in the sync, plus the set of joined
     and left room IDs since last sync.
@@ -227,7 +215,7 @@ class SyncResult:
     device_one_time_keys_count = attr.ib(type=JsonDict)
     groups = attr.ib(type=Optional[GroupsSyncResult])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if the notifier needs to wait for more events when polling for
         events.
@@ -243,8 +231,6 @@ class SyncResult:
             or self.groups
         )
 
-    __bool__ = __nonzero__  # python3
-
 
 class SyncHandler:
     def __init__(self, hs: "HomeServer"):
@@ -378,7 +364,7 @@ class SyncHandler:
         sync_config = sync_result_builder.sync_config
 
         with Measure(self.clock, "ephemeral_by_room"):
-            typing_key = since_token.typing_key if since_token else "0"
+            typing_key = since_token.typing_key if since_token else 0
 
             room_ids = sync_result_builder.joined_room_ids
 
@@ -402,7 +388,7 @@ class SyncHandler:
                 event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
                 ephemeral_by_room.setdefault(room_id, []).append(event_copy)
 
-            receipt_key = since_token.receipt_key if since_token else "0"
+            receipt_key = since_token.receipt_key if since_token else 0
 
             receipt_source = self.event_sources.sources["receipt"]
             receipts, receipt_key = await receipt_source.get_new_events(
@@ -533,7 +519,7 @@ class SyncHandler:
             if len(recents) > timeline_limit:
                 limited = True
                 recents = recents[-timeline_limit:]
-                room_key = recents[0].internal_metadata.before
+                room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
 
             prev_batch_token = now_token.copy_and_replace("room_key", room_key)
 
@@ -981,7 +967,7 @@ class SyncHandler:
             raise NotImplementedError()
         else:
             joined_room_ids = await self.get_rooms_for_user_at(
-                user_id, now_token.room_stream_id
+                user_id, now_token.room_key
             )
         sync_result_builder = SyncResultBuilder(
             sync_config,
@@ -1310,12 +1296,11 @@ class SyncHandler:
         presence_source = self.event_sources.sources["presence"]
 
         since_token = sync_result_builder.since_token
+        presence_key = None
+        include_offline = False
         if since_token and not sync_result_builder.full_state:
             presence_key = since_token.presence_key
             include_offline = True
-        else:
-            presence_key = None
-            include_offline = False
 
         presence, presence_key = await presence_source.get_new_events(
             user=user,
@@ -1323,6 +1308,7 @@ class SyncHandler:
             is_guest=sync_config.is_guest,
             include_offline=include_offline,
         )
+        assert presence_key
         sync_result_builder.now_token = now_token.copy_and_replace(
             "presence_key", presence_key
         )
@@ -1485,7 +1471,7 @@ class SyncHandler:
         if rooms_changed:
             return True
 
-        stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
+        stream_id = since_token.room_key.stream
         for room_id in sync_result_builder.joined_room_ids:
             if self.store.has_room_changed_since(room_id, stream_id):
                 return True
@@ -1751,7 +1737,7 @@ class SyncHandler:
                             continue
 
                 leave_token = now_token.copy_and_replace(
-                    "room_key", "s%d" % (event.stream_ordering,)
+                    "room_key", RoomStreamToken(None, event.stream_ordering)
                 )
                 room_entries.append(
                     RoomSyncResultBuilder(
@@ -1930,7 +1916,7 @@ class SyncHandler:
             raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
     async def get_rooms_for_user_at(
-        self, user_id: str, stream_ordering: int
+        self, user_id: str, room_key: RoomStreamToken
     ) -> FrozenSet[str]:
         """Get set of joined rooms for a user at the given stream ordering.
 
@@ -1956,15 +1942,15 @@ class SyncHandler:
         # If the membership's stream ordering is after the given stream
         # ordering, we need to go and work out if the user was in the room
         # before.
-        for room_id, membership_stream_ordering in joined_rooms:
-            if membership_stream_ordering <= stream_ordering:
+        for room_id, event_pos in joined_rooms:
+            if not event_pos.persisted_after(room_key):
                 joined_room_ids.add(room_id)
                 continue
 
             logger.info("User joined room after current token: %s", room_id)
 
             extrems = await self.store.get_forward_extremeties_for_room(
-                room_id, stream_ordering
+                room_id, event_pos.stream
             )
             users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
             if user_id in users_in_room:
@@ -2038,7 +2024,7 @@ def _calculate_state(
     return {event_id_to_key[e]: e for e in state_ids}
 
 
-@attr.s
+@attr.s(slots=True)
 class SyncResultBuilder:
     """Used to help build up a new SyncResult for a user
 
@@ -2074,7 +2060,7 @@ class SyncResultBuilder:
     to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
 
 
-@attr.s
+@attr.s(slots=True)
 class RoomSyncResultBuilder:
     """Stores information needed to create either a `JoinedSyncResult` or
     `ArchivedSyncResult`.
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index e21f8dbc58..79393c8829 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -37,7 +37,7 @@ class UserDirectoryHandler(StateDeltasHandler):
     """
 
     def __init__(self, hs):
-        super(UserDirectoryHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 3880ce0d94..8eb3638591 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -27,7 +27,7 @@ class RequestTimedOutError(SynapseError):
     """Exception representing timeout of an outbound request"""
 
     def __init__(self):
-        super(RequestTimedOutError, self).__init__(504, "Timed out")
+        super().__init__(504, "Timed out")
 
 
 def cancelled_to_request_timed_out_error(value, timeout):
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 13fcab3378..4694adc400 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -17,6 +17,18 @@
 import logging
 import urllib
 from io import BytesIO
+from typing import (
+    Any,
+    BinaryIO,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+)
 
 import treq
 from canonicaljson import encode_canonical_json
@@ -37,6 +49,7 @@ from twisted.web._newclient import ResponseDone
 from twisted.web.client import Agent, HTTPConnectionPool, readBody
 from twisted.web.http import PotentialDataLoss
 from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
 
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.http import (
@@ -57,6 +70,19 @@ incoming_responses_counter = Counter(
     "synapse_http_client_responses", "", ["method", "code"]
 )
 
+# the type of the headers list, to be passed to the t.w.h.Headers.
+# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
+# we simplify.
+RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
+
+# the value actually has to be a List, but List is invariant so we can't specify that
+# the entries can either be Lists or bytes.
+RawHeaderValue = Sequence[Union[str, bytes]]
+
+# the type of the query params, to be passed into `urlencode`
+QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
+QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
+
 
 def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
     """
@@ -285,13 +311,26 @@ class SimpleHttpClient:
                 ip_blacklist=self._ip_blacklist,
             )
 
-    async def request(self, method, uri, data=None, headers=None):
+    async def request(
+        self,
+        method: str,
+        uri: str,
+        data: Optional[bytes] = None,
+        headers: Optional[Headers] = None,
+    ) -> IResponse:
         """
         Args:
-            method (str): HTTP method to use.
-            uri (str): URI to query.
-            data (bytes): Data to send in the request body, if applicable.
-            headers (t.w.http_headers.Headers): Request headers.
+            method: HTTP method to use.
+            uri: URI to query.
+            data: Data to send in the request body, if applicable.
+            headers: Request headers.
+
+        Returns:
+            Response object, once the headers have been read.
+
+        Raises:
+            RequestTimedOutError if the request times out before the headers are read
+
         """
         # A small wrapper around self.agent.request() so we can easily attach
         # counters to it
@@ -324,6 +363,8 @@ class SimpleHttpClient:
                     headers=headers,
                     **self._extra_treq_args
                 )
+                # we use our own timeout mechanism rather than treq's as a workaround
+                # for https://twistedmatrix.com/trac/ticket/9534.
                 request_deferred = timeout_deferred(
                     request_deferred,
                     60,
@@ -353,18 +394,26 @@ class SimpleHttpClient:
                 set_tag("error_reason", e.args[0])
                 raise
 
-    async def post_urlencoded_get_json(self, uri, args={}, headers=None):
+    async def post_urlencoded_get_json(
+        self,
+        uri: str,
+        args: Mapping[str, Union[str, List[str]]] = {},
+        headers: Optional[RawHeaders] = None,
+    ) -> Any:
         """
         Args:
-            uri (str):
-            args (dict[str, str|List[str]]): query params
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: uri to query
+            args: parameters to be url-encoded in the body
+            headers: a map from header name to a list of values for that header
 
         Returns:
-            object: parsed json
+            parsed json
 
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException: On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -398,19 +447,24 @@ class SimpleHttpClient:
                 response.code, response.phrase.decode("ascii", errors="replace"), body
             )
 
-    async def post_json_get_json(self, uri, post_json, headers=None):
+    async def post_json_get_json(
+        self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
+    ) -> Any:
         """
 
         Args:
-            uri (str):
-            post_json (object):
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: URI to query.
+            post_json: request body, to be encoded as json
+            headers: a map from header name to a list of values for that header
 
         Returns:
-            object: parsed json
+            parsed json
 
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException: On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -440,21 +494,22 @@ class SimpleHttpClient:
                 response.code, response.phrase.decode("ascii", errors="replace"), body
             )
 
-    async def get_json(self, uri, args={}, headers=None):
-        """ Gets some json from the given URI.
+    async def get_json(
+        self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+    ) -> Any:
+        """Gets some json from the given URI.
 
         Args:
-            uri (str): The URI to request, not including query parameters
-            args (dict): A dictionary used to create query strings, defaults to
-                None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: The URI to request, not including query parameters
+            args: A dictionary used to create query string
+            headers: a map from header name to a list of values for that header
         Returns:
-            Succeeds when we get *any* 2xx HTTP response, with the
-            HTTP body as JSON.
+            Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -466,22 +521,27 @@ class SimpleHttpClient:
         body = await self.get_raw(uri, args, headers=headers)
         return json_decoder.decode(body.decode("utf-8"))
 
-    async def put_json(self, uri, json_body, args={}, headers=None):
-        """ Puts some json to the given URI.
+    async def put_json(
+        self,
+        uri: str,
+        json_body: Any,
+        args: QueryParams = {},
+        headers: RawHeaders = None,
+    ) -> Any:
+        """Puts some json to the given URI.
 
         Args:
-            uri (str): The URI to request, not including query parameters
-            json_body (dict): The JSON to put in the HTTP body,
-            args (dict): A dictionary used to create query strings, defaults to
-                None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: The URI to request, not including query parameters
+            json_body: The JSON to put in the HTTP body,
+            args: A dictionary used to create query strings
+            headers: a map from header name to a list of values for that header
         Returns:
-            Succeeds when we get *any* 2xx HTTP response, with the
-            HTTP body as JSON.
+            Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
         Raises:
+             RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -513,21 +573,23 @@ class SimpleHttpClient:
                 response.code, response.phrase.decode("ascii", errors="replace"), body
             )
 
-    async def get_raw(self, uri, args={}, headers=None):
-        """ Gets raw text from the given URI.
+    async def get_raw(
+        self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+    ) -> bytes:
+        """Gets raw text from the given URI.
 
         Args:
-            uri (str): The URI to request, not including query parameters
-            args (dict): A dictionary used to create query strings, defaults to
-                None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: The URI to request, not including query parameters
+            args: A dictionary used to create query strings
+            headers: a map from header name to a list of values for that header
         Returns:
-            Succeeds when we get *any* 2xx HTTP response, with the
+            Succeeds when we get a 2xx HTTP response, with the
             HTTP body as bytes.
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException on a non-2xx HTTP response.
         """
         if len(args):
@@ -552,16 +614,29 @@ class SimpleHttpClient:
     # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
     # The two should be factored out.
 
-    async def get_file(self, url, output_stream, max_size=None, headers=None):
+    async def get_file(
+        self,
+        url: str,
+        output_stream: BinaryIO,
+        max_size: Optional[int] = None,
+        headers: Optional[RawHeaders] = None,
+    ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
         """GETs a file from a given URL
         Args:
-            url (str): The URL to GET
-            output_stream (file): File to write the response body to.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            url: The URL to GET
+            output_stream: File to write the response body to.
+            headers: A map from header name to a list of values for that header
         Returns:
-            A (int,dict,string,int) tuple of the file length, dict of the response
+            A tuple of the file length, dict of the response
             headers, absolute URI of the response and HTTP response code.
+
+        Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
+            SynapseError: if the response is not a 2xx, the remote file is too large, or
+               another exception happens during the download.
         """
 
         actual_headers = {b"User-Agent": [self.user_agent]}
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index e6f067ca29..a306faa267 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -311,7 +311,7 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
     return cache_controls
 
 
-@attr.s()
+@attr.s(slots=True)
 class _FetchWellKnownFailure(Exception):
     # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
     # a temporary failure.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 5eaf3151ce..3c86cbc546 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -76,7 +76,7 @@ MAXINT = sys.maxsize
 _next_id = 1
 
 
-@attr.s(frozen=True)
+@attr.s(slots=True, frozen=True)
 class MatrixFederationRequest:
     method = attr.ib()
     """HTTP method
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 22598e02d2..2e282d9d67 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -217,11 +217,9 @@ class _Sentinel:
     def record_event_fetch(self, event_count):
         pass
 
-    def __nonzero__(self):
+    def __bool__(self):
         return False
 
-    __bool__ = __nonzero__  # python3
-
 
 SENTINEL_CONTEXT = _Sentinel()
 
diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py
index d736ad5b9b..11f60a77f7 100644
--- a/synapse/logging/formatter.py
+++ b/synapse/logging/formatter.py
@@ -30,7 +30,7 @@ class LogFormatter(logging.Formatter):
     """
 
     def __init__(self, *args, **kwargs):
-        super(LogFormatter, self).__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
 
     def formatException(self, ei):
         sio = StringIO()
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 7df0aa197d..e58850faff 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -509,7 +509,7 @@ def start_active_span_from_edu(
     ]
 
     # For some reason jaeger decided not to support the visualization of multiple parent
-    # spans or explicitely show references. I include the span context as a tag here as
+    # spans or explicitly show references. I include the span context as a tag here as
     # an aid to people debugging but it's really not an ideal solution.
 
     references += _references
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index 026854b4c7..7b9c657456 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -107,7 +107,7 @@ class _LogContextScope(Scope):
             finish_on_close (Boolean):
                 if True finish the span when the scope is closed
         """
-        super(_LogContextScope, self).__init__(manager, span)
+        super().__init__(manager, span)
         self.logcontext = logcontext
         self._finish_on_close = finish_on_close
         self._enter_logcontext = enter_logcontext
@@ -120,9 +120,9 @@ class _LogContextScope(Scope):
 
     def __exit__(self, type, value, traceback):
         if type == twisted.internet.defer._DefGen_Return:
-            super(_LogContextScope, self).__exit__(None, None, None)
+            super().__exit__(None, None, None)
         else:
-            super(_LogContextScope, self).__exit__(type, value, traceback)
+            super().__exit__(type, value, traceback)
         if self._enter_logcontext:
             self.logcontext.__exit__(type, value, traceback)
         else:  # the logcontext existed before the creation of the scope
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index fea774e2e5..becf66dd86 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -29,11 +29,11 @@ def _log_debug_as_f(f, msg, msg_args):
         lineno = f.__code__.co_firstlineno
         pathname = f.__code__.co_filename
 
-        record = logging.LogRecord(
+        record = logger.makeRecord(
             name=name,
             level=logging.DEBUG,
-            pathname=pathname,
-            lineno=lineno,
+            fn=pathname,
+            lno=lineno,
             msg=msg,
             args=msg_args,
             exc_info=None,
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 2643380d9e..a1f7ca3449 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -59,7 +59,7 @@ class RegistryProxy:
                 yield metric
 
 
-@attr.s(hash=True)
+@attr.s(slots=True, hash=True)
 class LaterGauge:
 
     name = attr.ib(type=str)
@@ -205,7 +205,7 @@ class InFlightGauge:
         all_gauges[self.name] = self
 
 
-@attr.s(hash=True)
+@attr.s(slots=True, hash=True)
 class BucketCollector:
     """
     Like a Histogram, but allows buckets to be point-in-time instead of
diff --git a/synapse/notifier.py b/synapse/notifier.py
index b7f4041306..441b3d15e2 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -42,7 +42,13 @@ from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.streams.config import PaginationConfig
-from synapse.types import Collection, StreamToken, UserID
+from synapse.types import (
+    Collection,
+    PersistedEventPosition,
+    RoomStreamToken,
+    StreamToken,
+    UserID,
+)
 from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
 from synapse.util.metrics import Measure
 from synapse.visibility import filter_events_for_client
@@ -112,7 +118,9 @@ class _NotifierUserStream:
         with PreserveLoggingContext():
             self.notify_deferred = ObservableDeferred(defer.Deferred())
 
-    def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
+    def notify(
+        self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
+    ):
         """Notify any listeners for this user of a new event from an
         event source.
         Args:
@@ -162,11 +170,9 @@ class _NotifierUserStream:
 
 
 class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
-    def __nonzero__(self):
+    def __bool__(self):
         return bool(self.events)
 
-    __bool__ = __nonzero__  # python3
-
 
 class Notifier:
     """ This class is responsible for notifying any listeners when there are
@@ -187,7 +193,7 @@ class Notifier:
         self.store = hs.get_datastore()
         self.pending_new_room_events = (
             []
-        )  # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]]
+        )  # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
 
         # Called when there are new things to stream over replication
         self.replication_callbacks = []  # type: List[Callable[[], None]]
@@ -198,6 +204,7 @@ class Notifier:
 
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
+        self._pusher_pool = hs.get_pusherpool()
 
         self.federation_sender = None
         if hs.should_send_federation():
@@ -245,9 +252,9 @@ class Notifier:
     def on_new_room_event(
         self,
         event: EventBase,
-        room_stream_id: int,
-        max_room_stream_id: int,
-        extra_users: Collection[Union[str, UserID]] = [],
+        event_pos: PersistedEventPosition,
+        max_room_stream_token: RoomStreamToken,
+        extra_users: Collection[UserID] = [],
     ):
         """ Used by handlers to inform the notifier something has happened
         in the room, room event wise.
@@ -260,61 +267,79 @@ class Notifier:
         until all previous events have been persisted before notifying
         the client streams.
         """
-        self.pending_new_room_events.append((room_stream_id, event, extra_users))
-        self._notify_pending_new_room_events(max_room_stream_id)
+        self.pending_new_room_events.append((event_pos, event, extra_users))
+        self._notify_pending_new_room_events(max_room_stream_token)
 
         self.notify_replication()
 
-    def _notify_pending_new_room_events(self, max_room_stream_id: int):
+    def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
         """Notify for the room events that were queued waiting for a previous
         event to be persisted.
         Args:
-            max_room_stream_id: The highest stream_id below which all
+            max_room_stream_token: The highest stream_id below which all
                 events have been persisted.
         """
         pending = self.pending_new_room_events
         self.pending_new_room_events = []
-        for room_stream_id, event, extra_users in pending:
-            if room_stream_id > max_room_stream_id:
-                self.pending_new_room_events.append(
-                    (room_stream_id, event, extra_users)
-                )
+
+        users = set()  # type: Set[UserID]
+        rooms = set()  # type: Set[str]
+
+        for event_pos, event, extra_users in pending:
+            if event_pos.persisted_after(max_room_stream_token):
+                self.pending_new_room_events.append((event_pos, event, extra_users))
             else:
-                self._on_new_room_event(event, room_stream_id, extra_users)
+                if (
+                    event.type == EventTypes.Member
+                    and event.membership == Membership.JOIN
+                ):
+                    self._user_joined_room(event.state_key, event.room_id)
+
+                users.update(extra_users)
+                rooms.add(event.room_id)
+
+        if users or rooms:
+            self.on_new_event(
+                "room_key", max_room_stream_token, users=users, rooms=rooms,
+            )
+            self._on_updated_room_token(max_room_stream_token)
+
+    def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken):
+        """Poke services that might care that the room position has been
+        updated.
+        """
 
-    def _on_new_room_event(
-        self,
-        event: EventBase,
-        room_stream_id: int,
-        extra_users: Collection[Union[str, UserID]] = [],
-    ):
-        """Notify any user streams that are interested in this room event"""
         # poke any interested application service.
         run_as_background_process(
-            "notify_app_services", self._notify_app_services, room_stream_id
+            "_notify_app_services", self._notify_app_services, max_room_stream_token
         )
 
-        if self.federation_sender:
-            self.federation_sender.notify_new_events(room_stream_id)
-
-        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
-            self._user_joined_room(event.state_key, event.room_id)
-
-        self.on_new_event(
-            "room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
+        run_as_background_process(
+            "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token
         )
 
-    async def _notify_app_services(self, room_stream_id: int):
+        if self.federation_sender:
+            self.federation_sender.notify_new_events(max_room_stream_token.stream)
+
+    async def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
         try:
-            await self.appservice_handler.notify_interested_services(room_stream_id)
+            await self.appservice_handler.notify_interested_services(
+                max_room_stream_token.stream
+            )
         except Exception:
             logger.exception("Error notifying application services of event")
 
+    async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
+        try:
+            await self._pusher_pool.on_new_notifications(max_room_stream_token.stream)
+        except Exception:
+            logger.exception("Error pusher pool of event")
+
     def on_new_event(
         self,
         stream_key: str,
-        new_token: int,
-        users: Collection[Union[str, UserID]] = [],
+        new_token: Union[int, RoomStreamToken],
+        users: Collection[UserID] = [],
         rooms: Collection[str] = [],
     ):
         """ Used to inform listeners that something has happened event wise.
@@ -432,8 +457,9 @@ class Notifier:
         If explicit_room_id is set, that room will be polled for events only if
         it is world readable or the user has joined the room.
         """
-        from_token = pagination_config.from_token
-        if not from_token:
+        if pagination_config.from_token:
+            from_token = pagination_config.from_token
+        else:
             from_token = self.event_sources.get_current_token()
 
         limit = pagination_config.limit
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index edf45dc599..5a437f9810 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -16,4 +16,4 @@
 
 class PusherConfigException(Exception):
     def __init__(self, msg):
-        super(PusherConfigException, self).__init__(msg)
+        super().__init__(msg)
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index b7ea4438e0..28bd8ab748 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -91,7 +91,7 @@ class EmailPusher:
                 pass
             self.timed_call = None
 
-    def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
+    def on_new_notifications(self, max_stream_ordering):
         if self.max_stream_ordering:
             self.max_stream_ordering = max(
                 max_stream_ordering, self.max_stream_ordering
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index f21fa9b659..26706bf3e1 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -114,7 +114,7 @@ class HttpPusher:
         if should_check_for_notifs:
             self._start_processing()
 
-    def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
+    def on_new_notifications(self, max_stream_ordering):
         self.max_stream_ordering = max(
             max_stream_ordering, self.max_stream_ordering or 0
         )
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 6c57854018..455a1acb46 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -123,7 +123,7 @@ class Mailer:
         params = {"token": token, "client_secret": client_secret, "sid": sid}
         link = (
             self.hs.config.public_baseurl
-            + "_matrix/client/unstable/password_reset/email/submit_token?%s"
+            + "_synapse/client/password_reset/email/submit_token?%s"
             % urllib.parse.urlencode(params)
         )
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3c3262a88c..76150e117b 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -60,10 +60,18 @@ class PusherPool:
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
+        self._account_validity = hs.config.account_validity
+
         # We shard the handling of push notifications by user ID.
         self._pusher_shard_config = hs.config.push.pusher_shard_config
         self._instance_name = hs.get_instance_name()
 
+        # Record the last stream ID that we were poked about so we can get
+        # changes since then. We set this to the current max stream ID on
+        # startup as every individual pusher will have checked for changes on
+        # startup.
+        self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
+
         # map from user id to app_id:pushkey to pusher
         self.pushers = {}  # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
 
@@ -178,20 +186,35 @@ class PusherPool:
                 )
                 await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
 
-    async def on_new_notifications(self, min_stream_id, max_stream_id):
+    async def on_new_notifications(self, max_stream_id: int):
         if not self.pushers:
             # nothing to do here.
             return
 
+        if max_stream_id < self._last_room_stream_id_seen:
+            # Nothing to do
+            return
+
+        prev_stream_id = self._last_room_stream_id_seen
+        self._last_room_stream_id_seen = max_stream_id
+
         try:
             users_affected = await self.store.get_push_action_users_in_range(
-                min_stream_id, max_stream_id
+                prev_stream_id, max_stream_id
             )
 
             for u in users_affected:
+                # Don't push if the user account has expired
+                if self._account_validity.enabled:
+                    expired = await self.store.is_account_expired(
+                        u, self.clock.time_msec()
+                    )
+                    if expired:
+                        continue
+
                 if u in self.pushers:
                     for p in self.pushers[u].values():
-                        p.on_new_notifications(min_stream_id, max_stream_id)
+                        p.on_new_notifications(max_stream_id)
 
         except Exception:
             logger.exception("Exception in pusher on_new_notifications")
@@ -209,6 +232,14 @@ class PusherPool:
             )
 
             for u in users_affected:
+                # Don't push if the user account has expired
+                if self._account_validity.enabled:
+                    expired = await self.store.is_account_expired(
+                        u, self.clock.time_msec()
+                    )
+                    if expired:
+                        continue
+
                 if u in self.pushers:
                     for p in self.pushers[u].values():
                         p.on_new_receipts(min_stream_id, max_stream_id)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 2d995ec456..288631477e 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -37,13 +37,16 @@ logger = logging.getLogger(__name__)
 # installed when that optional dependency requirement is specified. It is passed
 # to setup() as extras_require in setup.py
 #
+# Note that these both represent runtime dependencies (and the versions
+# installed are checked at runtime).
+#
 # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
 
 REQUIREMENTS = [
     "jsonschema>=2.5.1",
     "frozendict>=1",
     "unpaddedbase64>=1.1.0",
-    "canonicaljson>=1.3.0",
+    "canonicaljson>=1.4.0",
     # we use the type definitions added in signedjson 1.1.
     "signedjson>=1.1.0",
     "pynacl>=1.2.1",
@@ -92,12 +95,6 @@ CONDITIONAL_REQUIREMENTS = {
     "oidc": ["authlib>=0.14.0"],
     "systemd": ["systemd-python>=231"],
     "url_preview": ["lxml>=3.5.0"],
-    # Dependencies which are exclusively required by unit test code. This is
-    # NOT a list of all modules that are necessary to run the unit tests.
-    # Tests assume that all optional dependencies are installed.
-    #
-    # parameterized_class decorator was introduced in parameterized 0.7.0
-    "test": ["mock>=2.0", "parameterized>=0.7.0"],
     "sentry": ["sentry-sdk>=0.7.2"],
     "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
     "jwt": ["pyjwt>=1.6.4"],
@@ -110,6 +107,7 @@ ALL_OPTIONAL_REQUIREMENTS = set()  # type: Set[str]
 
 for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
     # Exclude systemd as it's a system-based requirement.
+    # Exclude lint as it's a dev-based requirement.
     if name not in ["systemd"]:
         ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS
 
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index ba16f22c91..b448da6710 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -33,7 +33,7 @@ from synapse.util.stringutils import random_string
 logger = logging.getLogger(__name__)
 
 
-class ReplicationEndpoint:
+class ReplicationEndpoint(metaclass=abc.ABCMeta):
     """Helper base class for defining new replication HTTP endpoints.
 
     This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
@@ -72,8 +72,6 @@ class ReplicationEndpoint:
             is received.
     """
 
-    __metaclass__ = abc.ABCMeta
-
     NAME = abc.abstractproperty()  # type: str  # type: ignore
     PATH_ARGS = abc.abstractproperty()  # type: Tuple[str, ...]  # type: ignore
     METHOD = "POST"
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 20f3ba76c0..807b85d2e1 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -53,7 +53,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
     CACHE = False
 
     def __init__(self, hs):
-        super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.device_list_updater = hs.get_device_handler().device_list_updater
         self.store = hs.get_datastore()
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 6b56315148..5393b9a9e7 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -57,7 +57,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
     PATH_ARGS = ()
 
     def __init__(self, hs):
-        super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
@@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         self.federation_handler = hs.get_handlers().federation_handler
 
     @staticmethod
-    async def _serialize_payload(store, event_and_contexts, backfilled):
+    async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
         """
         Args:
             store
+            room_id (str)
             event_and_contexts (list[tuple[FrozenEvent, EventContext]])
             backfilled (bool): Whether or not the events are the result of
                 backfilling
@@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
                 }
             )
 
-        payload = {"events": event_payloads, "backfilled": backfilled}
+        payload = {
+            "events": event_payloads,
+            "backfilled": backfilled,
+            "room_id": room_id,
+        }
 
         return payload
 
@@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         with Measure(self.clock, "repl_fed_send_events_parse"):
             content = parse_json_object_from_request(request)
 
+            room_id = content["room_id"]
             backfilled = content["backfilled"]
 
             event_payloads = content["events"]
@@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         logger.info("Got %d events from federation", len(event_and_contexts))
 
         max_stream_id = await self.federation_handler.persist_events_and_notify(
-            event_and_contexts, backfilled
+            room_id, event_and_contexts, backfilled
         )
 
         return 200, {"max_stream_id": max_stream_id}
@@ -144,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("edu_type",)
 
     def __init__(self, hs):
-        super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -187,7 +193,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
     CACHE = False
 
     def __init__(self, hs):
-        super(ReplicationGetQueryRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -230,7 +236,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("room_id",)
 
     def __init__(self, hs):
-        super(ReplicationCleanRoomRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
 
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index fb326bb869..4c81e2d784 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -32,7 +32,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id",)
 
     def __init__(self, hs):
-        super(RegisterDeviceReplicationServlet, self).__init__(hs)
+        super().__init__(hs)
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 741329ab5f..30680baee8 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict, Requester, UserID
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("room_id", "user_id")
 
     def __init__(self, hs):
-        super(ReplicationRemoteJoinRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.federation_handler = hs.get_handlers().federation_handler
         self.store = hs.get_datastore()
@@ -107,7 +107,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("invite_event_id",)
 
     def __init__(self, hs: "HomeServer"):
-        super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -168,7 +168,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
     CACHE = False  # No point caching as should return instantly.
 
     def __init__(self, hs):
-        super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.registeration_handler = hs.get_registration_handler()
         self.store = hs.get_datastore()
@@ -181,9 +181,9 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
         Args:
             room_id (str)
             user_id (str)
-            change (str): Either "joined" or "left"
+            change (str): "left"
         """
-        assert change in ("joined", "left")
+        assert change == "left"
 
         return {}
 
@@ -192,9 +192,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
 
         user = UserID.from_string(user_id)
 
-        if change == "joined":
-            user_joined_room(self.distributor, user, room_id)
-        elif change == "left":
+        if change == "left":
             user_left_room(self.distributor, user, room_id)
         else:
             raise Exception("Unrecognized change: %r", change)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index a02b27474d..7b12ec9060 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -29,7 +29,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id",)
 
     def __init__(self, hs):
-        super(ReplicationRegisterServlet, self).__init__(hs)
+        super().__init__(hs)
         self.store = hs.get_datastore()
         self.registration_handler = hs.get_registration_handler()
 
@@ -104,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id",)
 
     def __init__(self, hs):
-        super(ReplicationPostRegisterActionsServlet, self).__init__(hs)
+        super().__init__(hs)
         self.store = hs.get_datastore()
         self.registration_handler = hs.get_registration_handler()
 
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index f13d452426..9a3a694d5d 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -52,7 +52,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("event_id",)
 
     def __init__(self, hs):
-        super(ReplicationSendEventRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.event_creation_handler = hs.get_event_creation_handler()
         self.store = hs.get_datastore()
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 60f2e1245f..d0089fe06c 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -26,16 +26,18 @@ logger = logging.getLogger(__name__)
 
 class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(BaseSlavedStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = MultiWriterIdGenerator(
                 db_conn,
                 database,
+                stream_name="caches",
                 instance_name=hs.get_instance_name(),
                 table="cache_invalidation_stream_by_instance",
                 instance_column="instance_name",
                 id_column="stream_id",
                 sequence_name="cache_invalidation_stream_seq",
+                writers=[],
             )  # type: Optional[MultiWriterIdGenerator]
         else:
             self._cache_id_gen = None
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index bb66ba9b80..4268565fc8 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -34,7 +34,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
             ],
         )
 
-        super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index a6fdedde63..1f8dafe7ea 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -22,7 +22,7 @@ from ._base import BaseSlavedStore
 
 class SlavedClientIpStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 533d927701..5b045bed02 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
             db_conn, "device_inbox", "stream_id"
         )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 3b788c9625..e0d86240dd 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index da1cc836cf..fbffe6d85c 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -56,7 +56,7 @@ class SlavedEventStore(
     BaseSlavedStore,
 ):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedEventStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
         curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 2562b6fc38..6a23252861 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -21,7 +21,7 @@ from ._base import BaseSlavedStore
 
 class SlavedFilteringStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
     get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 567b4a5cc1..30955bcbfe 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -23,7 +23,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 025f6f6be8..55620c03d8 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -25,7 +25,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 class SlavedPresenceStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
 
         self._presence_on_startup = self._get_active_presence(db_conn)  # type: ignore
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 9da218bfe8..c418730ba8 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedPusherStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 5c2986e050..6195917376 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -30,7 +30,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 80ae803ad9..109ac6bea1 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -23,7 +23,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._public_room_id_gen = SlavedIdTracker(
             db_conn, "public_room_list_stream", "stream_id"
         )
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index d6ecf5b327..55af3d41ea 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,6 +29,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
     EventsStreamRow,
 )
+from synapse.types import PersistedEventPosition, RoomStreamToken, UserID
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -98,7 +99,6 @@ class ReplicationDataHandler:
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
-        self.pusher_pool = hs.get_pusherpool()
         self.notifier = hs.get_notifier()
         self._reactor = hs.get_reactor()
         self._clock = hs.get_clock()
@@ -148,13 +148,17 @@ class ReplicationDataHandler:
                 if event.rejected_reason:
                     continue
 
-                extra_users = ()  # type: Tuple[str, ...]
+                extra_users = ()  # type: Tuple[UserID, ...]
                 if event.type == EventTypes.Member:
-                    extra_users = (event.state_key,)
-                max_token = self.store.get_room_max_stream_ordering()
-                self.notifier.on_new_room_event(event, token, max_token, extra_users)
+                    extra_users = (UserID.from_string(event.state_key),)
 
-            await self.pusher_pool.on_new_notifications(token, token)
+                max_token = RoomStreamToken(
+                    None, self.store.get_room_max_stream_ordering()
+                )
+                event_pos = PersistedEventPosition(instance_name, token)
+                self.notifier.on_new_room_event(
+                    event, event_pos, max_token, extra_users
+                )
 
         # Notify any waiting deferreds. The list is ordered by position so we
         # just iterate through the list until we reach a position that is
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1c303f3a46..b323841f73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -109,7 +109,7 @@ class ReplicationCommandHandler:
             if isinstance(stream, (EventsStream, BackfillStream)):
                 # Only add EventStream and BackfillStream as a source on the
                 # instance in charge of event persistence.
-                if hs.config.worker.writers.events == hs.get_instance_name():
+                if hs.get_instance_name() in hs.config.worker.writers.events:
                     self._streams_to_replicate.append(stream)
 
                 continue
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 04d894fb3d..687984e7a8 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -93,7 +93,7 @@ class ReplicationStreamer:
         """
         if not self.command_handler.connected():
             # Don't bother if nothing is listening. We still need to advance
-            # the stream tokens otherwise they'll fall beihind forever
+            # the stream tokens otherwise they'll fall behind forever
             for stream in self.streams:
                 stream.discard_updates_and_advance()
             return
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 682d47f402..54dccd15a6 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -345,7 +345,7 @@ class PushRulesStream(Stream):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-        super(PushRulesStream, self).__init__(
+        super().__init__(
             hs.get_instance_name(),
             self._current_token,
             self.store.get_all_push_rule_updates,
@@ -383,7 +383,7 @@ class CachesStream(Stream):
     the cache on the workers
     """
 
-    @attr.s
+    @attr.s(slots=True)
     class CachesStreamRow:
         """Stream to inform workers they should invalidate their cache.
 
@@ -441,7 +441,7 @@ class DeviceListsStream(Stream):
     told about a device update.
     """
 
-    @attr.s
+    @attr.s(slots=True)
     class DeviceListsStreamRow:
         entity = attr.ib(type=str)
 
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f929fc3954..ccc7ca30d8 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
+from ._base import Stream, StreamUpdateResult, Token
 
 """Handling of the 'events' replication stream
 
@@ -117,7 +117,7 @@ class EventsStream(Stream):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            current_token_without_instance(self._store.get_current_events_token),
+            self._store._stream_id_gen.get_current_token_for_writer,
             self._update_function,
         )
 
diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html
new file mode 100644
index 0000000000..def4b5162b
--- /dev/null
+++ b/synapse/res/templates/password_reset_confirmation.html
@@ -0,0 +1,16 @@
+<html>
+<head></head>
+<body>
+<!--Use a hidden form to resubmit the information necessary to reset the password-->
+<form method="post">
+    <input type="hidden" name="sid" value="{{ sid }}">
+    <input type="hidden" name="token" value="{{ token }}">
+    <input type="hidden" name="client_secret" value="{{ client_secret }}">
+
+    <p>You have requested to <strong>reset your Matrix account password</strong>. Click the link below to confirm this action. <br /><br />
+        If you did not mean to do this, please close this page and your password will not be changed.</p>
+    <p><button type="submit">Confirm changing my password</button></p>
+</form>
+</body>
+</html>
+
diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html
deleted file mode 100644
index 01cd9bdaf3..0000000000
--- a/synapse/res/templates/saml_error.html
+++ /dev/null
@@ -1,52 +0,0 @@
-<!DOCTYPE html>
-<html lang="en">
-<head>
-    <meta charset="UTF-8">
-    <title>SSO login error</title>
-</head>
-<body>
-{# a 403 means we have actively rejected their login #}
-{% if code == 403 %}
-    <p>You are not allowed to log in here.</p>
-{% else %}
-    <p>
-        There was an error during authentication:
-    </p>
-    <div id="errormsg" style="margin:20px 80px">{{ msg }}</div>
-    <p>
-        If you are seeing this page after clicking a link sent to you via email, make
-        sure you only click the confirmation link once, and that you open the
-        validation link in the same client you're logging in from.
-    </p>
-    <p>
-        Try logging in again from your Matrix client and if the problem persists
-        please contact the server's administrator.
-    </p>
-
-    <script type="text/javascript">
-        // Error handling to support Auth0 errors that we might get through a GET request
-        // to the validation endpoint. If an error is provided, it's either going to be
-        // located in the query string or in a query string-like URI fragment.
-        // We try to locate the error from any of these two locations, but if we can't
-        // we just don't print anything specific.
-        let searchStr = "";
-        if (window.location.search) {
-            // window.location.searchParams isn't always defined when
-            // window.location.search is, so it's more reliable to parse the latter.
-            searchStr = window.location.search;
-        } else if (window.location.hash) {
-            // Replace the # with a ? so that URLSearchParams does the right thing and
-            // doesn't parse the first parameter incorrectly.
-            searchStr = window.location.hash.replace("#", "?");
-        }
-
-        // We might end up with no error in the URL, so we need to check if we have one
-        // to print one.
-        let errorDesc = new URLSearchParams(searchStr).get("error_description")
-        if (errorDesc) {
-            document.getElementById("errormsg").innerText = errorDesc;
-        }
-    </script>
-{% endif %}
-</body>
-</html>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 43a211386b..af8459719a 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -5,14 +5,49 @@
     <title>SSO error</title>
 </head>
 <body>
-    <p>Oops! Something went wrong during authentication.</p>
+{# If an error of unauthorised is returned it means we have actively rejected their login #}
+{% if error == "unauthorised" %}
+    <p>You are not allowed to log in here.</p>
+{% else %}
+    <p>
+        There was an error during authentication:
+    </p>
+    <div id="errormsg" style="margin:20px 80px">{{ error_description }}</div>
+    <p>
+        If you are seeing this page after clicking a link sent to you via email, make
+        sure you only click the confirmation link once, and that you open the
+        validation link in the same client you're logging in from.
+    </p>
     <p>
         Try logging in again from your Matrix client and if the problem persists
         please contact the server's administrator.
     </p>
     <p>Error: <code>{{ error }}</code></p>
-    {% if error_description %}
-    <pre><code>{{ error_description }}</code></pre>
-    {% endif %}
+
+    <script type="text/javascript">
+        // Error handling to support Auth0 errors that we might get through a GET request
+        // to the validation endpoint. If an error is provided, it's either going to be
+        // located in the query string or in a query string-like URI fragment.
+        // We try to locate the error from any of these two locations, but if we can't
+        // we just don't print anything specific.
+        let searchStr = "";
+        if (window.location.search) {
+            // window.location.searchParams isn't always defined when
+            // window.location.search is, so it's more reliable to parse the latter.
+            searchStr = window.location.search;
+        } else if (window.location.hash) {
+            // Replace the # with a ? so that URLSearchParams does the right thing and
+            // doesn't parse the first parameter incorrectly.
+            searchStr = window.location.hash.replace("#", "?");
+        }
+
+        // We might end up with no error in the URL, so we need to check if we have one
+        // to print one.
+        let errorDesc = new URLSearchParams(searchStr).get("error_description")
+        if (errorDesc) {
+            document.getElementById("errormsg").innerText = errorDesc;
+        }
+    </script>
+{% endif %}
 </body>
 </html>
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 87f927890c..40f5c32db2 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -13,8 +13,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import synapse.rest.admin
 from synapse.http.server import JsonResource
+from synapse.rest import admin
 from synapse.rest.client import versions
 from synapse.rest.client.v1 import (
     directory,
@@ -123,9 +123,7 @@ class ClientRestResource(JsonResource):
         password_policy.register_servlets(hs, client_resource)
 
         # moving to /_synapse/admin
-        synapse.rest.admin.register_servlets_for_client_rest_resource(
-            hs, client_resource
-        )
+        admin.register_servlets_for_client_rest_resource(hs, client_resource)
 
         # unstable
         shared_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 1c88c93f38..5c5f00b213 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -16,13 +16,13 @@
 
 import logging
 import platform
-import re
 
 import synapse
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.server import JsonResource
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.rest.admin._base import (
+    admin_patterns,
     assert_requester_is_admin,
     historical_admin_path_patterns,
 )
@@ -31,6 +31,7 @@ from synapse.rest.admin.devices import (
     DeviceRestServlet,
     DevicesRestServlet,
 )
+from synapse.rest.admin.event_reports import EventReportsRestServlet
 from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
 from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
 from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
@@ -49,6 +50,7 @@ from synapse.rest.admin.users import (
     ResetPasswordRestServlet,
     SearchUsersRestServlet,
     UserAdminServlet,
+    UserMembershipRestServlet,
     UserRegisterServlet,
     UserRestServletV2,
     UsersRestServlet,
@@ -61,7 +63,7 @@ logger = logging.getLogger(__name__)
 
 
 class VersionServlet(RestServlet):
-    PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
+    PATTERNS = admin_patterns("/server_version$")
 
     def __init__(self, hs):
         self.res = {
@@ -209,11 +211,13 @@ def register_servlets(hs, http_server):
     SendServerNoticeServlet(hs).register(http_server)
     VersionServlet(hs).register(http_server)
     UserAdminServlet(hs).register(http_server)
+    UserMembershipRestServlet(hs).register(http_server)
     UserRestServletV2(hs).register(http_server)
     UsersRestServletV2(hs).register(http_server)
     DeviceRestServlet(hs).register(http_server)
     DevicesRestServlet(hs).register(http_server)
     DeleteDevicesRestServlet(hs).register(http_server)
+    EventReportsRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index d82eaf5e38..db9fea263a 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -44,7 +44,7 @@ def historical_admin_path_patterns(path_regex):
     ]
 
 
-def admin_patterns(path_regex: str):
+def admin_patterns(path_regex: str, version: str = "v1"):
     """Returns the list of patterns for an admin endpoint
 
     Args:
@@ -54,7 +54,7 @@ def admin_patterns(path_regex: str):
     Returns:
         A list of regex patterns.
     """
-    admin_prefix = "^/_synapse/admin/v1"
+    admin_prefix = "^/_synapse/admin/" + version
     patterns = [re.compile(admin_prefix + path_regex)]
     return patterns
 
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 8d32677339..a163863322 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import re
 
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.http.servlet import (
@@ -21,7 +20,7 @@ from synapse.http.servlet import (
     assert_params_in_dict,
     parse_json_object_from_request,
 )
-from synapse.rest.admin._base import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
 from synapse.types import UserID
 
 logger = logging.getLogger(__name__)
@@ -32,14 +31,12 @@ class DeviceRestServlet(RestServlet):
     Get, update or delete the given user's device
     """
 
-    PATTERNS = (
-        re.compile(
-            "^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$"
-        ),
+    PATTERNS = admin_patterns(
+        "/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
     )
 
     def __init__(self, hs):
-        super(DeviceRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
@@ -98,7 +95,7 @@ class DevicesRestServlet(RestServlet):
     Retrieve the given user's devices
     """
 
-    PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),)
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
 
     def __init__(self, hs):
         """
@@ -131,9 +128,7 @@ class DeleteDevicesRestServlet(RestServlet):
     key which lists the device_ids to delete.
     """
 
-    PATTERNS = (
-        re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"),
-    )
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
 
     def __init__(self, hs):
         self.hs = hs
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
new file mode 100644
index 0000000000..5b8d0594cd
--- /dev/null
+++ b/synapse/rest/admin/event_reports.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+
+logger = logging.getLogger(__name__)
+
+
+class EventReportsRestServlet(RestServlet):
+    """
+    List all reported events that are known to the homeserver. Results are returned
+    in a dictionary containing report information. Supports pagination.
+    The requester must have administrator access in Synapse.
+
+    GET /_synapse/admin/v1/event_reports
+    returns:
+        200 OK with list of reports if success otherwise an error.
+
+    Args:
+        The parameters `from` and `limit` are required only for pagination.
+        By default, a `limit` of 100 is used.
+        The parameter `dir` can be used to define the order of results.
+        The parameter `user_id` can be used to filter by user id.
+        The parameter `room_id` can be used to filter by room id.
+    Returns:
+        A list of reported events and an integer representing the total number of
+        reported events that exist given this query
+    """
+
+    PATTERNS = admin_patterns("/event_reports$")
+
+    def __init__(self, hs):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+
+    async def on_GET(self, request):
+        await assert_requester_is_admin(self.auth, request)
+
+        start = parse_integer(request, "from", default=0)
+        limit = parse_integer(request, "limit", default=100)
+        direction = parse_string(request, "dir", default="b")
+        user_id = parse_string(request, "user_id")
+        room_id = parse_string(request, "room_id")
+
+        if start < 0:
+            raise SynapseError(
+                400,
+                "The start parameter must be a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if limit < 0:
+            raise SynapseError(
+                400,
+                "The limit parameter must be a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if direction not in ("f", "b"):
+            raise SynapseError(
+                400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+            )
+
+        event_reports, total = await self.store.get_event_reports_paginate(
+            start, limit, direction, user_id, room_id
+        )
+        ret = {"event_reports": event_reports, "total": total}
+        if (start + limit) < total:
+            ret["next_token"] = start + len(event_reports)
+
+        return 200, ret
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index f474066542..8b7bb6d44e 100644
--- a/synapse/rest/admin/purge_room_servlet.py
+++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,14 +12,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import re
-
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
     parse_json_object_from_request,
 )
 from synapse.rest.admin import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns
 
 
 class PurgeRoomServlet(RestServlet):
@@ -35,7 +34,7 @@ class PurgeRoomServlet(RestServlet):
     {}
     """
 
-    PATTERNS = (re.compile("^/_synapse/admin/v1/purge_room$"),)
+    PATTERNS = admin_patterns("/purge_room$")
 
     def __init__(self, hs):
         """
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 6e9a874121..375d055445 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,8 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import re
-
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import (
@@ -22,6 +20,7 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
 )
 from synapse.rest.admin import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.types import UserID
 
@@ -56,13 +55,13 @@ class SendServerNoticeServlet(RestServlet):
         self.snm = hs.get_server_notices_manager()
 
     def register(self, json_resource):
-        PATTERN = "^/_synapse/admin/v1/send_server_notice"
+        PATTERN = "/send_server_notice"
         json_resource.register_paths(
-            "POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__
+            "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
         )
         json_resource.register_paths(
             "PUT",
-            (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),),
+            admin_patterns(PATTERN + "/(?P<txn_id>[^/]*)$"),
             self.on_PUT,
             self.__class__.__name__,
         )
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f3e77da850..20dc1d0e05 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -15,7 +15,6 @@
 import hashlib
 import hmac
 import logging
-import re
 from http import HTTPStatus
 
 from synapse.api.constants import UserTypes
@@ -29,6 +28,7 @@ from synapse.http.servlet import (
     parse_string,
 )
 from synapse.rest.admin._base import (
+    admin_patterns,
     assert_requester_is_admin,
     assert_user_is_admin,
     historical_admin_path_patterns,
@@ -60,7 +60,7 @@ class UsersRestServlet(RestServlet):
 
 
 class UsersRestServletV2(RestServlet):
-    PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),)
+    PATTERNS = admin_patterns("/users$", "v2")
 
     """Get request to list all local users.
     This needs user to have administrator access in Synapse.
@@ -105,7 +105,7 @@ class UsersRestServletV2(RestServlet):
 
 
 class UserRestServletV2(RestServlet):
-    PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]+)$"),)
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
 
     """Get request to list user details.
     This needs user to have administrator access in Synapse.
@@ -642,7 +642,7 @@ class UserAdminServlet(RestServlet):
                 {}
     """
 
-    PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>[^/]*)/admin$"),)
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
 
     def __init__(self, hs):
         self.hs = hs
@@ -683,3 +683,29 @@ class UserAdminServlet(RestServlet):
         await self.store.set_server_admin(target_user, set_admin_to)
 
         return 200, {}
+
+
+class UserMembershipRestServlet(RestServlet):
+    """
+    Get room list of an user.
+    """
+
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
+
+    def __init__(self, hs):
+        self.is_mine = hs.is_mine
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+
+    async def on_GET(self, request, user_id):
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.is_mine(UserID.from_string(user_id)):
+            raise SynapseError(400, "Can only lookup local users")
+
+        room_ids = await self.store.get_rooms_for_user(user_id)
+        if not room_ids:
+            raise NotFoundError("User not found")
+
+        ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
+        return 200, ret
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index b210015173..faabeeb91c 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -40,7 +40,7 @@ class ClientDirectoryServer(RestServlet):
     PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
 
     def __init__(self, hs):
-        super(ClientDirectoryServer, self).__init__()
+        super().__init__()
         self.store = hs.get_datastore()
         self.handlers = hs.get_handlers()
         self.auth = hs.get_auth()
@@ -120,7 +120,7 @@ class ClientDirectoryListServer(RestServlet):
     PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
 
     def __init__(self, hs):
-        super(ClientDirectoryListServer, self).__init__()
+        super().__init__()
         self.store = hs.get_datastore()
         self.handlers = hs.get_handlers()
         self.auth = hs.get_auth()
@@ -160,7 +160,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
     )
 
     def __init__(self, hs):
-        super(ClientAppserviceDirectoryListServer, self).__init__()
+        super().__init__()
         self.store = hs.get_datastore()
         self.handlers = hs.get_handlers()
         self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 25effd0261..985d994f6b 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -30,7 +30,7 @@ class EventStreamRestServlet(RestServlet):
     DEFAULT_LONGPOLL_TIME_MS = 30000
 
     def __init__(self, hs):
-        super(EventStreamRestServlet, self).__init__()
+        super().__init__()
         self.event_stream_handler = hs.get_event_stream_handler()
         self.auth = hs.get_auth()
 
@@ -74,7 +74,7 @@ class EventRestServlet(RestServlet):
     PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
 
     def __init__(self, hs):
-        super(EventRestServlet, self).__init__()
+        super().__init__()
         self.clock = hs.get_clock()
         self.event_handler = hs.get_event_handler()
         self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 910b3b4eeb..d7042786ce 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -24,7 +24,7 @@ class InitialSyncRestServlet(RestServlet):
     PATTERNS = client_patterns("/initialSync$", v1=True)
 
     def __init__(self, hs):
-        super(InitialSyncRestServlet, self).__init__()
+        super().__init__()
         self.initial_sync_handler = hs.get_initial_sync_handler()
         self.auth = hs.get_auth()
 
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index a14618ac84..250b03a025 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,6 +18,7 @@ from typing import Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.appservice import ApplicationService
 from synapse.handlers.auth import (
     convert_client_dict_legacy_fields_to_identifier,
     login_id_phone_to_thirdparty,
@@ -44,9 +45,10 @@ class LoginRestServlet(RestServlet):
     TOKEN_TYPE = "m.login.token"
     JWT_TYPE = "org.matrix.login.jwt"
     JWT_TYPE_DEPRECATED = "m.login.jwt"
+    APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
 
     def __init__(self, hs):
-        super(LoginRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
 
         # JWT configuration variables.
@@ -61,6 +63,8 @@ class LoginRestServlet(RestServlet):
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
 
+        self.auth = hs.get_auth()
+
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self.handlers = hs.get_handlers()
@@ -107,6 +111,8 @@ class LoginRestServlet(RestServlet):
             ({"type": t} for t in self.auth_handler.get_supported_login_types())
         )
 
+        flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
+
         return 200, {"flows": flows}
 
     def on_OPTIONS(self, request: SynapseRequest):
@@ -116,8 +122,12 @@ class LoginRestServlet(RestServlet):
         self._address_ratelimiter.ratelimit(request.getClientIP())
 
         login_submission = parse_json_object_from_request(request)
+
         try:
-            if self.jwt_enabled and (
+            if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
+                appservice = self.auth.get_appservice_by_req(request)
+                result = await self._do_appservice_login(login_submission, appservice)
+            elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
@@ -134,6 +144,33 @@ class LoginRestServlet(RestServlet):
             result["well_known"] = well_known_data
         return 200, result
 
+    def _get_qualified_user_id(self, identifier):
+        if identifier["type"] != "m.id.user":
+            raise SynapseError(400, "Unknown login identifier type")
+        if "user" not in identifier:
+            raise SynapseError(400, "User identifier is missing 'user' key")
+
+        if identifier["user"].startswith("@"):
+            return identifier["user"]
+        else:
+            return UserID(identifier["user"], self.hs.hostname).to_string()
+
+    async def _do_appservice_login(
+        self, login_submission: JsonDict, appservice: ApplicationService
+    ):
+        logger.info(
+            "Got appservice login request with identifier: %r",
+            login_submission.get("identifier"),
+        )
+
+        identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
+        qualified_user_id = self._get_qualified_user_id(identifier)
+
+        if not appservice.is_interested_in_user(qualified_user_id):
+            raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
+
+        return await self._complete_login(qualified_user_id, login_submission)
+
     async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
         """Handle non-token/saml/jwt logins
 
@@ -219,15 +256,7 @@ class LoginRestServlet(RestServlet):
 
         # by this point, the identifier should be an m.id.user: if it's anything
         # else, we haven't understood it.
-        if identifier["type"] != "m.id.user":
-            raise SynapseError(400, "Unknown login identifier type")
-        if "user" not in identifier:
-            raise SynapseError(400, "User identifier is missing 'user' key")
-
-        if identifier["user"].startswith("@"):
-            qualified_user_id = identifier["user"]
-        else:
-            qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
+        qualified_user_id = self._get_qualified_user_id(identifier)
 
         # Check if we've hit the failed ratelimit (but don't update it)
         self._failed_attempts_ratelimiter.ratelimit(
@@ -400,7 +429,7 @@ class CasTicketServlet(RestServlet):
     PATTERNS = client_patterns("/login/cas/ticket", v1=True)
 
     def __init__(self, hs):
-        super(CasTicketServlet, self).__init__()
+        super().__init__()
         self._cas_handler = hs.get_cas_handler()
 
     async def on_GET(self, request: SynapseRequest) -> None:
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index b0c30b65be..f792b50cdc 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -25,7 +25,7 @@ class LogoutRestServlet(RestServlet):
     PATTERNS = client_patterns("/logout$", v1=True)
 
     def __init__(self, hs):
-        super(LogoutRestServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
@@ -53,7 +53,7 @@ class LogoutAllRestServlet(RestServlet):
     PATTERNS = client_patterns("/logout/all$", v1=True)
 
     def __init__(self, hs):
-        super(LogoutAllRestServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 970fdd5834..79d8e3057f 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -30,7 +30,7 @@ class PresenceStatusRestServlet(RestServlet):
     PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
 
     def __init__(self, hs):
-        super(PresenceStatusRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.presence_handler = hs.get_presence_handler()
         self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..b686cd671f 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -25,7 +25,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
     PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
 
     def __init__(self, hs):
-        super(ProfileDisplaynameRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.profile_handler = hs.get_profile_handler()
         self.auth = hs.get_auth()
@@ -73,7 +73,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
     PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
 
     def __init__(self, hs):
-        super(ProfileAvatarURLRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.profile_handler = hs.get_profile_handler()
         self.auth = hs.get_auth()
@@ -124,7 +124,7 @@ class ProfileRestServlet(RestServlet):
     PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
 
     def __init__(self, hs):
-        super(ProfileRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.profile_handler = hs.get_profile_handler()
         self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e781a3bcf4..f9eecb7cf5 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -38,7 +38,7 @@ class PushRuleRestServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(PushRuleRestServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
@@ -163,6 +163,18 @@ class PushRuleRestServlet(RestServlet):
         self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
 
     async def set_rule_attr(self, user_id, spec, val):
+        if spec["attr"] not in ("enabled", "actions"):
+            # for the sake of potential future expansion, shouldn't report
+            # 404 in the case of an unknown request so check it corresponds to
+            # a known attribute first.
+            raise UnrecognizedRequestError()
+
+        namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+        rule_id = spec["rule_id"]
+        is_default_rule = rule_id.startswith(".")
+        if is_default_rule:
+            if namespaced_rule_id not in BASE_RULE_IDS:
+                raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
         if spec["attr"] == "enabled":
             if isinstance(val, dict) and "enabled" in val:
                 val = val["enabled"]
@@ -171,9 +183,8 @@ class PushRuleRestServlet(RestServlet):
                 # This should *actually* take a dict, but many clients pass
                 # bools directly, so let's not break them.
                 raise SynapseError(400, "Value for 'enabled' must be boolean")
-            namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
             return await self.store.set_push_rule_enabled(
-                user_id, namespaced_rule_id, val
+                user_id, namespaced_rule_id, val, is_default_rule
             )
         elif spec["attr"] == "actions":
             actions = val.get("actions")
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 5f65cb7d83..28dabf1c7a 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -44,7 +44,7 @@ class PushersRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers$", v1=True)
 
     def __init__(self, hs):
-        super(PushersRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
 
@@ -68,7 +68,7 @@ class PushersSetRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers/set$", v1=True)
 
     def __init__(self, hs):
-        super(PushersSetRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
@@ -153,7 +153,7 @@ class PushersRemoveRestServlet(RestServlet):
     SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
 
     def __init__(self, hs):
-        super(PushersRemoveRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.notifier = hs.get_notifier()
         self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 84baf3d59b..7e64a2e0fe 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
 
 class TransactionRestServlet(RestServlet):
     def __init__(self, hs):
-        super(TransactionRestServlet, self).__init__()
+        super().__init__()
         self.txns = HttpTransactionCache(hs)
 
 
@@ -65,7 +65,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
     # No PATTERN; we have custom dispatch rules here
 
     def __init__(self, hs):
-        super(RoomCreateRestServlet, self).__init__(hs)
+        super().__init__(hs)
         self._room_creation_handler = hs.get_room_creation_handler()
         self.auth = hs.get_auth()
 
@@ -111,7 +111,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
 # TODO: Needs unit testing for generic events
 class RoomStateEventRestServlet(TransactionRestServlet):
     def __init__(self, hs):
-        super(RoomStateEventRestServlet, self).__init__(hs)
+        super().__init__(hs)
         self.handlers = hs.get_handlers()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.room_member_handler = hs.get_room_member_handler()
@@ -229,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
 # TODO: Needs unit testing for generic events + feedback
 class RoomSendEventRestServlet(TransactionRestServlet):
     def __init__(self, hs):
-        super(RoomSendEventRestServlet, self).__init__(hs)
+        super().__init__(hs)
         self.event_creation_handler = hs.get_event_creation_handler()
         self.auth = hs.get_auth()
 
@@ -280,7 +280,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
 # TODO: Needs unit testing for room ID + alias joins
 class JoinRoomAliasServlet(TransactionRestServlet):
     def __init__(self, hs):
-        super(JoinRoomAliasServlet, self).__init__(hs)
+        super().__init__(hs)
         self.room_member_handler = hs.get_room_member_handler()
         self.auth = hs.get_auth()
 
@@ -343,7 +343,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
     PATTERNS = client_patterns("/publicRooms$", v1=True)
 
     def __init__(self, hs):
-        super(PublicRoomListRestServlet, self).__init__(hs)
+        super().__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
 
@@ -448,7 +448,7 @@ class RoomMemberListRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
 
     def __init__(self, hs):
-        super(RoomMemberListRestServlet, self).__init__()
+        super().__init__()
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
@@ -499,7 +499,7 @@ class JoinedRoomMemberListRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
 
     def __init__(self, hs):
-        super(JoinedRoomMemberListRestServlet, self).__init__()
+        super().__init__()
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
@@ -518,7 +518,7 @@ class RoomMessageListRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
 
     def __init__(self, hs):
-        super(RoomMessageListRestServlet, self).__init__()
+        super().__init__()
         self.pagination_handler = hs.get_pagination_handler()
         self.auth = hs.get_auth()
 
@@ -557,7 +557,7 @@ class RoomStateRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
 
     def __init__(self, hs):
-        super(RoomStateRestServlet, self).__init__()
+        super().__init__()
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
@@ -577,7 +577,7 @@ class RoomInitialSyncRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
 
     def __init__(self, hs):
-        super(RoomInitialSyncRestServlet, self).__init__()
+        super().__init__()
         self.initial_sync_handler = hs.get_initial_sync_handler()
         self.auth = hs.get_auth()
 
@@ -596,7 +596,7 @@ class RoomEventServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(RoomEventServlet, self).__init__()
+        super().__init__()
         self.clock = hs.get_clock()
         self.event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
@@ -628,7 +628,7 @@ class RoomEventContextServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(RoomEventContextServlet, self).__init__()
+        super().__init__()
         self.clock = hs.get_clock()
         self.room_context_handler = hs.get_room_context_handler()
         self._event_serializer = hs.get_event_client_serializer()
@@ -675,7 +675,7 @@ class RoomEventContextServlet(RestServlet):
 
 class RoomForgetRestServlet(TransactionRestServlet):
     def __init__(self, hs):
-        super(RoomForgetRestServlet, self).__init__(hs)
+        super().__init__(hs)
         self.room_member_handler = hs.get_room_member_handler()
         self.auth = hs.get_auth()
 
@@ -701,7 +701,7 @@ class RoomForgetRestServlet(TransactionRestServlet):
 # TODO: Needs unit testing
 class RoomMembershipRestServlet(TransactionRestServlet):
     def __init__(self, hs):
-        super(RoomMembershipRestServlet, self).__init__(hs)
+        super().__init__(hs)
         self.room_member_handler = hs.get_room_member_handler()
         self.auth = hs.get_auth()
 
@@ -792,7 +792,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
 
 class RoomRedactEventRestServlet(TransactionRestServlet):
     def __init__(self, hs):
-        super(RoomRedactEventRestServlet, self).__init__(hs)
+        super().__init__(hs)
         self.handlers = hs.get_handlers()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.auth = hs.get_auth()
@@ -841,7 +841,7 @@ class RoomTypingRestServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(RoomTypingRestServlet, self).__init__()
+        super().__init__()
         self.presence_handler = hs.get_presence_handler()
         self.typing_handler = hs.get_typing_handler()
         self.auth = hs.get_auth()
@@ -914,7 +914,7 @@ class SearchRestServlet(RestServlet):
     PATTERNS = client_patterns("/search$", v1=True)
 
     def __init__(self, hs):
-        super(SearchRestServlet, self).__init__()
+        super().__init__()
         self.handlers = hs.get_handlers()
         self.auth = hs.get_auth()
 
@@ -935,7 +935,7 @@ class JoinedRoomsRestServlet(RestServlet):
     PATTERNS = client_patterns("/joined_rooms$", v1=True)
 
     def __init__(self, hs):
-        super(JoinedRoomsRestServlet, self).__init__()
+        super().__init__()
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 50277c6cf6..b8d491ca5c 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -25,7 +25,7 @@ class VoipRestServlet(RestServlet):
     PATTERNS = client_patterns("/voip/turnServer$", v1=True)
 
     def __init__(self, hs):
-        super(VoipRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
 
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 3481477731..c3ce0f6259 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -17,6 +17,11 @@
 import logging
 import random
 from http import HTTPStatus
+from typing import TYPE_CHECKING
+from urllib.parse import urlparse
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
@@ -47,7 +52,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password/email/requestToken$")
 
     def __init__(self, hs):
-        super(EmailPasswordRequestTokenRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.datastore = hs.get_datastore()
         self.config = hs.config
@@ -98,6 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        # Raise if the provided next_link value isn't valid
+        assert_valid_next_link(self.hs, next_link)
+
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
         # an email address which is controlled by the attacker but which, after
@@ -144,86 +152,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         return 200, ret
 
 
-class PasswordResetSubmitTokenServlet(RestServlet):
-    """Handles 3PID validation token submission"""
-
-    PATTERNS = client_patterns(
-        "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
-    )
-
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
-        super(PasswordResetSubmitTokenServlet, self).__init__()
-        self.hs = hs
-        self.auth = hs.get_auth()
-        self.config = hs.config
-        self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
-        if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
-            self._failure_email_template = (
-                self.config.email_password_reset_template_failure_html
-            )
-
-    async def on_GET(self, request, medium):
-        # We currently only handle threepid token submissions for email
-        if medium != "email":
-            raise SynapseError(
-                400, "This medium is currently not supported for password resets"
-            )
-        if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "Password reset emails have been disabled due to lack of an email config"
-                )
-            raise SynapseError(
-                400, "Email-based password resets are disabled on this server"
-            )
-
-        sid = parse_string(request, "sid", required=True)
-        token = parse_string(request, "token", required=True)
-        client_secret = parse_string(request, "client_secret", required=True)
-        assert_valid_client_secret(client_secret)
-
-        # Attempt to validate a 3PID session
-        try:
-            # Mark the session as valid
-            next_link = await self.store.validate_threepid_session(
-                sid, client_secret, token, self.clock.time_msec()
-            )
-
-            # Perform a 302 redirect if next_link is set
-            if next_link:
-                if next_link.startswith("file:///"):
-                    logger.warning(
-                        "Not redirecting to next_link as it is a local file: address"
-                    )
-                else:
-                    request.setResponseCode(302)
-                    request.setHeader("Location", next_link)
-                    finish_request(request)
-                    return None
-
-            # Otherwise show the success template
-            html = self.config.email_password_reset_template_success_html_content
-            status_code = 200
-        except ThreepidValidationError as e:
-            status_code = e.code
-
-            # Show a failure page with a reason
-            template_vars = {"failure_reason": e.msg}
-            html = self._failure_email_template.render(**template_vars)
-
-        respond_with_html(request, status_code, html)
-
-
 class PasswordRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password$")
 
     def __init__(self, hs):
-        super(PasswordRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
@@ -349,7 +282,7 @@ class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/deactivate$")
 
     def __init__(self, hs):
-        super(DeactivateAccountRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
@@ -368,7 +301,7 @@ class DeactivateAccountRestServlet(RestServlet):
 
         requester = await self.auth.get_user_by_req(request)
 
-        # allow ASes to dectivate their own users
+        # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
                 requester.user.to_string(), erase
@@ -397,7 +330,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/email/requestToken$")
 
     def __init__(self, hs):
-        super(EmailThreepidRequestTokenRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.config = hs.config
         self.identity_handler = hs.get_handlers().identity_handler
@@ -446,6 +379,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        # Raise if the provided next_link value isn't valid
+        assert_valid_next_link(self.hs, next_link)
+
         existing_user_id = await self.store.get_user_id_by_threepid("email", email)
 
         if existing_user_id is not None:
@@ -491,7 +427,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
 
     def __init__(self, hs):
         self.hs = hs
-        super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
+        super().__init__()
         self.store = self.hs.get_datastore()
         self.identity_handler = hs.get_handlers().identity_handler
 
@@ -517,6 +453,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        # Raise if the provided next_link value isn't valid
+        assert_valid_next_link(self.hs, next_link)
+
         existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
 
         if existing_user_id is not None:
@@ -603,15 +542,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
 
             # Perform a 302 redirect if next_link is set
             if next_link:
-                if next_link.startswith("file:///"):
-                    logger.warning(
-                        "Not redirecting to next_link as it is a local file: address"
-                    )
-                else:
-                    request.setResponseCode(302)
-                    request.setHeader("Location", next_link)
-                    finish_request(request)
-                    return None
+                request.setResponseCode(302)
+                request.setHeader("Location", next_link)
+                finish_request(request)
+                return None
 
             # Otherwise show the success template
             html = self.config.email_add_threepid_template_success_html_content
@@ -672,7 +606,7 @@ class ThreepidRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid$")
 
     def __init__(self, hs):
-        super(ThreepidRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
@@ -728,7 +662,7 @@ class ThreepidAddRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/add$")
 
     def __init__(self, hs):
-        super(ThreepidAddRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
@@ -779,7 +713,7 @@ class ThreepidBindRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/bind$")
 
     def __init__(self, hs):
-        super(ThreepidBindRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
@@ -808,7 +742,7 @@ class ThreepidUnbindRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/unbind$")
 
     def __init__(self, hs):
-        super(ThreepidUnbindRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
@@ -839,7 +773,7 @@ class ThreepidDeleteRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/delete$")
 
     def __init__(self, hs):
-        super(ThreepidDeleteRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
@@ -875,11 +809,50 @@ class ThreepidDeleteRestServlet(RestServlet):
         return 200, {"id_server_unbind_result": id_server_unbind_result}
 
 
+def assert_valid_next_link(hs: "HomeServer", next_link: str):
+    """
+    Raises a SynapseError if a given next_link value is invalid
+
+    next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config
+    option is either empty or contains a domain that matches the one in the given next_link
+
+    Args:
+        hs: The homeserver object
+        next_link: The next_link value given by the client
+
+    Raises:
+        SynapseError: If the next_link is invalid
+    """
+    valid = True
+
+    # Parse the contents of the URL
+    next_link_parsed = urlparse(next_link)
+
+    # Scheme must not point to the local drive
+    if next_link_parsed.scheme == "file":
+        valid = False
+
+    # If the domain whitelist is set, the domain must be in it
+    if (
+        valid
+        and hs.config.next_link_domain_whitelist is not None
+        and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
+    ):
+        valid = False
+
+    if not valid:
+        raise SynapseError(
+            400,
+            "'next_link' domain not included in whitelist, or not http(s)",
+            errcode=Codes.INVALID_PARAM,
+        )
+
+
 class WhoamiRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/whoami$")
 
     def __init__(self, hs):
-        super(WhoamiRestServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
 
     async def on_GET(self, request):
@@ -890,7 +863,6 @@ class WhoamiRestServlet(RestServlet):
 
 def register_servlets(hs, http_server):
     EmailPasswordRequestTokenRestServlet(hs).register(http_server)
-    PasswordResetSubmitTokenServlet(hs).register(http_server)
     PasswordRestServlet(hs).register(http_server)
     DeactivateAccountRestServlet(hs).register(http_server)
     EmailThreepidRequestTokenRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index c1d4cd0caf..87a5b1b86b 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -34,7 +34,7 @@ class AccountDataServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(AccountDataServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
@@ -86,7 +86,7 @@ class RoomAccountDataServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(RoomAccountDataServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index d06336ceea..bd7f9ae203 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -32,7 +32,7 @@ class AccountValidityRenewServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(AccountValidityRenewServlet, self).__init__()
+        super().__init__()
 
         self.hs = hs
         self.account_activity_handler = hs.get_account_validity_handler()
@@ -67,7 +67,7 @@ class AccountValiditySendMailServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(AccountValiditySendMailServlet, self).__init__()
+        super().__init__()
 
         self.hs = hs
         self.account_activity_handler = hs.get_account_validity_handler()
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8e585e9153..097538f968 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -124,7 +124,7 @@ class AuthRestServlet(RestServlet):
     PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
 
     def __init__(self, hs):
-        super(AuthRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index fe9d019c44..76879ac559 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -32,7 +32,7 @@ class CapabilitiesRestServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(CapabilitiesRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.config = hs.config
         self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index c0714fcfb1..7e174de692 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -35,7 +35,7 @@ class DevicesRestServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(DevicesRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
@@ -57,7 +57,7 @@ class DeleteDevicesRestServlet(RestServlet):
     PATTERNS = client_patterns("/delete_devices")
 
     def __init__(self, hs):
-        super(DeleteDevicesRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
@@ -102,7 +102,7 @@ class DeviceRestServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(DeviceRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index b28da017cd..7cc692643b 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -28,7 +28,7 @@ class GetFilterRestServlet(RestServlet):
     PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
 
     def __init__(self, hs):
-        super(GetFilterRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.filtering = hs.get_filtering()
@@ -64,7 +64,7 @@ class CreateFilterRestServlet(RestServlet):
     PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
 
     def __init__(self, hs):
-        super(CreateFilterRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.filtering = hs.get_filtering()
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 13ecf7005d..a3bb095c2d 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -32,7 +32,7 @@ class GroupServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
 
     def __init__(self, hs):
-        super(GroupServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -66,7 +66,7 @@ class GroupSummaryServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
 
     def __init__(self, hs):
-        super(GroupSummaryServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -97,7 +97,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(GroupSummaryRoomsCatServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -137,7 +137,7 @@ class GroupCategoryServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(GroupCategoryServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -181,7 +181,7 @@ class GroupCategoriesServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
 
     def __init__(self, hs):
-        super(GroupCategoriesServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -204,7 +204,7 @@ class GroupRoleServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
 
     def __init__(self, hs):
-        super(GroupRoleServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -248,7 +248,7 @@ class GroupRolesServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
 
     def __init__(self, hs):
-        super(GroupRolesServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -279,7 +279,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(GroupSummaryUsersRoleServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -317,7 +317,7 @@ class GroupRoomServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
 
     def __init__(self, hs):
-        super(GroupRoomServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -343,7 +343,7 @@ class GroupUsersServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
 
     def __init__(self, hs):
-        super(GroupUsersServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -366,7 +366,7 @@ class GroupInvitedUsersServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
 
     def __init__(self, hs):
-        super(GroupInvitedUsersServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -389,7 +389,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
 
     def __init__(self, hs):
-        super(GroupSettingJoinPolicyServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.groups_handler = hs.get_groups_local_handler()
 
@@ -413,7 +413,7 @@ class GroupCreateServlet(RestServlet):
     PATTERNS = client_patterns("/create_group$")
 
     def __init__(self, hs):
-        super(GroupCreateServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -444,7 +444,7 @@ class GroupAdminRoomsServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(GroupAdminRoomsServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -481,7 +481,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(GroupAdminRoomsConfigServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -507,7 +507,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(GroupAdminUsersInviteServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -536,7 +536,7 @@ class GroupAdminUsersKickServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(GroupAdminUsersKickServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -560,7 +560,7 @@ class GroupSelfLeaveServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
 
     def __init__(self, hs):
-        super(GroupSelfLeaveServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -584,7 +584,7 @@ class GroupSelfJoinServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
 
     def __init__(self, hs):
-        super(GroupSelfJoinServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -608,7 +608,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
 
     def __init__(self, hs):
-        super(GroupSelfAcceptInviteServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
@@ -632,7 +632,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
 
     def __init__(self, hs):
-        super(GroupSelfUpdatePublicityServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
@@ -655,7 +655,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
     PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
 
     def __init__(self, hs):
-        super(PublicisedGroupsForUserServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
@@ -676,7 +676,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
     PATTERNS = client_patterns("/publicised_groups$")
 
     def __init__(self, hs):
-        super(PublicisedGroupsForUsersServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
@@ -700,7 +700,7 @@ class GroupsForUserServlet(RestServlet):
     PATTERNS = client_patterns("/joined_groups$")
 
     def __init__(self, hs):
-        super(GroupsForUserServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 24bb090822..7abd6ff333 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(KeyUploadServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
 
@@ -147,7 +147,7 @@ class KeyQueryServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer):
         """
-        super(KeyQueryServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
 
@@ -177,7 +177,7 @@ class KeyChangesServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer):
         """
-        super(KeyChangesServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
 
@@ -222,7 +222,7 @@ class OneTimeKeyServlet(RestServlet):
     PATTERNS = client_patterns("/keys/claim$")
 
     def __init__(self, hs):
-        super(OneTimeKeyServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
 
@@ -250,7 +250,7 @@ class SigningKeyUploadServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(SigningKeyUploadServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -308,7 +308,7 @@ class SignaturesUploadServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(SignaturesUploadServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
 
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index aa911d75ee..87063ec8b1 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -27,7 +27,7 @@ class NotificationsServlet(RestServlet):
     PATTERNS = client_patterns("/notifications$")
 
     def __init__(self, hs):
-        super(NotificationsServlet, self).__init__()
+        super().__init__()
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index 6ae9a5a8e9..5b996e2d63 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -60,7 +60,7 @@ class IdTokenServlet(RestServlet):
     EXPIRES_MS = 3600 * 1000
 
     def __init__(self, hs):
-        super(IdTokenServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
index 968403cca4..68b27ff23a 100644
--- a/synapse/rest/client/v2_alpha/password_policy.py
+++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -30,7 +30,7 @@ class PasswordPolicyServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(PasswordPolicyServlet, self).__init__()
+        super().__init__()
 
         self.policy = hs.config.password_policy
         self.enabled = hs.config.password_policy_enabled
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index 67cbc37312..55c6688f52 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -26,7 +26,7 @@ class ReadMarkerRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
 
     def __init__(self, hs):
-        super(ReadMarkerRestServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_receipts_handler()
         self.read_marker_handler = hs.get_read_marker_handler()
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 92555bd4a9..6f7246a394 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -31,7 +31,7 @@ class ReceiptRestServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(ReceiptRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_receipts_handler()
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b6b90a8b30..ffa2dfce42 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -76,7 +76,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(EmailRegisterRequestTokenRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
         self.config = hs.config
@@ -174,7 +174,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
 
@@ -249,7 +249,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(RegistrationSubmitTokenServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.config = hs.config
@@ -319,7 +319,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(UsernameAvailabilityRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.registration_handler = hs.get_registration_handler()
         self.ratelimiter = FederationRateLimiter(
@@ -363,7 +363,7 @@ class RegisterRestServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(RegisterRestServlet, self).__init__()
+        super().__init__()
 
         self.hs = hs
         self.auth = hs.get_auth()
@@ -431,11 +431,14 @@ class RegisterRestServlet(RestServlet):
 
             access_token = self.auth.get_access_token_from_request(request)
 
-            if isinstance(desired_username, str):
-                result = await self._do_appservice_registration(
-                    desired_username, access_token, body
-                )
-            return 200, result  # we throw for non 200 responses
+            if not isinstance(desired_username, str):
+                raise SynapseError(400, "Desired Username is missing or not a string")
+
+            result = await self._do_appservice_registration(
+                desired_username, access_token, body
+            )
+
+            return 200, result
 
         # == Normal User Registration == (everyone else)
         if not self._registration_enabled:
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index e29f49f7f5..18c75738f8 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -61,7 +61,7 @@ class RelationSendServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(RelationSendServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.txns = HttpTransactionCache(hs)
@@ -138,7 +138,7 @@ class RelationPaginationServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(RelationPaginationServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -233,7 +233,7 @@ class RelationAggregationPaginationServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(RelationAggregationPaginationServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.event_handler = hs.get_event_handler()
@@ -311,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(RelationAggregationGroupPaginationServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index e15927c4ea..215d619ca1 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -32,7 +32,7 @@ class ReportEventRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
 
     def __init__(self, hs):
-        super(ReportEventRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 59529707df..53de97923f 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -37,7 +37,7 @@ class RoomKeysServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(RoomKeysServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
@@ -248,7 +248,7 @@ class RoomKeysNewVersionServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(RoomKeysNewVersionServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
@@ -301,7 +301,7 @@ class RoomKeysVersionServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(RoomKeysVersionServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index 39a5518614..bf030e0ff4 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -53,7 +53,7 @@ class RoomUpgradeRestServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(RoomUpgradeRestServlet, self).__init__()
+        super().__init__()
         self._hs = hs
         self._room_creation_handler = hs.get_room_creation_handler()
         self._auth = hs.get_auth()
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index db829f3098..bc4f43639a 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -36,7 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(SendToDeviceRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.txns = HttpTransactionCache(hs)
diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py
index 2492634dac..c866d5151c 100644
--- a/synapse/rest/client/v2_alpha/shared_rooms.py
+++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -34,7 +34,7 @@ class UserSharedRoomsServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(UserSharedRoomsServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.user_directory_active = hs.config.update_user_directory
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a0b00135e1..51e395cc64 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -74,7 +74,7 @@ class SyncRestServlet(RestServlet):
     ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
 
     def __init__(self, hs):
-        super(SyncRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.sync_handler = hs.get_sync_handler()
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index a3f12e8a77..bf3a79db44 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -31,7 +31,7 @@ class TagListServlet(RestServlet):
     PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
 
     def __init__(self, hs):
-        super(TagListServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
@@ -56,7 +56,7 @@ class TagServlet(RestServlet):
     )
 
     def __init__(self, hs):
-        super(TagServlet, self).__init__()
+        super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 23709960ad..0c127a1b5f 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -28,7 +28,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
     PATTERNS = client_patterns("/thirdparty/protocols")
 
     def __init__(self, hs):
-        super(ThirdPartyProtocolsServlet, self).__init__()
+        super().__init__()
 
         self.auth = hs.get_auth()
         self.appservice_handler = hs.get_application_service_handler()
@@ -44,7 +44,7 @@ class ThirdPartyProtocolServlet(RestServlet):
     PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
 
     def __init__(self, hs):
-        super(ThirdPartyProtocolServlet, self).__init__()
+        super().__init__()
 
         self.auth = hs.get_auth()
         self.appservice_handler = hs.get_application_service_handler()
@@ -65,7 +65,7 @@ class ThirdPartyUserServlet(RestServlet):
     PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
 
     def __init__(self, hs):
-        super(ThirdPartyUserServlet, self).__init__()
+        super().__init__()
 
         self.auth = hs.get_auth()
         self.appservice_handler = hs.get_application_service_handler()
@@ -87,7 +87,7 @@ class ThirdPartyLocationServlet(RestServlet):
     PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
 
     def __init__(self, hs):
-        super(ThirdPartyLocationServlet, self).__init__()
+        super().__init__()
 
         self.auth = hs.get_auth()
         self.appservice_handler = hs.get_application_service_handler()
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 83f3b6b70a..79317c74ba 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -28,7 +28,7 @@ class TokenRefreshRestServlet(RestServlet):
     PATTERNS = client_patterns("/tokenrefresh")
 
     def __init__(self, hs):
-        super(TokenRefreshRestServlet, self).__init__()
+        super().__init__()
 
     async def on_POST(self, request):
         raise AuthError(403, "tokenrefresh is no longer supported.")
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index bef91a2d3e..ad598cefe0 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -31,7 +31,7 @@ class UserDirectorySearchRestServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(UserDirectorySearchRestServlet, self).__init__()
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.user_directory_handler = hs.get_user_directory_handler()
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index c560edbc59..d24a199318 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -29,7 +29,7 @@ class VersionsRestServlet(RestServlet):
     PATTERNS = [re.compile("^/_matrix/client/versions$")]
 
     def __init__(self, hs):
-        super(VersionsRestServlet, self).__init__()
+        super().__init__()
         self.config = hs.config
 
         # Calculate these once since they shouldn't change after start-up.
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 5db7f81c2d..f843f02454 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -35,7 +35,7 @@ class RemoteKey(DirectServeJsonResource):
 
     Supports individual GET APIs and a bulk query POST API.
 
-    Requsts:
+    Requests:
 
     GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
 
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index d2826374a7..7447eeaebe 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -80,7 +80,7 @@ class MediaFilePaths:
         self, server_name, file_id, width, height, content_type, method
     ):
         top_level_type, sub_type = content_type.split("/")
-        file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+        file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
             "remote_thumbnail",
             server_name,
@@ -92,6 +92,23 @@ class MediaFilePaths:
 
     remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
 
+    # Legacy path that was used to store thumbnails previously.
+    # Should be removed after some time, when most of the thumbnails are stored
+    # using the new path.
+    def remote_media_thumbnail_rel_legacy(
+        self, server_name, file_id, width, height, content_type
+    ):
+        top_level_type, sub_type = content_type.split("/")
+        file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+        return os.path.join(
+            "remote_thumbnail",
+            server_name,
+            file_id[0:2],
+            file_id[2:4],
+            file_id[4:],
+            file_name,
+        )
+
     def remote_media_thumbnail_dir(self, server_name, file_id):
         return os.path.join(
             self.base_path,
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9a1b7779f7..69f353d46f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -53,7 +53,7 @@ from .media_storage import MediaStorage
 from .preview_url_resource import PreviewUrlResource
 from .storage_provider import StorageProviderWrapper
 from .thumbnail_resource import ThumbnailResource
-from .thumbnailer import Thumbnailer
+from .thumbnailer import Thumbnailer, ThumbnailError
 from .upload_resource import UploadResource
 
 logger = logging.getLogger(__name__)
@@ -460,13 +460,30 @@ class MediaRepository:
         return t_byte_source
 
     async def generate_local_exact_thumbnail(
-        self, media_id, t_width, t_height, t_method, t_type, url_cache
-    ):
+        self,
+        media_id: str,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+        url_cache: str,
+    ) -> Optional[str]:
         input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(None, media_id, url_cache=url_cache)
         )
 
-        thumbnailer = Thumbnailer(input_path)
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s",
+                media_id,
+                t_method,
+                t_type,
+                e,
+            )
+            return None
+
         t_byte_source = await defer_to_thread(
             self.hs.get_reactor(),
             self._generate_thumbnail,
@@ -506,14 +523,36 @@ class MediaRepository:
 
             return output_path
 
+        # Could not generate thumbnail.
+        return None
+
     async def generate_remote_exact_thumbnail(
-        self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
-    ):
+        self,
+        server_name: str,
+        file_id: str,
+        media_id: str,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+    ) -> Optional[str]:
         input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(server_name, file_id, url_cache=False)
         )
 
-        thumbnailer = Thumbnailer(input_path)
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s",
+                media_id,
+                server_name,
+                t_method,
+                t_type,
+                e,
+            )
+            return None
+
         t_byte_source = await defer_to_thread(
             self.hs.get_reactor(),
             self._generate_thumbnail,
@@ -559,6 +598,9 @@ class MediaRepository:
 
             return output_path
 
+        # Could not generate thumbnail.
+        return None
+
     async def _generate_thumbnails(
         self,
         server_name: Optional[str],
@@ -590,7 +632,18 @@ class MediaRepository:
             FileInfo(server_name, file_id, url_cache=url_cache)
         )
 
-        thumbnailer = Thumbnailer(input_path)
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate thumbnails for remote media %s from %s using a method of %s and type of %s: %s",
+                media_id,
+                server_name,
+                media_type,
+                e,
+            )
+            return None
+
         m_width = thumbnailer.width
         m_height = thumbnailer.height
 
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 3a352b5631..5681677fc9 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -147,6 +147,20 @@ class MediaStorage:
         if os.path.exists(local_path):
             return FileResponder(open(local_path, "rb"))
 
+        # Fallback for paths without method names
+        # Should be removed in the future
+        if file_info.thumbnail and file_info.server_name:
+            legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+                server_name=file_info.server_name,
+                file_id=file_info.file_id,
+                width=file_info.thumbnail_width,
+                height=file_info.thumbnail_height,
+                content_type=file_info.thumbnail_type,
+            )
+            legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+            if os.path.exists(legacy_local_path):
+                return FileResponder(open(legacy_local_path, "rb"))
+
         for provider in self.storage_providers:
             res = await provider.fetch(path, file_info)  # type: Any
             if res:
@@ -170,6 +184,20 @@ class MediaStorage:
         if os.path.exists(local_path):
             return local_path
 
+        # Fallback for paths without method names
+        # Should be removed in the future
+        if file_info.thumbnail and file_info.server_name:
+            legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+                server_name=file_info.server_name,
+                file_id=file_info.file_id,
+                width=file_info.thumbnail_width,
+                height=file_info.thumbnail_height,
+                content_type=file_info.thumbnail_type,
+            )
+            legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+            if os.path.exists(legacy_local_path):
+                return legacy_local_path
+
         dirname = os.path.dirname(local_path)
         if not os.path.exists(dirname):
             os.makedirs(dirname)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index cd8c246594..dce6c4d168 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -102,7 +102,7 @@ for endpoint, globs in _oembed_globs.items():
         _oembed_patterns[re.compile(pattern)] = endpoint
 
 
-@attr.s
+@attr.s(slots=True)
 class OEmbedResult:
     # Either HTML content or URL must be provided.
     html = attr.ib(type=Optional[str])
@@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
             raise OEmbedError() from e
 
-    async def _download_url(self, url, user):
+    async def _download_url(self, url: str, user):
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
 
         # If this URL can be accessed via oEmbed, use that instead.
-        url_to_download = url
+        url_to_download = url  # type: Optional[str]
         oembed_url = self._get_oembed_url(url)
         if oembed_url:
             # The result might be a new URL to download, or it might be HTML content.
@@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
                 # FIXME: we should calculate a proper expiration based on the
                 # Cache-Control and Expire headers.  But for now, assume 1 hour.
                 expires = ONE_HOUR
-                etag = headers["ETag"][0] if "ETag" in headers else None
+                etag = (
+                    headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+                )
         else:
-            html_bytes = oembed_result.html.encode("utf-8")  # type: ignore
+            # we can only get here if we did an oembed request and have an oembed_result.html
+            assert oembed_result.html is not None
+            assert oembed_url is not None
+
+            html_bytes = oembed_result.html.encode("utf-8")
             with self.media_storage.store_into_file(file_info) as (f, fname, finish):
                 f.write(html_bytes)
                 await finish()
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a83535b97b..30421b663a 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,6 +16,7 @@
 
 import logging
 
+from synapse.api.errors import SynapseError
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
 from synapse.http.servlet import parse_integer, parse_string
 
@@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource):
             await respond_with_file(request, desired_type, file_path)
         else:
             logger.warning("Failed to generate thumbnail")
-            respond_404(request)
+            raise SynapseError(400, "Failed to generate thumbnail.")
 
     async def _select_or_generate_remote_thumbnail(
         self,
@@ -235,7 +236,7 @@ class ThumbnailResource(DirectServeJsonResource):
             await respond_with_file(request, desired_type, file_path)
         else:
             logger.warning("Failed to generate thumbnail")
-            respond_404(request)
+            raise SynapseError(400, "Failed to generate thumbnail.")
 
     async def _respond_remote_thumbnail(
         self, request, server_name, media_id, width, height, method, m_type
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index d681bf7bf0..32a8e4f960 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -15,7 +15,7 @@
 import logging
 from io import BytesIO
 
-from PIL import Image as Image
+from PIL import Image
 
 logger = logging.getLogger(__name__)
 
@@ -31,12 +31,22 @@ EXIF_TRANSPOSE_MAPPINGS = {
 }
 
 
+class ThumbnailError(Exception):
+    """An error occurred generating a thumbnail."""
+
+
 class Thumbnailer:
 
     FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
 
     def __init__(self, input_path):
-        self.image = Image.open(input_path)
+        try:
+            self.image = Image.open(input_path)
+        except OSError as e:
+            # If an error occurs opening the image, a thumbnail won't be able to
+            # be generated.
+            raise ThumbnailError from e
+
         self.width, self.height = self.image.size
         self.transpose_method = None
         try:
@@ -73,7 +83,7 @@ class Thumbnailer:
 
         Args:
             max_width: The largest possible width.
-            max_height: The larget possible height.
+            max_height: The largest possible height.
         """
 
         if max_width * self.height < max_height * self.width:
@@ -107,7 +117,7 @@ class Thumbnailer:
 
         Args:
             max_width: The largest possible width.
-            max_height: The larget possible height.
+            max_height: The largest possible height.
 
         Returns:
             BytesIO: the bytes of the encoded image ready to be written to disk
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index c10188a5d7..f6668fb5e3 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -13,10 +13,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.python import failure
 
-from synapse.api.errors import SynapseError
-from synapse.http.server import DirectServeHtmlResource, return_html_error
+from synapse.http.server import DirectServeHtmlResource
 
 
 class SAML2ResponseResource(DirectServeHtmlResource):
@@ -27,21 +25,15 @@ class SAML2ResponseResource(DirectServeHtmlResource):
     def __init__(self, hs):
         super().__init__()
         self._saml_handler = hs.get_saml_handler()
-        self._error_html_template = hs.config.saml2.saml2_error_html_template
 
     async def _async_render_GET(self, request):
         # We're not expecting any GET request on that resource if everything goes right,
         # but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
         # In this case, just tell the user that something went wrong and they should
         # try to authenticate again.
-        f = failure.Failure(
-            SynapseError(400, "Unexpected GET request on /saml2/authn_response")
+        self._saml_handler._render_error(
+            request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
         )
-        return_html_error(f, request, self._error_html_template)
 
     async def _async_render_POST(self, request):
-        try:
-            await self._saml_handler.handle_saml_response(request)
-        except Exception:
-            f = failure.Failure()
-            return_html_error(f, request, self._error_html_template)
+        await self._saml_handler.handle_saml_response(request)
diff --git a/synapse/rest/synapse/__init__.py b/synapse/rest/synapse/__init__.py
new file mode 100644
index 0000000000..c0b733488b
--- /dev/null
+++ b/synapse/rest/synapse/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
new file mode 100644
index 0000000000..c0b733488b
--- /dev/null
+++ b/synapse/rest/synapse/client/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
new file mode 100644
index 0000000000..9e4fbc0cbd
--- /dev/null
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
+
+from synapse.api.errors import ThreepidValidationError
+from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.http.server import DirectServeHtmlResource
+from synapse.http.servlet import parse_string
+from synapse.util.stringutils import assert_valid_client_secret
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
+    """Handles 3PID validation token submission
+
+    This resource gets mounted under /_synapse/client/password_reset/email/submit_token
+    """
+
+    isLeaf = 1
+
+    def __init__(self, hs: "HomeServer"):
+        """
+        Args:
+            hs: server
+        """
+        super().__init__()
+
+        self.clock = hs.get_clock()
+        self.store = hs.get_datastore()
+
+        self._local_threepid_handling_disabled_due_to_email_config = (
+            hs.config.local_threepid_handling_disabled_due_to_email_config
+        )
+        self._confirmation_email_template = (
+            hs.config.email_password_reset_template_confirmation_html
+        )
+        self._email_password_reset_template_success_html = (
+            hs.config.email_password_reset_template_success_html_content
+        )
+        self._failure_email_template = (
+            hs.config.email_password_reset_template_failure_html
+        )
+
+        # This resource should not be mounted if threepid behaviour is not LOCAL
+        assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+
+    async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
+        sid = parse_string(request, "sid", required=True)
+        token = parse_string(request, "token", required=True)
+        client_secret = parse_string(request, "client_secret", required=True)
+        assert_valid_client_secret(client_secret)
+
+        # Show a confirmation page, just in case someone accidentally clicked this link when
+        # they didn't mean to
+        template_vars = {
+            "sid": sid,
+            "token": token,
+            "client_secret": client_secret,
+        }
+        return (
+            200,
+            self._confirmation_email_template.render(**template_vars).encode("utf-8"),
+        )
+
+    async def _async_render_POST(self, request: Request) -> Tuple[int, bytes]:
+        sid = parse_string(request, "sid", required=True)
+        token = parse_string(request, "token", required=True)
+        client_secret = parse_string(request, "client_secret", required=True)
+
+        # Attempt to validate a 3PID session
+        try:
+            # Mark the session as valid
+            next_link = await self.store.validate_threepid_session(
+                sid, client_secret, token, self.clock.time_msec()
+            )
+
+            # Perform a 302 redirect if next_link is set
+            if next_link:
+                if next_link.startswith("file:///"):
+                    logger.warning(
+                        "Not redirecting to next_link as it is a local file: address"
+                    )
+                else:
+                    next_link_bytes = next_link.encode("utf-8")
+                    request.setHeader("Location", next_link_bytes)
+                    return (
+                        302,
+                        (
+                            b'You are being redirected to <a src="%s">%s</a>.'
+                            % (next_link_bytes, next_link_bytes)
+                        ),
+                    )
+
+            # Otherwise show the success template
+            html_bytes = self._email_password_reset_template_success_html.encode(
+                "utf-8"
+            )
+            status_code = 200
+        except ThreepidValidationError as e:
+            status_code = e.code
+
+            # Show a failure page with a reason
+            template_vars = {"failure_reason": e.msg}
+            html_bytes = self._failure_email_template.render(**template_vars).encode(
+                "utf-8"
+            )
+
+        return status_code, html_bytes
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c7e3015b5d..5a5ea39e01 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -25,7 +25,6 @@ from typing import (
     Sequence,
     Set,
     Union,
-    cast,
     overload,
 )
 
@@ -42,7 +41,7 @@ from synapse.logging.utils import log_function
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
-from synapse.types import Collection, MutableStateMap, StateMap
+from synapse.types import Collection, StateMap
 from synapse.util import Clock
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -472,10 +471,9 @@ class StateResolutionHandler:
     def __init__(self, hs):
         self.clock = hs.get_clock()
 
-        # dict of set of event_ids -> _StateCacheEntry.
-        self._state_cache = None
         self.resolve_linearizer = Linearizer(name="state_resolve_lock")
 
+        # dict of set of event_ids -> _StateCacheEntry.
         self._state_cache = ExpiringCache(
             cache_name="state_cache",
             clock=self.clock,
@@ -519,57 +517,28 @@ class StateResolutionHandler:
         Returns:
             The resolved state
         """
-        logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
-
         group_names = frozenset(state_groups_ids.keys())
 
         with (await self.resolve_linearizer.queue(group_names)):
-            if self._state_cache is not None:
-                cache = self._state_cache.get(group_names, None)
-                if cache:
-                    return cache
+            cache = self._state_cache.get(group_names, None)
+            if cache:
+                return cache
 
             logger.info(
-                "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
+                "Resolving state for %s with groups %s", room_id, list(group_names),
             )
 
             state_groups_histogram.observe(len(state_groups_ids))
 
-            # start by assuming we won't have any conflicted state, and build up the new
-            # state map by iterating through the state groups. If we discover a conflict,
-            # we give up and instead use `resolve_events_with_store`.
-            #
-            # XXX: is this actually worthwhile, or should we just let
-            # resolve_events_with_store do it?
-            new_state = {}  # type: MutableStateMap[str]
-            conflicted_state = False
-            for st in state_groups_ids.values():
-                for key, e_id in st.items():
-                    if key in new_state:
-                        conflicted_state = True
-                        break
-                    new_state[key] = e_id
-                if conflicted_state:
-                    break
-
-            if conflicted_state:
-                logger.info("Resolving conflicted state for %r", room_id)
-                with Measure(self.clock, "state._resolve_events"):
-                    # resolve_events_with_store returns a StateMap, but we can
-                    # treat it as a MutableStateMap as it is above. It isn't
-                    # actually mutated anymore (and is frozen in
-                    # _make_state_cache_entry below).
-                    new_state = cast(
-                        MutableStateMap,
-                        await resolve_events_with_store(
-                            self.clock,
-                            room_id,
-                            room_version,
-                            list(state_groups_ids.values()),
-                            event_map=event_map,
-                            state_res_store=state_res_store,
-                        ),
-                    )
+            with Measure(self.clock, "state._resolve_events"):
+                new_state = await resolve_events_with_store(
+                    self.clock,
+                    room_id,
+                    room_version,
+                    list(state_groups_ids.values()),
+                    event_map=event_map,
+                    state_res_store=state_res_store,
+                )
 
             # if the new state matches any of the input state groups, we can
             # use that state group again. Otherwise we will generate a state_id
@@ -579,8 +548,7 @@ class StateResolutionHandler:
             with Measure(self.clock, "state.create_group_ids"):
                 cache = _make_state_cache_entry(new_state, state_groups_ids)
 
-            if self._state_cache is not None:
-                self._state_cache[group_names] = cache
+            self._state_cache[group_names] = cache
 
             return cache
 
@@ -678,7 +646,7 @@ def resolve_events_with_store(
         )
 
 
-@attr.s
+@attr.s(slots=True)
 class StateResolutionStore:
     """Interface that allows state resolution algorithms to access the database
     in well defined way.
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 8e5d78f6f7..bbff3c8d5b 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -47,6 +47,9 @@ class Storage:
         # interfaces.
         self.main = stores.main
 
-        self.persistence = EventsPersistenceStorage(hs, stores)
         self.purge_events = PurgeEventsStorage(hs, stores)
         self.state = StateGroupStorage(hs, stores)
+
+        self.persistence = None
+        if stores.persist_events:
+            self.persistence = EventsPersistenceStorage(hs, stores)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ed8a9bffb1..79ec8f119d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -952,7 +952,7 @@ class DatabasePool:
         key_names: Collection[str],
         key_values: Collection[Iterable[Any]],
         value_names: Collection[str],
-        value_values: Iterable[Iterable[str]],
+        value_values: Iterable[Iterable[Any]],
     ) -> None:
         """
         Upsert, many times.
@@ -981,7 +981,7 @@ class DatabasePool:
         key_names: Iterable[str],
         key_values: Collection[Iterable[Any]],
         value_names: Collection[str],
-        value_values: Iterable[Iterable[str]],
+        value_values: Iterable[Iterable[Any]],
     ) -> None:
         """
         Upsert, many times, but without native UPSERT support or batching.
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 985b12df91..aa5d490624 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -75,7 +75,7 @@ class Databases:
 
                     # If we're on a process that can persist events also
                     # instantiate a `PersistEventsStore`
-                    if hs.config.worker.writers.events == hs.get_instance_name():
+                    if hs.get_instance_name() in hs.config.worker.writers.events:
                         persist_events = PersistEventsStore(hs, database, main)
 
                 if "state" in database_config.databases:
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 2ae2fbd5d7..0cb12f4c61 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -160,19 +160,25 @@ class DataStore(
         )
 
         if isinstance(self.database_engine, PostgresEngine):
+            # We set the `writers` to an empty list here as we don't care about
+            # missing updates over restarts, as we'll not have anything in our
+            # caches to invalidate. (This reduces the amount of writes to the DB
+            # that happen).
             self._cache_id_gen = MultiWriterIdGenerator(
                 db_conn,
                 database,
-                instance_name="master",
+                stream_name="caches",
+                instance_name=hs.get_instance_name(),
                 table="cache_invalidation_stream_by_instance",
                 instance_column="instance_name",
                 id_column="stream_id",
                 sequence_name="cache_invalidation_stream_seq",
+                writers=[],
             )
         else:
             self._cache_id_gen = None
 
-        super(DataStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._presence_on_startup = self._get_active_presence(db_conn)
 
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 4436b1a83d..ef81d73573 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -29,22 +29,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-class AccountDataWorkerStore(SQLBaseStore):
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
+class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
     """This is an abstract base class where subclasses must implement
     `get_max_account_data_stream_id` which can be called in the initializer.
     """
 
-    # This ABCMeta metaclass ensures that we cannot be instantiated without
-    # the abstract methods being implemented.
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, database: DatabasePool, db_conn, hs):
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
         )
 
-        super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     @abc.abstractmethod
     def get_max_account_data_stream_id(self):
@@ -315,7 +313,7 @@ class AccountDataStore(AccountDataWorkerStore):
             ],
         )
 
-        super(AccountDataStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream id for the private user data stream
@@ -341,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as room_account_data has a unique constraint
             # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
@@ -389,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as account_data has a unique constraint on
             # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 454c0bc50c..85f6b1e3fd 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -52,7 +52,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         )
         self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
 
-        super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_app_services(self):
         return self.services_cache
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index c2fc847fbc..239c7a949c 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -31,7 +31,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000
 
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
             "user_ips_device_index",
@@ -358,7 +358,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             name="client_ip_last_seen", keylen=4, max_entries=50000
         )
 
-        super(ClientIpStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.user_ips_max_age
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 0044433110..d42faa3f1f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -283,7 +283,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
             "device_inbox_stream_index",
@@ -313,7 +313,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(DeviceInboxStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) to the last stream_id that has been
         # deleted up to. This is so that we can no op deletions.
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
-        with await self._device_inbox_id_gen.get_next() as stream_id:
+        async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 txn, stream_id, local_messages_by_user_then_device
             )
 
-        with await self._device_inbox_id_gen.get_next() as stream_id:
+        async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index add4e3ea0e..fdf394c612 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
             THe new stream ID.
         """
 
-        with await self._device_list_id_gen.get_next() as stream_id:
+        async with self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore):
         }
 
     async def get_users_whose_devices_changed(
-        self, from_key: str, user_ids: Iterable[str]
+        self, from_key: int, user_ids: Iterable[str]
     ) -> Set[str]:
         """Get set of users whose devices have changed since `from_key` that
         are in the given list of user_ids.
@@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             The set of user_ids whose devices have changed since `from_key`
         """
-        from_key = int(from_key)
 
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
@@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     async def get_users_whose_signatures_changed(
-        self, user_id: str, from_key: str
+        self, user_id: str, from_key: int
     ) -> Set[str]:
         """Get the users who have new cross-signing signatures made by `user_id` since
         `from_key`.
@@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             A set of user IDs with updated signatures.
         """
-        from_key = int(from_key)
+
         if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
             sql = """
                 SELECT DISTINCT user_ids FROM user_signature_stream
@@ -702,7 +701,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
             "device_lists_stream_idx",
@@ -827,7 +826,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(DeviceStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
@@ -1094,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return
 
-        with await self._device_list_id_gen.get_next_mult(
+        async with self._device_list_id_gen.get_next_mult(
             len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
@@ -1109,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return stream_ids[-1]
 
         context = get_active_span_text_map()
-        with await self._device_list_id_gen.get_next_mult(
+        async with self._device_list_id_gen.get_next_mult(
             len(hosts) * len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index fba3098ea2..22e1ed15d0 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
     from synapse.handlers.e2e_keys import SignatureListItem
 
 
-@attr.s
+@attr.s(slots=True)
 class DeviceKeyLookupResult:
     """The type returned by get_e2e_device_keys_and_signatures"""
 
@@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key (dict): the key data
         """
 
-        with await self._cross_signing_id_gen.get_next() as stream_id:
+        async with self._cross_signing_id_gen.get_next() as stream_id:
             return await self.db_pool.runInteraction(
                 "add_e2e_cross_signing_key",
                 self._set_e2e_cross_signing_key_txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 0b69aa6a94..6d3689c09e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """
 
         if stream_ordering <= self.stream_ordering_month_ago:
-            raise StoreError(400, "stream_ordering too old")
+            raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
 
         sql = """
                 SELECT event_id FROM stream_ordering_to_exterm
@@ -600,7 +600,7 @@ class EventFederationStore(EventFederationWorkerStore):
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(EventFederationStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 5233ed83e2..62f1738732 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -68,7 +68,7 @@ def _deserialize_action(actions, is_highlight):
 
 class EventPushActionsWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
         self.stream_ordering_month_ago = None
@@ -661,7 +661,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(EventPushActionsStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
             self.EPA_HIGHLIGHT_INDEX,
@@ -969,7 +969,7 @@ def _action_has_highlight(actions):
     return False
 
 
-@attr.s
+@attr.s(slots=True)
 class _EventPushSummary:
     """Summary of pending event push actions for a given user in a given room.
     Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b3d27a2ee7..18def01f50 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,7 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
 
 import attr
 from prometheus_client import Counter
@@ -32,7 +32,7 @@ from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.iterutils import batch_iter
@@ -97,18 +97,21 @@ class PersistEventsStore:
         self.store = main_data_store
         self.database_engine = db.engine
         self._clock = hs.get_clock()
+        self._instance_name = hs.get_instance_name()
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
         # Ideally we'd move these ID gens here, unfortunately some other ID
         # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen = self.store._backfill_id_gen  # type: StreamIdGenerator
-        self._stream_id_gen = self.store._stream_id_gen  # type: StreamIdGenerator
+        self._backfill_id_gen = (
+            self.store._backfill_id_gen
+        )  # type: MultiWriterIdGenerator
+        self._stream_id_gen = self.store._stream_id_gen  # type: MultiWriterIdGenerator
 
         # This should only exist on instances that are configured to write
         assert (
-            hs.config.worker.writers.events == hs.get_instance_name()
+            hs.get_instance_name() in hs.config.worker.writers.events
         ), "Can only instantiate EventsStore on master"
 
     async def _persist_events_and_state_updates(
@@ -153,15 +156,15 @@ class PersistEventsStore:
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
         if backfilled:
-            stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
+            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
         else:
-            stream_ordering_manager = await self._stream_id_gen.get_next_mult(
+            stream_ordering_manager = self._stream_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
 
-        with stream_ordering_manager as stream_orderings:
+        async with stream_ordering_manager as stream_orderings:
             for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
@@ -213,7 +216,7 @@ class PersistEventsStore:
         Returns:
             Filtered event ids
         """
-        results = []
+        results = []  # type: List[str]
 
         def _get_events_which_are_prevs_txn(txn, batch):
             sql = """
@@ -631,7 +634,9 @@ class PersistEventsStore:
         )
 
     @classmethod
-    def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+    def _filter_events_and_contexts_for_duplicates(
+        cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Ensure that we don't have the same event twice.
 
         Pick the earliest non-outlier if there is one, else the earliest one.
@@ -641,7 +646,9 @@ class PersistEventsStore:
         Returns:
             list[(EventBase, EventContext)]: filtered list
         """
-        new_events_and_contexts = OrderedDict()
+        new_events_and_contexts = (
+            OrderedDict()
+        )  # type: OrderedDict[str, Tuple[EventBase, EventContext]]
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
             if prev_event_context:
@@ -655,7 +662,12 @@ class PersistEventsStore:
                 new_events_and_contexts[event.event_id] = (event, context)
         return list(new_events_and_contexts.values())
 
-    def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+    def _update_room_depths_txn(
+        self,
+        txn,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        backfilled: bool,
+    ):
         """Update min_depth for each room
 
         Args:
@@ -664,7 +676,7 @@ class PersistEventsStore:
                 we are persisting
             backfilled (bool): True if the events were backfilled
         """
-        depth_updates = {}
+        depth_updates = {}  # type: Dict[str, int]
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -800,6 +812,7 @@ class PersistEventsStore:
             table="events",
             values=[
                 {
+                    "instance_name": self._instance_name,
                     "stream_ordering": event.internal_metadata.stream_ordering,
                     "topological_ordering": event.depth,
                     "depth": event.depth,
@@ -1095,6 +1108,10 @@ class PersistEventsStore:
     def _store_room_members_txn(self, txn, events, backfilled):
         """Store a room member in the database.
         """
+
+        def str_or_none(val: Any) -> Optional[str]:
+            return val if isinstance(val, str) else None
+
         self.db_pool.simple_insert_many_txn(
             txn,
             table="room_memberships",
@@ -1105,8 +1122,8 @@ class PersistEventsStore:
                     "sender": event.user_id,
                     "room_id": event.room_id,
                     "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
+                    "display_name": str_or_none(event.content.get("displayname")),
+                    "avatar_url": str_or_none(event.content.get("avatar_url")),
                 }
                 for event in events
             ],
@@ -1436,7 +1453,7 @@ class PersistEventsStore:
 
         Forward extremities are handled when we first start persisting the events.
         """
-        events_by_room = {}
+        events_by_room = {}  # type: Dict[str, List[EventBase]]
         for ev in events:
             events_by_room.setdefault(ev.room_id, []).append(ev)
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e53c6373a8..5e4af2eb51 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -29,7 +29,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
     DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a7a73cc3d8..f95679ebc4 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -13,8 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from __future__ import division
-
 import itertools
 import logging
 import threading
@@ -42,7 +40,8 @@ from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import Collection, get_domain_from_id
 from synapse.util.caches.descriptors import Cache, cached
 from synapse.util.iterutils import batch_iter
@@ -76,29 +75,60 @@ class EventRedactBehaviour(Names):
 
 class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(EventsWorkerStore, self).__init__(database, db_conn, hs)
-
-        if hs.config.worker.writers.events == hs.get_instance_name():
-            # We are the process in charge of generating stream ids for events,
-            # so instantiate ID generators based on the database
-            self._stream_id_gen = StreamIdGenerator(
-                db_conn, "events", "stream_ordering",
+        super().__init__(database, db_conn, hs)
+
+        if isinstance(database.engine, PostgresEngine):
+            # If we're using Postgres than we can use `MultiWriterIdGenerator`
+            # regardless of whether this process writes to the streams or not.
+            self._stream_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="events",
+                instance_name=hs.get_instance_name(),
+                table="events",
+                instance_column="instance_name",
+                id_column="stream_ordering",
+                sequence_name="events_stream_seq",
+                writers=hs.config.worker.writers.events,
             )
-            self._backfill_id_gen = StreamIdGenerator(
-                db_conn,
-                "events",
-                "stream_ordering",
-                step=-1,
-                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+            self._backfill_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="backfill",
+                instance_name=hs.get_instance_name(),
+                table="events",
+                instance_column="instance_name",
+                id_column="stream_ordering",
+                sequence_name="events_backfill_stream_seq",
+                positive=False,
+                writers=hs.config.worker.writers.events,
             )
         else:
-            # Another process is in charge of persisting events and generating
-            # stream IDs: rely on the replication streams to let us know which
-            # IDs we can process.
-            self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
-            self._backfill_id_gen = SlavedIdTracker(
-                db_conn, "events", "stream_ordering", step=-1
-            )
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.events:
+                self._stream_id_gen = StreamIdGenerator(
+                    db_conn, "events", "stream_ordering",
+                )
+                self._backfill_id_gen = StreamIdGenerator(
+                    db_conn,
+                    "events",
+                    "stream_ordering",
+                    step=-1,
+                    extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+                )
+            else:
+                self._stream_id_gen = SlavedIdTracker(
+                    db_conn, "events", "stream_ordering"
+                )
+                self._backfill_id_gen = SlavedIdTracker(
+                    db_conn, "events", "stream_ordering", step=-1
+                )
 
         self._get_event_cache = Cache(
             "*getEvent*",
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ccfbb2135e..7218191965 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        with await self._group_updates_id_gen.get_next() as next_id:
+        async with self._group_updates_id_gen.get_next() as next_id:
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 86557d5512..cc538c5c10 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -17,12 +17,14 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 
+BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
+    "media_repository_drop_index_wo_method"
+)
+
 
 class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(MediaRepositoryBackgroundUpdateStore, self).__init__(
-            database, db_conn, hs
-        )
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
             update_name="local_media_repository_url_idx",
@@ -32,12 +34,65 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
             where_clause="url_cache IS NOT NULL",
         )
 
+        # The following the updates add the method to the unique constraint of
+        # the thumbnail databases. That fixes an issue, where thumbnails of the
+        # same resolution, but different methods could overwrite one another.
+        # This can happen with custom thumbnail configs or with dynamic thumbnailing.
+        self.db_pool.updates.register_background_index_update(
+            update_name="local_media_repository_thumbnails_method_idx",
+            index_name="local_media_repository_thumbn_media_id_width_height_method_key",
+            table="local_media_repository_thumbnails",
+            columns=[
+                "media_id",
+                "thumbnail_width",
+                "thumbnail_height",
+                "thumbnail_type",
+                "thumbnail_method",
+            ],
+            unique=True,
+        )
+
+        self.db_pool.updates.register_background_index_update(
+            update_name="remote_media_repository_thumbnails_method_idx",
+            index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key",
+            table="remote_media_cache_thumbnails",
+            columns=[
+                "media_origin",
+                "media_id",
+                "thumbnail_width",
+                "thumbnail_height",
+                "thumbnail_type",
+                "thumbnail_method",
+            ],
+            unique=True,
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD,
+            self._drop_media_index_without_method,
+        )
+
+    async def _drop_media_index_without_method(self, progress, batch_size):
+        def f(txn):
+            txn.execute(
+                "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+            )
+            txn.execute(
+                "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+            )
+
+        await self.db_pool.runInteraction("drop_media_indices_without_method", f)
+        await self.db_pool.updates._end_background_update(
+            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
+        )
+        return 1
+
 
 class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
         """Get the metadata for a local piece of media
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 1d793d3deb..e0cedd1aac 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -28,7 +28,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 class MonthlyActiveUsersWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
 
@@ -120,7 +120,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
 class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._limit_usage_by_mau = hs.config.limit_usage_by_mau
         self._mau_stats_only = hs.config.mau_stats_only
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c9f655dfb7..dbbb99cb95 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
 
 class PresenceStore(SQLBaseStore):
     async def update_presence(self, presence_states):
-        stream_ordering_manager = await self._presence_id_gen.get_next_mult(
+        stream_ordering_manager = self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
 
-        with stream_ordering_manager as stream_orderings:
+        async with stream_ordering_manager as stream_orderings:
             await self.db_pool.runInteraction(
                 "update_presence",
                 self._update_presence_txn,
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ea833829ae..d7a03cbf7d 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,6 +69,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         #     room_depth
         #     state_groups
         #     state_groups_state
+        #     destination_rooms
 
         # we will build a temporary table listing the events so that we don't
         # have to keep shovelling the list back and forth across the
@@ -336,6 +337,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         # and finally, the tables with an index on room_id (or no useful index)
         for table in (
             "current_state_events",
+            "destination_rooms",
             "event_backward_extremities",
             "event_forward_extremities",
             "event_json",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 0de802a86b..711d5aa23d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -13,11 +13,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import abc
 import logging
 from typing import List, Tuple, Union
 
+from synapse.api.errors import NotFoundError, StoreError
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -27,6 +27,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.pusher import PusherWorkerStore
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util import json_encoder
@@ -60,6 +61,8 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False):
     return rules
 
 
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
 class PushRulesWorkerStore(
     ApplicationServiceWorkerStore,
     ReceiptsWorkerStore,
@@ -67,17 +70,14 @@ class PushRulesWorkerStore(
     RoomMemberWorkerStore,
     EventsWorkerStore,
     SQLBaseStore,
+    metaclass=abc.ABCMeta,
 ):
     """This is an abstract base class where subclasses must implement
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
-    # This ABCMeta metaclass ensures that we cannot be instantiated without
-    # the abstract methods being implemented.
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
             self._push_rules_stream_id_gen = StreamIdGenerator(
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             if before or after:
@@ -540,6 +540,25 @@ class PushRuleStore(PushRulesWorkerStore):
                 },
             )
 
+        # ensure we have a push_rules_enable row
+        # enabledness defaults to true
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = """
+                INSERT INTO push_rules_enable (id, user_name, rule_id, enabled)
+                VALUES (?, ?, ?, ?)
+                ON CONFLICT DO NOTHING
+            """
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = """
+                INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled)
+                VALUES (?, ?, ?, ?)
+            """
+        else:
+            raise RuntimeError("Unknown database engine")
+
+        new_enable_id = self._push_rules_enable_id_gen.get_next()
+        txn.execute(sql, (new_enable_id, user_id, rule_id, 1))
+
     async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
         """
         Delete a push rule. Args specify the row to be deleted and can be
@@ -552,6 +571,12 @@ class PushRuleStore(PushRulesWorkerStore):
         """
 
         def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+            # we don't use simple_delete_one_txn because that would fail if the
+            # user did not have a push_rule_enable row.
+            self.db_pool.simple_delete_txn(
+                txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}
+            )
+
             self.db_pool.simple_delete_one_txn(
                 txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
             )
@@ -560,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
@@ -570,10 +595,29 @@ class PushRuleStore(PushRulesWorkerStore):
                 event_stream_ordering,
             )
 
-    async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
-            event_stream_ordering = self._stream_id_gen.get_current_token()
+    async def set_push_rule_enabled(
+        self, user_id: str, rule_id: str, enabled: bool, is_default_rule: bool
+    ) -> None:
+        """
+        Sets the `enabled` state of a push rule.
 
+        Args:
+            user_id: the user ID of the user who wishes to enable/disable the rule
+                e.g. '@tina:example.org'
+            rule_id: the full rule ID of the rule to be enabled/disabled
+                e.g. 'global/override/.m.rule.roomnotif'
+                  or 'global/override/myCustomRule'
+            enabled: True if the rule is to be enabled, False if it is to be
+                disabled
+            is_default_rule: True if and only if this is a server-default rule.
+                This skips the check for existence (as only user-created rules
+                are always stored in the database `push_rules` table).
+
+        Raises:
+            NotFoundError if the rule does not exist.
+        """
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
             await self.db_pool.runInteraction(
                 "_set_push_rule_enabled_txn",
                 self._set_push_rule_enabled_txn,
@@ -582,12 +626,47 @@ class PushRuleStore(PushRulesWorkerStore):
                 user_id,
                 rule_id,
                 enabled,
+                is_default_rule,
             )
 
     def _set_push_rule_enabled_txn(
-        self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+        self,
+        txn,
+        stream_id,
+        event_stream_ordering,
+        user_id,
+        rule_id,
+        enabled,
+        is_default_rule,
     ):
         new_id = self._push_rules_enable_id_gen.get_next()
+
+        if not is_default_rule:
+            # first check it exists; we need to lock for key share so that a
+            # transaction that deletes the push rule will conflict with this one.
+            # We also need a push_rule_enable row to exist for every push_rules
+            # row, otherwise it is possible to simultaneously delete a push rule
+            # (that has no _enable row) and enable it, resulting in a dangling
+            # _enable row. To solve this: we either need to use SERIALISABLE or
+            # ensure we always have a push_rule_enable row for every push_rule
+            # row. We chose the latter.
+            for_key_share = "FOR KEY SHARE"
+            if not isinstance(self.database_engine, PostgresEngine):
+                # For key share is not applicable/available on SQLite
+                for_key_share = ""
+            sql = (
+                """
+                SELECT 1 FROM push_rules
+                WHERE user_name = ? AND rule_id = ?
+                %s
+            """
+                % for_key_share
+            )
+            txn.execute(sql, (user_id, rule_id))
+            if txn.fetchone() is None:
+                # needed to set NOT_FOUND code.
+                raise NotFoundError("Push rule does not exist.")
+
         self.db_pool.simple_upsert_txn(
             txn,
             "push_rules_enable",
@@ -606,8 +685,30 @@ class PushRuleStore(PushRulesWorkerStore):
         )
 
     async def set_push_rule_actions(
-        self, user_id, rule_id, actions, is_default_rule
+        self,
+        user_id: str,
+        rule_id: str,
+        actions: List[Union[dict, str]],
+        is_default_rule: bool,
     ) -> None:
+        """
+        Sets the `actions` state of a push rule.
+
+        Will throw NotFoundError if the rule does not exist; the Code for this
+        is NOT_FOUND.
+
+        Args:
+            user_id: the user ID of the user who wishes to enable/disable the rule
+                e.g. '@tina:example.org'
+            rule_id: the full rule ID of the rule to be enabled/disabled
+                e.g. 'global/override/.m.rule.roomnotif'
+                  or 'global/override/myCustomRule'
+            actions: A list of actions (each action being a dict or string),
+                e.g. ["notify", {"set_tweak": "highlight", "value": false}]
+            is_default_rule: True if and only if this is a server-default rule.
+                This skips the check for existence (as only user-created rules
+                are always stored in the database `push_rules` table).
+        """
         actions_json = json_encoder.encode(actions)
 
         def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -629,12 +730,19 @@ class PushRuleStore(PushRulesWorkerStore):
                     update_stream=False,
                 )
             else:
-                self.db_pool.simple_update_one_txn(
-                    txn,
-                    "push_rules",
-                    {"user_name": user_id, "rule_id": rule_id},
-                    {"actions": actions_json},
-                )
+                try:
+                    self.db_pool.simple_update_one_txn(
+                        txn,
+                        "push_rules",
+                        {"user_name": user_id, "rule_id": rule_id},
+                        {"actions": actions_json},
+                    )
+                except StoreError as serr:
+                    if serr.code == 404:
+                        # this sets the NOT_FOUND error Code
+                        raise NotFoundError("Push rule does not exist")
+                    else:
+                        raise
 
             self._insert_push_rules_update_txn(
                 txn,
@@ -646,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c388468273..df8609b97b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
         last_stream_ordering,
         profile_tag="",
     ) -> None:
-        with await self._pushers_id_gen.get_next() as stream_id:
+        async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so simple_upsert will retry
             await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
                 },
             )
 
-        with await self._pushers_id_gen.get_next() as stream_id:
+        async with self._pushers_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "delete_pusher", delete_pusher_txn, stream_id
             )
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 4a0d5a320e..c79ddff680 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -31,17 +31,15 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-class ReceiptsWorkerStore(SQLBaseStore):
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
+class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
     """This is an abstract base class where subclasses must implement
     `get_max_receipt_stream_id` which can be called in the initializer.
     """
 
-    # This ABCMeta metaclass ensures that we cannot be instantiated without
-    # the abstract methods being implemented.
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -388,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(ReceiptsStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
@@ -526,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "insert_receipt_conv", graph_to_linear
             )
 
-        with await self._receipts_id_gen.get_next() as stream_id:
+        async with self._receipts_id_gen.get_next() as stream_id:
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 01f20c03c2..33825e8949 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
 
 class RegistrationWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.config = hs.config
         self.clock = hs.get_clock()
@@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_expiration_ts_for_user",
         )
 
+    async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+        """
+        Returns whether an user account is expired.
+
+        Args:
+            user_id: The user's ID
+            current_ts: The current timestamp
+
+        Returns:
+            Whether the user account has expired
+        """
+        expiration_ts = await self.get_expiration_ts_for_user(user_id)
+        return expiration_ts is not None and current_ts >= expiration_ts
+
     async def set_account_validity_for_user(
         self,
         user_id: str,
@@ -764,7 +778,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.clock = hs.get_clock()
         self.config = hs.config
@@ -892,7 +906,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
 class RegistrationStore(RegistrationBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._account_validity = hs.config.account_validity
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 717df97301..3c7630857f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -69,7 +69,7 @@ class RoomSortOrder(Enum):
 
 class RoomWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.config = hs.config
 
@@ -104,7 +104,8 @@ class RoomWorkerStore(SQLBaseStore):
                   curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
                   rooms.creator, state.encryption, state.is_federatable AS federatable,
                   rooms.is_public AS public, state.join_rules, state.guest_access,
-                  state.history_visibility, curr.current_state_events AS state_events
+                  state.history_visibility, curr.current_state_events AS state_events,
+                  state.avatar, state.topic
                 FROM rooms
                 LEFT JOIN room_stats_state state USING (room_id)
                 LEFT JOIN room_stats_current curr USING (room_id)
@@ -862,7 +863,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
     ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.config = hs.config
 
@@ -1073,7 +1074,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.config = hs.config
 
@@ -1136,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         },
                     )
 
-            with await self._public_room_id_gen.get_next() as next_id:
+            async with self._public_room_id_gen.get_next() as next_id:
                 await self.db_pool.runInteraction(
                     "store_room_txn", store_room_txn, next_id
                 )
@@ -1203,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with await self._public_room_id_gen.get_next() as next_id:
+        async with self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
@@ -1283,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with await self._public_room_id_gen.get_next() as next_id:
+        async with self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
@@ -1327,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             desc="add_event_report",
         )
 
+    async def get_event_reports_paginate(
+        self,
+        start: int,
+        limit: int,
+        direction: str = "b",
+        user_id: Optional[str] = None,
+        room_id: Optional[str] = None,
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Retrieve a paginated list of event reports
+
+        Args:
+            start: event offset to begin the query from
+            limit: number of rows to retrieve
+            direction: Whether to fetch the most recent first (`"b"`) or the
+                oldest first (`"f"`)
+            user_id: search for user_id. Ignored if user_id is None
+            room_id: search for room_id. Ignored if room_id is None
+        Returns:
+            event_reports: json list of event reports
+            count: total number of event reports matching the filter criteria
+        """
+
+        def _get_event_reports_paginate_txn(txn):
+            filters = []
+            args = []
+
+            if user_id:
+                filters.append("er.user_id LIKE ?")
+                args.extend(["%" + user_id + "%"])
+            if room_id:
+                filters.append("er.room_id LIKE ?")
+                args.extend(["%" + room_id + "%"])
+
+            if direction == "b":
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+            sql = """
+                SELECT COUNT(*) as total_event_reports
+                FROM event_reports AS er
+                {}
+                """.format(
+                where_clause
+            )
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.reason,
+                    er.content,
+                    events.sender,
+                    room_aliases.room_alias,
+                    event_json.json AS event_json
+                FROM event_reports AS er
+                LEFT JOIN room_aliases
+                    ON room_aliases.room_id = er.room_id
+                JOIN events
+                    ON events.event_id = er.event_id
+                JOIN event_json
+                    ON event_json.event_id = er.event_id
+                {where_clause}
+                ORDER BY er.received_ts {order}
+                LIMIT ?
+                OFFSET ?
+            """.format(
+                where_clause=where_clause, order=order,
+            )
+
+            args += [limit, start]
+            txn.execute(sql, args)
+            event_reports = self.db_pool.cursor_to_dict(txn)
+
+            if count > 0:
+                for row in event_reports:
+                    try:
+                        row["content"] = db_to_json(row["content"])
+                        row["event_json"] = db_to_json(row["event_json"])
+                    except Exception:
+                        continue
+
+            return event_reports, count
+
+        return await self.db_pool.runInteraction(
+            "get_event_reports_paginate", _get_event_reports_paginate_txn
+        )
+
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 91a8b43da3..86ffe2479e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
 
@@ -37,7 +36,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, PersistedEventPosition, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -55,7 +54,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 class RoomMemberWorkerStore(EventsWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         # Is the current_state_events.membership up to date? Or is the
         # background update still running?
@@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # for rooms the server is participating in.
         if self._current_state_events_membership_up_to_date:
             sql = """
-                SELECT room_id, e.stream_ordering
+                SELECT room_id, e.instance_name, e.stream_ordering
                 FROM current_state_events AS c
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
@@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
         else:
             sql = """
-                SELECT room_id, e.stream_ordering
+                SELECT room_id, e.instance_name, e.stream_ordering
                 FROM current_state_events AS c
                 INNER JOIN room_memberships AS m USING (room_id, event_id)
                 INNER JOIN events AS e USING (room_id, event_id)
@@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
 
         txn.execute(sql, (user_id, Membership.JOIN))
-        return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+        return frozenset(
+            GetRoomsForUserWithStreamOrdering(
+                room_id, PersistedEventPosition(instance, stream_id)
+            )
+            for room_id, instance, stream_id in txn
+        )
 
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
@@ -819,7 +823,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
         )
@@ -973,7 +977,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomMemberStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     async def forget(self, user_id: str, room_id: str) -> None:
         """Indicate that user_id wishes to discard history for room_id."""
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
new file mode 100644
index 0000000000..b64926e9c9
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -0,0 +1,33 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is the postgres specific migration modifying the table with a background
+ * migration.
+ */
+
+-- add new index that includes method to local media
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('local_media_repository_thumbnails_method_idx', '{}');
+
+-- add new index that includes method to remote media
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+  ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+
+-- drop old index
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+  ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
new file mode 100644
index 0000000000..1d0c04b53a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
@@ -0,0 +1,44 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is a sqlite specific migration, since sqlite can't modify the unique
+ * constraint of a table without recreating it.
+ */
+
+CREATE TABLE local_media_repository_thumbnails_new ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO local_media_repository_thumbnails_new
+    SELECT media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method, thumbnail_length
+    FROM local_media_repository_thumbnails;
+
+DROP TABLE local_media_repository_thumbnails;
+
+ALTER TABLE local_media_repository_thumbnails_new RENAME TO local_media_repository_thumbnails;
+
+CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id);
+
+
+
+CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails_new ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO remote_media_cache_thumbnails_new
+    SELECT media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_method, thumbnail_type, thumbnail_length, filesystem_id
+    FROM remote_media_cache_thumbnails;
+
+DROP TABLE remote_media_cache_thumbnails;
+
+ALTER TABLE remote_media_cache_thumbnails_new RENAME TO remote_media_cache_thumbnails;
diff --git a/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
new file mode 100644
index 0000000000..847aebd85e
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+  Delete stuck 'enabled' bits that correspond to deleted or non-existent push rules.
+  We ignore rules that are server-default rules because they are not defined
+  in the `push_rules` table.
+**/
+
+DELETE FROM push_rules_enable WHERE
+  rule_id NOT LIKE 'global/%/.m.rule.%'
+  AND NOT EXISTS (
+    SELECT 1 FROM push_rules
+    WHERE push_rules.user_name = push_rules_enable.user_name
+      AND push_rules.rule_id = push_rules_enable.rule_id
+  );
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
new file mode 100644
index 0000000000..98ff76d709
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE events ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
new file mode 100644
index 0000000000..97c1e6a0c5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
+
+SELECT setval('events_stream_seq', (
+    SELECT COALESCE(MAX(stream_ordering), 1) FROM events
+));
+
+CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+
+SELECT setval('events_backfill_stream_seq', (
+    SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+));
diff --git a/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
new file mode 100644
index 0000000000..ebfbed7925
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
@@ -0,0 +1,42 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This schema delta alters the schema to enable 'catching up' remote homeservers
+-- after there has been a connectivity problem for any reason.
+
+-- This stores, for each (destination, room) pair, the stream_ordering of the
+-- latest event for that destination.
+CREATE TABLE IF NOT EXISTS destination_rooms (
+  -- the destination in question.
+  destination TEXT NOT NULL REFERENCES destinations (destination),
+  -- the ID of the room in question
+  room_id TEXT NOT NULL REFERENCES rooms (room_id),
+  -- the stream_ordering of the event
+  stream_ordering BIGINT NOT NULL,
+  PRIMARY KEY (destination, room_id)
+  -- We don't declare a foreign key on stream_ordering here because that'd mean
+  -- we'd need to either maintain an index (expensive) or do a table scan of
+  -- destination_rooms whenever we delete an event (also potentially expensive).
+  -- In addition to that, a foreign key on stream_ordering would be redundant
+  -- as this row doesn't need to refer to a specific event; if the event gets
+  -- deleted then it doesn't affect the validity of the stream_ordering here.
+);
+
+-- This index is needed to make it so that a deletion of a room (in the rooms
+-- table) can be efficient, as otherwise a table scan would need to be performed
+-- to check that no destination_rooms rows point to the room to be deleted.
+-- Also: it makes it efficient to delete all the entries for a given room ID,
+-- such as when purging a room.
+CREATE INDEX IF NOT EXISTS destination_rooms_room_id
+    ON destination_rooms (room_id);
diff --git a/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
new file mode 100644
index 0000000000..55f5d0f732
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This delta file fixes a regression introduced by 58/12room_stats.sql, removing the hacky
+-- populate_stats_process_rooms_2 background job and restores the functionality under the
+-- original name.
+-- See https://github.com/matrix-org/synapse/issues/8238 for details
+
+DELETE FROM background_updates WHERE update_name = 'populate_stats_process_rooms';
+UPDATE background_updates SET update_name = 'populate_stats_process_rooms'
+    WHERE update_name = 'populate_stats_process_rooms_2';
diff --git a/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
new file mode 100644
index 0000000000..a67aa5e500
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
@@ -0,0 +1,21 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This column tracks the stream_ordering of the event that was most recently
+-- successfully transmitted to the destination.
+-- A value of NULL means that we have not sent an event successfully yet
+-- (at least, not since the introduction of this column).
+ALTER TABLE destinations
+    ADD COLUMN last_successful_stream_ordering BIGINT;
diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
new file mode 100644
index 0000000000..985fd949a2
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stream_positions (
+    stream_name TEXT NOT NULL,
+    instance_name TEXT NOT NULL,
+    stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name);
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index f01cf2fd02..e34fce6281 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -89,7 +89,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         if not hs.config.enable_search:
             return
@@ -342,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
 class SearchStore(SearchBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SearchStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     async def search_msgs(self, room_ids, search_term, keys):
         """Performs a full text search over events with given keys.
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5c6168e301..3c1e33819b 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -56,7 +56,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     async def get_room_version(self, room_id: str) -> RoomVersion:
         """Get the room_version of a given room
@@ -320,7 +320,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
     DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
 
@@ -506,4 +506,4 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(StateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 55a250ef06..5beb302be3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -61,7 +61,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
 
 class StatsStore(StateDeltasStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(StatsStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
         self.clock = self.hs.get_clock()
@@ -74,9 +74,6 @@ class StatsStore(StateDeltasStore):
             "populate_stats_process_rooms", self._populate_stats_process_rooms
         )
         self.db_pool.updates.register_background_update_handler(
-            "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2
-        )
-        self.db_pool.updates.register_background_update_handler(
             "populate_stats_process_users", self._populate_stats_process_users
         )
         # we no longer need to perform clean-up, but we will give ourselves
@@ -148,31 +145,10 @@ class StatsStore(StateDeltasStore):
         return len(users_to_work_on)
 
     async def _populate_stats_process_rooms(self, progress, batch_size):
-        """
-        This was a background update which regenerated statistics for rooms.
-
-        It has been replaced by StatsStore._populate_stats_process_rooms_2. This background
-        job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure
-        someone upgrading from <v1.0.0, this background task has been turned into a no-op
-        so that the potentially expensive task is not run twice.
-
-        Further context: https://github.com/matrix-org/synapse/pull/7977
-        """
-        await self.db_pool.updates._end_background_update(
-            "populate_stats_process_rooms"
-        )
-        return 1
-
-    async def _populate_stats_process_rooms_2(self, progress, batch_size):
-        """
-        This is a background update which regenerates statistics for rooms.
-
-        It replaces StatsStore._populate_stats_process_rooms. See its docstring for the
-        reasoning.
-        """
+        """This is a background update which regenerates statistics for rooms."""
         if not self.stats_enabled:
             await self.db_pool.updates._end_background_update(
-                "populate_stats_process_rooms_2"
+                "populate_stats_process_rooms"
             )
             return 1
 
@@ -189,13 +165,13 @@ class StatsStore(StateDeltasStore):
             return [r for r, in txn]
 
         rooms_to_work_on = await self.db_pool.runInteraction(
-            "populate_stats_rooms_2_get_batch", _get_next_batch
+            "populate_stats_rooms_get_batch", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
             await self.db_pool.updates._end_background_update(
-                "populate_stats_process_rooms_2"
+                "populate_stats_process_rooms"
             )
             return 1
 
@@ -204,9 +180,9 @@ class StatsStore(StateDeltasStore):
             progress["last_room_id"] = room_id
 
         await self.db_pool.runInteraction(
-            "_populate_stats_process_rooms_2",
+            "_populate_stats_process_rooms",
             self.db_pool.updates._background_update_progress_txn,
-            "populate_stats_process_rooms_2",
+            "populate_stats_process_rooms",
             progress,
         )
 
@@ -234,6 +210,7 @@ class StatsStore(StateDeltasStore):
         * topic
         * avatar
         * canonical_alias
+        * guest_access
 
         A is_federatable key can also be included with a boolean value.
 
@@ -258,6 +235,7 @@ class StatsStore(StateDeltasStore):
             "topic",
             "avatar",
             "canonical_alias",
+            "guest_access",
         ):
             field = fields.get(col, sentinel)
             if field is not sentinel and (not isinstance(field, str) or "\0" in field):
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index db20a3db30..92e96468b4 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -79,8 +79,8 @@ _EventDictReturn = namedtuple(
 def generate_pagination_where_clause(
     direction: str,
     column_names: Tuple[str, str],
-    from_token: Optional[Tuple[int, int]],
-    to_token: Optional[Tuple[int, int]],
+    from_token: Optional[Tuple[Optional[int], int]],
+    to_token: Optional[Tuple[Optional[int], int]],
     engine: BaseDatabaseEngine,
 ) -> str:
     """Creates an SQL expression to bound the columns by the pagination
@@ -259,16 +259,14 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     return " AND ".join(clauses), args
 
 
-class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     """This is an abstract base class where subclasses must implement
     `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
     which can be called in the initializer.
     """
 
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
-        super(StreamWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
         self._send_federation = hs.should_send_federation()
@@ -310,11 +308,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     async def get_room_events_stream_for_rooms(
         self,
         room_ids: Collection[str],
-        from_key: str,
-        to_key: str,
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
         limit: int = 0,
         order: str = "DESC",
-    ) -> Dict[str, Tuple[List[EventBase], str]]:
+    ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
@@ -333,9 +331,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 - list of recent events in the room
                 - stream ordering key for the start of the chunk of events returned.
         """
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-
-        room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
+        room_ids = self._events_stream_cache.get_entities_changed(
+            room_ids, from_key.stream
+        )
 
         if not room_ids:
             return {}
@@ -364,16 +362,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return results
 
     def get_rooms_that_changed(
-        self, room_ids: Collection[str], from_key: str
+        self, room_ids: Collection[str], from_key: RoomStreamToken
     ) -> Set[str]:
         """Given a list of rooms and a token, return rooms where there may have
         been changes.
-
-        Args:
-            room_ids
-            from_key: The room_key portion of a StreamToken
         """
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
+        from_id = from_key.stream
         return {
             room_id
             for room_id in room_ids
@@ -383,11 +377,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     async def get_room_events_stream_for_room(
         self,
         room_id: str,
-        from_key: str,
-        to_key: str,
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
         limit: int = 0,
         order: str = "DESC",
-    ) -> Tuple[List[EventBase], str]:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
@@ -408,8 +402,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if from_key == to_key:
             return [], from_key
 
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+        from_id = from_key.stream
+        to_id = to_key.stream
 
         has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
 
@@ -441,7 +435,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             ret.reverse()
 
         if rows:
-            key = "s%d" % min(r.stream_ordering for r in rows)
+            key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
         else:
             # Assume we didn't get anything because there was nothing to
             # get.
@@ -450,10 +444,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return ret, key
 
     async def get_membership_changes_for_user(
-        self, user_id: str, from_key: str, to_key: str
+        self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+        from_id = from_key.stream
+        to_id = to_key.stream
 
         if from_key == to_key:
             return []
@@ -491,8 +485,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return ret
 
     async def get_recent_events_for_room(
-        self, room_id: str, limit: int, end_token: str
-    ) -> Tuple[List[EventBase], str]:
+        self, room_id: str, limit: int, end_token: RoomStreamToken
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Get the most recent events in the room in topological ordering.
 
         Args:
@@ -518,8 +512,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return (events, token)
 
     async def get_recent_event_ids_for_room(
-        self, room_id: str, limit: int, end_token: str
-    ) -> Tuple[List[_EventDictReturn], str]:
+        self, room_id: str, limit: int, end_token: RoomStreamToken
+    ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
         """Get the most recent events in the room in topological ordering.
 
         Args:
@@ -535,8 +529,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if limit == 0:
             return [], end_token
 
-        end_token = RoomStreamToken.parse(end_token)
-
         rows, token = await self.db_pool.runInteraction(
             "get_recent_event_ids_for_room",
             self._paginate_room_events_txn,
@@ -619,17 +611,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             allow_none=allow_none,
         )
 
-    async def get_stream_token_for_event(self, event_id: str) -> str:
+    async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
         """The stream token for an event
         Args:
             event_id: The id of the event to look up a stream token for.
         Raises:
             StoreError if the event wasn't in the database.
         Returns:
-            A "s%d" stream token.
+            A stream token.
         """
         stream_id = await self.get_stream_id_for_event(event_id)
-        return "s%d" % (stream_id,)
+        return RoomStreamToken(None, stream_id)
 
     async def get_topological_token_for_event(self, event_id: str) -> str:
         """The stream token for an event
@@ -951,7 +943,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         direction: str = "b",
         limit: int = -1,
         event_filter: Optional[Filter] = None,
-    ) -> Tuple[List[_EventDictReturn], str]:
+    ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
         """Returns list of events before or after a given token.
 
         Args:
@@ -986,8 +978,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token,
-            to_token=to_token,
+            from_token=from_token.as_tuple(),
+            to_token=to_token.as_tuple() if to_token else None,
             engine=self.database_engine,
         )
 
@@ -1051,17 +1043,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             # TODO (erikj): We should work out what to do here instead.
             next_token = to_token if to_token else from_token
 
-        return rows, str(next_token)
+        return rows, next_token
 
     async def paginate_room_events(
         self,
         room_id: str,
-        from_key: str,
-        to_key: Optional[str] = None,
+        from_key: RoomStreamToken,
+        to_key: Optional[RoomStreamToken] = None,
         direction: str = "b",
         limit: int = -1,
         event_filter: Optional[Filter] = None,
-    ) -> Tuple[List[EventBase], str]:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Returns list of events before or after a given token.
 
         Args:
@@ -1080,10 +1072,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             and `to_key`).
         """
 
-        from_key = RoomStreamToken.parse(from_key)
-        if to_key:
-            to_key = RoomStreamToken.parse(to_key)
-
         rows, token = await self.db_pool.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 96ffe26cc9..9f120d3cb6 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
             )
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
             txn.execute(sql, (user_id, room_id, tag))
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 5b31aab700..97aed1500e 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -15,13 +15,14 @@
 
 import logging
 from collections import namedtuple
-from typing import Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 from synapse.util.caches.expiringcache import ExpiringCache
 
@@ -47,7 +48,7 @@ class TransactionStore(SQLBaseStore):
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(TransactionStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
 
@@ -164,7 +165,9 @@ class TransactionStore(SQLBaseStore):
             allow_none=True,
         )
 
-        if result and result["retry_last_ts"] > 0:
+        # check we have a row and retry_last_ts is not null or zero
+        # (retry_last_ts can't be negative)
+        if result and result["retry_last_ts"]:
             return result
         else:
             return None
@@ -215,6 +218,7 @@ class TransactionStore(SQLBaseStore):
                         retry_interval = EXCLUDED.retry_interval
                     WHERE
                         EXCLUDED.retry_interval = 0
+                        OR destinations.retry_interval IS NULL
                         OR destinations.retry_interval < EXCLUDED.retry_interval
             """
 
@@ -246,7 +250,11 @@ class TransactionStore(SQLBaseStore):
                     "retry_interval": retry_interval,
                 },
             )
-        elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
+        elif (
+            retry_interval == 0
+            or prev_row["retry_interval"] is None
+            or prev_row["retry_interval"] < retry_interval
+        ):
             self.db_pool.simple_update_one_txn(
                 txn,
                 "destinations",
@@ -273,3 +281,196 @@ class TransactionStore(SQLBaseStore):
         await self.db_pool.runInteraction(
             "_cleanup_transactions", _cleanup_transactions_txn
         )
+
+    async def store_destination_rooms_entries(
+        self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+    ) -> None:
+        """
+        Updates or creates `destination_rooms` entries in batch for a single event.
+
+        Args:
+            destinations: list of destinations
+            room_id: the room_id of the event
+            stream_ordering: the stream_ordering of the event
+        """
+
+        return await self.db_pool.runInteraction(
+            "store_destination_rooms_entries",
+            self._store_destination_rooms_entries_txn,
+            destinations,
+            room_id,
+            stream_ordering,
+        )
+
+    def _store_destination_rooms_entries_txn(
+        self,
+        txn: LoggingTransaction,
+        destinations: Iterable[str],
+        room_id: str,
+        stream_ordering: int,
+    ) -> None:
+
+        # ensure we have a `destinations` row for this destination, as there is
+        # a foreign key constraint.
+        if isinstance(self.database_engine, PostgresEngine):
+            q = """
+                INSERT INTO destinations (destination)
+                    VALUES (?)
+                    ON CONFLICT DO NOTHING;
+            """
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            q = """
+                INSERT OR IGNORE INTO destinations (destination)
+                    VALUES (?);
+            """
+        else:
+            raise RuntimeError("Unknown database engine")
+
+        txn.execute_batch(q, ((destination,) for destination in destinations))
+
+        rows = [(destination, room_id) for destination in destinations]
+
+        self.db_pool.simple_upsert_many_txn(
+            txn,
+            "destination_rooms",
+            ["destination", "room_id"],
+            rows,
+            ["stream_ordering"],
+            [(stream_ordering,)] * len(rows),
+        )
+
+    async def get_destination_last_successful_stream_ordering(
+        self, destination: str
+    ) -> Optional[int]:
+        """
+        Gets the stream ordering of the PDU most-recently successfully sent
+        to the specified destination, or None if this information has not been
+        tracked yet.
+
+        Args:
+            destination: the destination to query
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            "destinations",
+            {"destination": destination},
+            "last_successful_stream_ordering",
+            allow_none=True,
+            desc="get_last_successful_stream_ordering",
+        )
+
+    async def set_destination_last_successful_stream_ordering(
+        self, destination: str, last_successful_stream_ordering: int
+    ) -> None:
+        """
+        Marks that we have successfully sent the PDUs up to and including the
+        one specified.
+
+        Args:
+            destination: the destination we have successfully sent to
+            last_successful_stream_ordering: the stream_ordering of the most
+                recent successfully-sent PDU
+        """
+        return await self.db_pool.simple_upsert(
+            "destinations",
+            keyvalues={"destination": destination},
+            values={"last_successful_stream_ordering": last_successful_stream_ordering},
+            desc="set_last_successful_stream_ordering",
+        )
+
+    async def get_catch_up_room_event_ids(
+        self, destination: str, last_successful_stream_ordering: int,
+    ) -> List[str]:
+        """
+        Returns at most 50 event IDs and their corresponding stream_orderings
+        that correspond to the oldest events that have not yet been sent to
+        the destination.
+
+        Args:
+            destination: the destination in question
+            last_successful_stream_ordering: the stream_ordering of the
+                most-recently successfully-transmitted event to the destination
+
+        Returns:
+            list of event_ids
+        """
+        return await self.db_pool.runInteraction(
+            "get_catch_up_room_event_ids",
+            self._get_catch_up_room_event_ids_txn,
+            destination,
+            last_successful_stream_ordering,
+        )
+
+    @staticmethod
+    def _get_catch_up_room_event_ids_txn(
+        txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
+    ) -> List[str]:
+        q = """
+                SELECT event_id FROM destination_rooms
+                 JOIN events USING (stream_ordering)
+                WHERE destination = ?
+                  AND stream_ordering > ?
+                ORDER BY stream_ordering
+                LIMIT 50
+            """
+        txn.execute(
+            q, (destination, last_successful_stream_ordering),
+        )
+        event_ids = [row[0] for row in txn]
+        return event_ids
+
+    async def get_catch_up_outstanding_destinations(
+        self, after_destination: Optional[str]
+    ) -> List[str]:
+        """
+        Gets at most 25 destinations which have outstanding PDUs to be caught up,
+        and are not being backed off from
+        Args:
+            after_destination:
+                If provided, all destinations must be lexicographically greater
+                than this one.
+
+        Returns:
+            list of up to 25 destinations with outstanding catch-up.
+                These are the lexicographically first destinations which are
+                lexicographically greater than after_destination (if provided).
+        """
+        time = self.hs.get_clock().time_msec()
+
+        return await self.db_pool.runInteraction(
+            "get_catch_up_outstanding_destinations",
+            self._get_catch_up_outstanding_destinations_txn,
+            time,
+            after_destination,
+        )
+
+    @staticmethod
+    def _get_catch_up_outstanding_destinations_txn(
+        txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
+    ) -> List[str]:
+        q = """
+            SELECT destination FROM destinations
+                WHERE destination IN (
+                    SELECT destination FROM destination_rooms
+                        WHERE destination_rooms.stream_ordering >
+                            destinations.last_successful_stream_ordering
+                )
+                AND destination > ?
+                AND (
+                    retry_last_ts IS NULL OR
+                    retry_last_ts + retry_interval < ?
+                )
+                ORDER BY destination
+                LIMIT 25
+        """
+        txn.execute(
+            q,
+            (
+                # everything is lexicographically greater than "" so this gives
+                # us the first batch of up to 25.
+                after_destination or "",
+                now_time_ms,
+            ),
+        )
+
+        destinations = [row[0] for row in txn]
+        return destinations
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index b89668d561..3b9211a6d2 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -23,7 +23,7 @@ from synapse.types import JsonDict
 from synapse.util import json_encoder, stringutils
 
 
-@attr.s
+@attr.s(slots=True)
 class UIAuthSessionData:
     session_id = attr.ib(type=str)
     # The dictionary from the client root level, not the 'auth' key.
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f2f9a5799a..5a390ff2f6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -38,7 +38,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     SHARE_PRIVATE_WORKING_SET = 500
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
 
@@ -564,7 +564,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     SHARE_PRIVATE_WORKING_SET = 500
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(UserDirectoryStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     async def remove_from_user_dir(self, user_id: str) -> None:
         def _remove_from_user_dir_txn(txn):
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 2f7c95fc74..f9575b1f1f 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore):
                 return
 
             # They are there, delete them.
-            self.simple_delete_one_txn(
+            self.db_pool.simple_delete_one_txn(
                 txn, "erased_users", keyvalues={"user_id": user_id}
             )
 
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 139085b672..acb24e33af 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -181,7 +181,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
     STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
             self._background_deduplicate_state,
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index e924f1ca3b..bec3780a32 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -52,7 +52,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         # Originally the state store used a single DictionaryCache to cache the
         # event IDs for the state types in a given state group to avoid hammering
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index dbaeef91dd..603cd7d825 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,7 +18,7 @@
 import itertools
 import logging
 from collections import deque, namedtuple
-from typing import Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
 
 from prometheus_client import Counter, Histogram
 
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
-from synapse.types import StateMap
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
 
@@ -185,9 +185,12 @@ class EventsPersistenceStorage:
         # store for now.
         self.main_store = stores.main
         self.state_store = stores.state
+
+        assert stores.persist_events
         self.persist_events_store = stores.persist_events
 
         self._clock = hs.get_clock()
+        self._instance_name = hs.get_instance_name()
         self.is_mine_id = hs.is_mine_id
         self._event_persist_queue = _EventPeristenceQueue()
         self._state_resolution_handler = hs.get_state_resolution_handler()
@@ -196,7 +199,7 @@ class EventsPersistenceStorage:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ) -> int:
+    ) -> RoomStreamToken:
         """
         Write events to the database
         Args:
@@ -208,7 +211,7 @@ class EventsPersistenceStorage:
         Returns:
             the stream ordering of the latest persisted event
         """
-        partitioned = {}
+        partitioned = {}  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
@@ -226,11 +229,11 @@ class EventsPersistenceStorage:
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
-        return self.main_store.get_current_events_token()
+        return RoomStreamToken(None, self.main_store.get_current_events_token())
 
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
-    ) -> Tuple[int, int]:
+    ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
         """
         Returns:
             The stream ordering of `event`, and the stream ordering of the
@@ -245,7 +248,10 @@ class EventsPersistenceStorage:
         await make_deferred_yieldable(deferred)
 
         max_persisted_id = self.main_store.get_current_events_token()
-        return (event.internal_metadata.stream_ordering, max_persisted_id)
+        event_stream_id = event.internal_metadata.stream_ordering
+
+        pos = PersistedEventPosition(self._instance_name, event_stream_id)
+        return pos, RoomStreamToken(None, max_persisted_id)
 
     def _maybe_start_persisting(self, room_id: str):
         async def persisting_queue(item):
@@ -305,7 +311,9 @@ class EventsPersistenceStorage:
                     # Work out the new "current state" for each room.
                     # We do this by working out what the new extremities are and then
                     # calculating the state from that.
-                    events_by_room = {}
+                    events_by_room = (
+                        {}
+                    )  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
                     for event, context in chunk:
                         events_by_room.setdefault(event.room_id, []).append(
                             (event, context)
@@ -436,7 +444,7 @@ class EventsPersistenceStorage:
         self,
         room_id: str,
         event_contexts: List[Tuple[EventBase, EventContext]],
-        latest_event_ids: List[str],
+        latest_event_ids: Collection[str],
     ):
         """Calculates the new forward extremities for a room given events to
         persist.
@@ -470,7 +478,7 @@ class EventsPersistenceStorage:
         # Remove any events which are prev_events of any existing events.
         existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
             result
-        )
+        )  # type: Collection[str]
         result.difference_update(existing_prevs)
 
         # Finally handle the case where the new events have soft-failed prev
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index ee60e2a718..4957e77f4c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -19,12 +19,15 @@ import logging
 import os
 import re
 from collections import Counter
-from typing import TextIO
+from typing import Optional, TextIO
 
 import attr
 
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
 
@@ -63,7 +66,12 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
 )
 
 
-def prepare_database(db_conn, database_engine, config, databases=["main", "state"]):
+def prepare_database(
+    db_conn: Connection,
+    database_engine: BaseDatabaseEngine,
+    config: Optional[HomeServerConfig],
+    databases: Collection[str] = ["main", "state"],
+):
     """Prepares a physical database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
@@ -73,16 +81,24 @@ def prepare_database(db_conn, database_engine, config, databases=["main", "state
     Args:
         db_conn:
         database_engine:
-        config (synapse.config.homeserver.HomeServerConfig|None):
+        config :
             application config, or None if we are connecting to an existing
             database which we expect to be configured already
-        databases (list[str]): The name of the databases that will be used
+        databases: The name of the databases that will be used
             with this physical database. Defaults to all databases.
     """
 
     try:
         cur = db_conn.cursor()
 
+        # sqlite does not automatically start transactions for DDL / SELECT statements,
+        # so we start one before running anything. This ensures that any upgrades
+        # are either applied completely, or not at all.
+        #
+        # (psycopg2 automatically starts a transaction as soon as we run any statements
+        # at all, so this is redundant but harmless there.)
+        cur.execute("BEGIN TRANSACTION")
+
         logger.info("%r: Checking existing schema version", databases)
         version_info = _get_or_create_schema_state(cur, database_engine)
 
@@ -622,7 +638,7 @@ def _get_or_create_schema_state(txn, database_engine):
     return None
 
 
-@attr.s()
+@attr.s(slots=True)
 class _DirectoryListing:
     """Helper class to store schema file name and the
     absolute path to it.
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index d30e3f11e7..cec96ad6a7 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -22,7 +22,7 @@ from synapse.api.errors import SynapseError
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+@attr.s(slots=True)
 class PaginationChunk:
     """Returned by relation pagination APIs.
 
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8c4a83a840..f152f63321 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
 )
 
 GetRoomsForUserWithStreamOrdering = namedtuple(
-    "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
+    "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
 )
 
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b7eb4f8ac9..727fcc521c 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,16 +12,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-import contextlib
 import heapq
 import logging
 import threading
 from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
 
+import attr
 from typing_extensions import Deque
 
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
@@ -86,7 +87,7 @@ class StreamIdGenerator:
             upwards, -1 to grow downwards.
 
     Usage:
-        with await stream_id_gen.get_next() as stream_id:
+        async with stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
 
@@ -101,10 +102,10 @@ class StreamIdGenerator:
             )
         self._unfinished_ids = deque()  # type: Deque[int]
 
-    async def get_next(self):
+    def get_next(self):
         """
         Usage:
-            with await stream_id_gen.get_next() as stream_id:
+            async with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
         with self._lock:
@@ -113,7 +114,7 @@ class StreamIdGenerator:
 
             self._unfinished_ids.append(next_id)
 
-        @contextlib.contextmanager
+        @contextmanager
         def manager():
             try:
                 yield next_id
@@ -121,12 +122,12 @@ class StreamIdGenerator:
                 with self._lock:
                     self._unfinished_ids.remove(next_id)
 
-        return manager()
+        return _AsyncCtxManagerWrapper(manager())
 
-    async def get_next_mult(self, n):
+    def get_next_mult(self, n):
         """
         Usage:
-            with await stream_id_gen.get_next(n) as stream_ids:
+            async with stream_id_gen.get_next(n) as stream_ids:
                 # ... persist events ...
         """
         with self._lock:
@@ -140,7 +141,7 @@ class StreamIdGenerator:
             for next_id in next_ids:
                 self._unfinished_ids.append(next_id)
 
-        @contextlib.contextmanager
+        @contextmanager
         def manager():
             try:
                 yield next_ids
@@ -149,7 +150,7 @@ class StreamIdGenerator:
                     for next_id in next_ids:
                         self._unfinished_ids.remove(next_id)
 
-        return manager()
+        return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
@@ -184,12 +185,16 @@ class MultiWriterIdGenerator:
     Args:
         db_conn
         db
+        stream_name: A name for the stream.
         instance_name: The name of this instance.
         table: Database table associated with stream.
         instance_column: Column that stores the row's writer's instance name
         id_column: Column that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
+        writers: A list of known writers to use to populate current positions
+            on startup. Can be empty if nothing uses `get_current_token` or
+            `get_positions` (e.g. caches stream).
         positive: Whether the IDs are positive (true) or negative (false).
             When using negative IDs we go backwards from -1 to -2, -3, etc.
     """
@@ -198,16 +203,20 @@ class MultiWriterIdGenerator:
         self,
         db_conn,
         db: DatabasePool,
+        stream_name: str,
         instance_name: str,
         table: str,
         instance_column: str,
         id_column: str,
         sequence_name: str,
+        writers: List[str],
         positive: bool = True,
     ):
         self._db = db
+        self._stream_name = stream_name
         self._instance_name = instance_name
         self._positive = positive
+        self._writers = writers
         self._return_factor = 1 if positive else -1
 
         # We lock as some functions may be called from DB threads.
@@ -216,14 +225,16 @@ class MultiWriterIdGenerator:
         # Note: If we are a negative stream then we still store all the IDs as
         # positive to make life easier for us, and simply negate the IDs when we
         # return them.
-        self._current_positions = self._load_current_ids(
-            db_conn, table, instance_column, id_column
-        )
+        self._current_positions = {}  # type: Dict[str, int]
 
         # Set of local IDs that we're still processing. The current position
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
 
+        # Set of local IDs that we've processed that are larger than the current
+        # position, due to there being smaller unpersisted IDs.
+        self._finished_ids = set()  # type: Set[int]
+
         # We track the max position where we know everything before has been
         # persisted. This is done by a) looking at the min across all instances
         # and b) noting that if we have seen a run of persisted positions
@@ -236,37 +247,91 @@ class MultiWriterIdGenerator:
         # gaps should be relatively rare it's still worth doing the book keeping
         # that allows us to skip forwards when there are gapless runs of
         # positions.
+        #
+        # We start at 1 here as a) the first generated stream ID will be 2, and
+        # b) other parts of the code assume that stream IDs are strictly greater
+        # than 0.
         self._persisted_upto_position = (
-            min(self._current_positions.values()) if self._current_positions else 0
+            min(self._current_positions.values()) if self._current_positions else 1
         )
         self._known_persisted_positions = []  # type: List[int]
 
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
+        # This goes and fills out the above state from the database.
+        self._load_current_ids(db_conn, table, instance_column, id_column)
+
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
-    ) -> Dict[str, int]:
-        # If positive stream aggregate via MAX. For negative stream use MIN
-        # *and* negate the result to get a positive number.
-        sql = """
-            SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
-            GROUP BY %(instance)s
-        """ % {
-            "instance": instance_column,
-            "id": id_column,
-            "table": table,
-            "agg": "MAX" if self._positive else "-MIN",
-        }
-
+    ):
         cur = db_conn.cursor()
-        cur.execute(sql)
 
-        # `cur` is an iterable over returned rows, which are 2-tuples.
-        current_positions = dict(cur)
+        # Load the current positions of all writers for the stream.
+        if self._writers:
+            sql = """
+                SELECT instance_name, stream_id FROM stream_positions
+                WHERE stream_name = ?
+            """
+            sql = self._db.engine.convert_param_style(sql)
 
-        cur.close()
+            cur.execute(sql, (self._stream_name,))
+
+            self._current_positions = {
+                instance: stream_id * self._return_factor
+                for instance, stream_id in cur
+                if instance in self._writers
+            }
+
+        # We set the `_persisted_upto_position` to be the minimum of all current
+        # positions. If empty we use the max stream ID from the DB table.
+        min_stream_id = min(self._current_positions.values(), default=None)
+
+        if min_stream_id is None:
+            sql = """
+                SELECT COALESCE(%(agg)s(%(id)s), 1) FROM %(table)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "agg": "MAX" if self._positive else "-MIN",
+            }
+            cur.execute(sql)
+            (stream_id,) = cur.fetchone()
+            self._persisted_upto_position = stream_id
+        else:
+            # If we have a min_stream_id then we pull out everything greater
+            # than it from the DB so that we can prefill
+            # `_known_persisted_positions` and get a more accurate
+            # `_persisted_upto_position`.
+            #
+            # We also check if any of the later rows are from this instance, in
+            # which case we use that for this instance's current position. This
+            # is to handle the case where we didn't finish persisting to the
+            # stream positions table before restart (or the stream position
+            # table otherwise got out of date).
+
+            sql = """
+                SELECT %(instance)s, %(id)s FROM %(table)s
+                WHERE ? %(cmp)s %(id)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "instance": instance_column,
+                "cmp": "<=" if self._positive else ">=",
+            }
+            sql = self._db.engine.convert_param_style(sql)
+            cur.execute(sql, (min_stream_id,))
+
+            self._persisted_upto_position = min_stream_id
+
+            with self._lock:
+                for (instance, stream_id,) in cur:
+                    stream_id = self._return_factor * stream_id
+                    self._add_persisted_position(stream_id)
 
-        return current_positions
+                    if instance == self._instance_name:
+                        self._current_positions[instance] = stream_id
+
+        cur.close()
 
     def _load_next_id_txn(self, txn) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
@@ -274,59 +339,23 @@ class MultiWriterIdGenerator:
     def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
         return self._sequence_gen.get_next_mult_txn(txn, n)
 
-    async def get_next(self):
+    def get_next(self):
         """
         Usage:
-            with await stream_id_gen.get_next() as stream_id:
+            async with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
-        next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
-        # Assert the fetched ID is actually greater than what we currently
-        # believe the ID to be. If not, then the sequence and table have got
-        # out of sync somehow.
-        with self._lock:
-            assert self._current_positions.get(self._instance_name, 0) < next_id
 
-            self._unfinished_ids.add(next_id)
+        return _MultiWriterCtxManager(self)
 
-        @contextlib.contextmanager
-        def manager():
-            try:
-                # Multiply by the return factor so that the ID has correct sign.
-                yield self._return_factor * next_id
-            finally:
-                self._mark_id_as_finished(next_id)
-
-        return manager()
-
-    async def get_next_mult(self, n: int):
+    def get_next_mult(self, n: int):
         """
         Usage:
-            with await stream_id_gen.get_next_mult(5) as stream_ids:
+            async with stream_id_gen.get_next_mult(5) as stream_ids:
                 # ... persist events ...
         """
-        next_ids = await self._db.runInteraction(
-            "_load_next_mult_id", self._load_next_mult_id_txn, n
-        )
 
-        # Assert the fetched ID is actually greater than any ID we've already
-        # seen. If not, then the sequence and table have got out of sync
-        # somehow.
-        with self._lock:
-            assert max(self._current_positions.values(), default=0) < min(next_ids)
-
-            self._unfinished_ids.update(next_ids)
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                yield [self._return_factor * i for i in next_ids]
-            finally:
-                for i in next_ids:
-                    self._mark_id_as_finished(i)
-
-        return manager()
+        return _MultiWriterCtxManager(self, n)
 
     def get_next_txn(self, txn: LoggingTransaction):
         """
@@ -344,21 +373,63 @@ class MultiWriterIdGenerator:
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
 
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persited row with the correct instance name.
+        if self._writers:
+            txn.call_after(
+                run_as_background_process,
+                "MultiWriterIdGenerator._update_table",
+                self._db.runInteraction,
+                "MultiWriterIdGenerator._update_table",
+                self._update_stream_positions_table_txn,
+            )
+
         return self._return_factor * next_id
 
     def _mark_id_as_finished(self, next_id: int):
         """The ID has finished being processed so we should advance the
-        current poistion if possible.
+        current position if possible.
         """
 
         with self._lock:
             self._unfinished_ids.discard(next_id)
+            self._finished_ids.add(next_id)
+
+            new_cur = None
+
+            if self._unfinished_ids:
+                # If there are unfinished IDs then the new position will be the
+                # largest finished ID less than the minimum unfinished ID.
+
+                finished = set()
+
+                min_unfinshed = min(self._unfinished_ids)
+                for s in self._finished_ids:
+                    if s < min_unfinshed:
+                        if new_cur is None or new_cur < s:
+                            new_cur = s
+                    else:
+                        finished.add(s)
+
+                # We clear these out since they're now all less than the new
+                # position.
+                self._finished_ids = finished
+            else:
+                # There are no unfinished IDs so the new position is simply the
+                # largest finished one.
+                new_cur = max(self._finished_ids)
 
-            # Figure out if its safe to advance the position by checking there
-            # aren't any lower allocated IDs that are yet to finish.
-            if all(c > next_id for c in self._unfinished_ids):
+                # We clear these out since they're now all less than the new
+                # position.
+                self._finished_ids.clear()
+
+            if new_cur:
                 curr = self._current_positions.get(self._instance_name, 0)
-                self._current_positions[self._instance_name] = max(curr, next_id)
+                self._current_positions[self._instance_name] = max(curr, new_cur)
 
             self._add_persisted_position(next_id)
 
@@ -367,9 +438,7 @@ class MultiWriterIdGenerator:
         equal to it have been successfully persisted.
         """
 
-        # Currently we don't support this operation, as it's not obvious how to
-        # condense the stream positions of multiple writers into a single int.
-        raise NotImplementedError()
+        return self.get_persisted_upto_position()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
         """Returns the position of the given writer.
@@ -428,7 +497,7 @@ class MultiWriterIdGenerator:
         # We move the current min position up if the minimum current positions
         # of all instances is higher (since by definition all positions less
         # that that have been persisted).
-        min_curr = min(self._current_positions.values())
+        min_curr = min(self._current_positions.values(), default=0)
         self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
 
         # We now iterate through the seen positions, discarding those that are
@@ -449,3 +518,95 @@ class MultiWriterIdGenerator:
                 # There was a gap in seen positions, so there is nothing more to
                 # do.
                 break
+
+    def _update_stream_positions_table_txn(self, txn):
+        """Update the `stream_positions` table with newly persisted position.
+        """
+
+        if not self._writers:
+            return
+
+        # We upsert the value, ensuring on conflict that we always increase the
+        # value (or decrease if stream goes backwards).
+        sql = """
+            INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+            VALUES (?, ?, ?)
+            ON CONFLICT (stream_name, instance_name)
+            DO UPDATE SET
+                stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+        """ % {
+            "agg": "GREATEST" if self._positive else "LEAST",
+        }
+
+        pos = (self.get_current_token_for_writer(self._instance_name),)
+        txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+    """Helper class to convert a plain context manager to an async one.
+
+    This is mainly useful if you have a plain context manager but the interface
+    requires an async one.
+    """
+
+    inner = attr.ib()
+
+    async def __aenter__(self):
+        return self.inner.__enter__()
+
+    async def __aexit__(self, exc_type, exc, tb):
+        return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+    """Async context manager returned by MultiWriterIdGenerator
+    """
+
+    id_gen = attr.ib(type=MultiWriterIdGenerator)
+    multiple_ids = attr.ib(type=Optional[int], default=None)
+    stream_ids = attr.ib(type=List[int], factory=list)
+
+    async def __aenter__(self) -> Union[int, List[int]]:
+        self.stream_ids = await self.id_gen._db.runInteraction(
+            "_load_next_mult_id",
+            self.id_gen._load_next_mult_id_txn,
+            self.multiple_ids or 1,
+        )
+
+        # Assert the fetched ID is actually greater than any ID we've already
+        # seen. If not, then the sequence and table have got out of sync
+        # somehow.
+        with self.id_gen._lock:
+            assert max(self.id_gen._current_positions.values(), default=0) < min(
+                self.stream_ids
+            )
+
+            self.id_gen._unfinished_ids.update(self.stream_ids)
+
+        if self.multiple_ids is None:
+            return self.stream_ids[0] * self.id_gen._return_factor
+        else:
+            return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+    async def __aexit__(self, exc_type, exc, tb):
+        for i in self.stream_ids:
+            self.id_gen._mark_id_as_finished(i)
+
+        if exc_type is not None:
+            return False
+
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persisted row with the correct instance name.
+        if self.id_gen._writers:
+            await self.id_gen._db.runInteraction(
+                "MultiWriterIdGenerator._update_table",
+                self.id_gen._update_stream_positions_table_txn,
+            )
+
+        return False
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index d97dc4d101..0bdf846edf 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -14,9 +14,13 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
+
+import attr
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
 from synapse.types import StreamToken
 
 logger = logging.getLogger(__name__)
@@ -25,38 +29,22 @@ logger = logging.getLogger(__name__)
 MAX_LIMIT = 1000
 
 
-class SourcePaginationConfig:
-
-    """A configuration object which stores pagination parameters for a
-    specific event source."""
-
-    def __init__(self, from_key=None, to_key=None, direction="f", limit=None):
-        self.from_key = from_key
-        self.to_key = to_key
-        self.direction = "f" if direction == "f" else "b"
-        self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
-
-    def __repr__(self):
-        return "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)" % (
-            self.from_key,
-            self.to_key,
-            self.direction,
-            self.limit,
-        )
-
-
+@attr.s(slots=True)
 class PaginationConfig:
-
     """A configuration object which stores pagination parameters."""
 
-    def __init__(self, from_token=None, to_token=None, direction="f", limit=None):
-        self.from_token = from_token
-        self.to_token = to_token
-        self.direction = "f" if direction == "f" else "b"
-        self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
+    from_token = attr.ib(type=Optional[StreamToken])
+    to_token = attr.ib(type=Optional[StreamToken])
+    direction = attr.ib(type=str)
+    limit = attr.ib(type=Optional[int])
 
     @classmethod
-    def from_request(cls, request, raise_invalid_params=True, default_limit=None):
+    def from_request(
+        cls,
+        request: SynapseRequest,
+        raise_invalid_params: bool = True,
+        default_limit: Optional[int] = None,
+    ) -> "PaginationConfig":
         direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
 
         from_tok = parse_string(request, "from")
@@ -78,8 +66,11 @@ class PaginationConfig:
 
         limit = parse_integer(request, "limit", default=default_limit)
 
-        if limit and limit < 0:
-            raise SynapseError(400, "Limit must be 0 or above")
+        if limit:
+            if limit < 0:
+                raise SynapseError(400, "Limit must be 0 or above")
+
+            limit = min(int(limit), MAX_LIMIT)
 
         try:
             return PaginationConfig(from_tok, to_tok, direction, limit)
@@ -87,20 +78,10 @@ class PaginationConfig:
             logger.exception("Failed to create pagination config")
             raise SynapseError(400, "Invalid request.")
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % (
             self.from_token,
             self.to_token,
             self.direction,
             self.limit,
         )
-
-    def get_source_config(self, source_name):
-        keyname = "%s_key" % source_name
-
-        return SourcePaginationConfig(
-            from_key=getattr(self.from_token, keyname),
-            to_key=getattr(self.to_token, keyname) if self.to_token else None,
-            direction=self.direction,
-            limit=self.limit,
-        )
diff --git a/synapse/types.py b/synapse/types.py
index f7de48f148..ec39f9e1e8 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,7 +18,7 @@ import re
 import string
 import sys
 from collections import namedtuple
-from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
+from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -165,7 +165,9 @@ def get_localpart_from_id(string):
 DS = TypeVar("DS", bound="DomainSpecificString")
 
 
-class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))):
+class DomainSpecificString(
+    namedtuple("DomainSpecificString", ("localpart", "domain")), metaclass=abc.ABCMeta
+):
     """Common base class among ID/name strings that have a local part and a
     domain name, prefixed with a sigil.
 
@@ -175,8 +177,6 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
         'domain' : The domain part of the name
     """
 
-    __metaclass__ = abc.ABCMeta
-
     SIGIL = abc.abstractproperty()  # type: str  # type: ignore
 
     # Deny iteration because it will bite you if you try to create a singleton
@@ -362,22 +362,81 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
     return username.decode("ascii")
 
 
-class StreamToken(
-    namedtuple(
-        "Token",
-        (
-            "room_key",
-            "presence_key",
-            "typing_key",
-            "receipt_key",
-            "account_data_key",
-            "push_rules_key",
-            "to_device_key",
-            "device_list_key",
-            "groups_key",
-        ),
+@attr.s(frozen=True, slots=True)
+class RoomStreamToken:
+    """Tokens are positions between events. The token "s1" comes after event 1.
+
+            s0    s1
+            |     |
+        [0] V [1] V [2]
+
+    Tokens can either be a point in the live event stream or a cursor going
+    through historic events.
+
+    When traversing the live event stream events are ordered by when they
+    arrived at the homeserver.
+
+    When traversing historic events the events are ordered by their depth in
+    the event graph "topological_ordering" and then by when they arrived at the
+    homeserver "stream_ordering".
+
+    Live tokens start with an "s" followed by the "stream_ordering" id of the
+    event it comes after. Historic tokens start with a "t" followed by the
+    "topological_ordering" id of the event it comes after, followed by "-",
+    followed by the "stream_ordering" id of the event it comes after.
+    """
+
+    topological = attr.ib(
+        type=Optional[int],
+        validator=attr.validators.optional(attr.validators.instance_of(int)),
     )
-):
+    stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+
+    @classmethod
+    def parse(cls, string: str) -> "RoomStreamToken":
+        try:
+            if string[0] == "s":
+                return cls(topological=None, stream=int(string[1:]))
+            if string[0] == "t":
+                parts = string[1:].split("-", 1)
+                return cls(topological=int(parts[0]), stream=int(parts[1]))
+        except Exception:
+            pass
+        raise SynapseError(400, "Invalid token %r" % (string,))
+
+    @classmethod
+    def parse_stream_token(cls, string: str) -> "RoomStreamToken":
+        try:
+            if string[0] == "s":
+                return cls(topological=None, stream=int(string[1:]))
+        except Exception:
+            pass
+        raise SynapseError(400, "Invalid token %r" % (string,))
+
+    def as_tuple(self) -> Tuple[Optional[int], int]:
+        return (self.topological, self.stream)
+
+    def __str__(self) -> str:
+        if self.topological is not None:
+            return "t%d-%d" % (self.topological, self.stream)
+        else:
+            return "s%d" % (self.stream,)
+
+
+@attr.s(slots=True, frozen=True)
+class StreamToken:
+    room_key = attr.ib(
+        type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
+    )
+    presence_key = attr.ib(type=int)
+    typing_key = attr.ib(type=int)
+    receipt_key = attr.ib(type=int)
+    account_data_key = attr.ib(type=int)
+    push_rules_key = attr.ib(type=int)
+    to_device_key = attr.ib(type=int)
+    device_list_key = attr.ib(type=int)
+    groups_key = attr.ib(type=int)
+
     _SEPARATOR = "_"
     START = None  # type: StreamToken
 
@@ -385,24 +444,19 @@ class StreamToken(
     def from_string(cls, string):
         try:
             keys = string.split(cls._SEPARATOR)
-            while len(keys) < len(cls._fields):
+            while len(keys) < len(attr.fields(cls)):
                 # i.e. old token from before receipt_key
                 keys.append("0")
-            return cls(*keys)
+            return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
         except Exception:
             raise SynapseError(400, "Invalid Token")
 
     def to_string(self):
-        return self._SEPARATOR.join([str(k) for k in self])
+        return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
 
     @property
     def room_stream_id(self):
-        # TODO(markjh): Awful hack to work around hacks in the presence tests
-        # which assume that the keys are integers.
-        if type(self.room_key) is int:
-            return self.room_key
-        else:
-            return int(self.room_key[1:].split("-")[-1])
+        return self.room_key.stream
 
     def is_after(self, other):
         """Does this token contain events that the other doesn't?"""
@@ -418,7 +472,7 @@ class StreamToken(
             or (int(other.groups_key) < int(self.groups_key))
         )
 
-    def copy_and_advance(self, key, new_value):
+    def copy_and_advance(self, key, new_value) -> "StreamToken":
         """Advance the given key in the token to a new value if and only if the
         new value is after the old value.
         """
@@ -434,64 +488,26 @@ class StreamToken(
         else:
             return self
 
-    def copy_and_replace(self, key, new_value):
-        return self._replace(**{key: new_value})
+    def copy_and_replace(self, key, new_value) -> "StreamToken":
+        return attr.evolve(self, **{key: new_value})
 
 
-StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))
+StreamToken.START = StreamToken.from_string("s0_0")
 
 
-class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
-    """Tokens are positions between events. The token "s1" comes after event 1.
-
-            s0    s1
-            |     |
-        [0] V [1] V [2]
+@attr.s(slots=True, frozen=True)
+class PersistedEventPosition:
+    """Position of a newly persisted event with instance that persisted it.
 
-    Tokens can either be a point in the live event stream or a cursor going
-    through historic events.
-
-    When traversing the live event stream events are ordered by when they
-    arrived at the homeserver.
-
-    When traversing historic events the events are ordered by their depth in
-    the event graph "topological_ordering" and then by when they arrived at the
-    homeserver "stream_ordering".
-
-    Live tokens start with an "s" followed by the "stream_ordering" id of the
-    event it comes after. Historic tokens start with a "t" followed by the
-    "topological_ordering" id of the event it comes after, followed by "-",
-    followed by the "stream_ordering" id of the event it comes after.
+    This can be used to test whether the event is persisted before or after a
+    RoomStreamToken.
     """
 
-    __slots__ = []  # type: list
-
-    @classmethod
-    def parse(cls, string):
-        try:
-            if string[0] == "s":
-                return cls(topological=None, stream=int(string[1:]))
-            if string[0] == "t":
-                parts = string[1:].split("-", 1)
-                return cls(topological=int(parts[0]), stream=int(parts[1]))
-        except Exception:
-            pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
-
-    @classmethod
-    def parse_stream_token(cls, string):
-        try:
-            if string[0] == "s":
-                return cls(topological=None, stream=int(string[1:]))
-        except Exception:
-            pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
+    instance_name = attr.ib(type=str)
+    stream = attr.ib(type=int)
 
-    def __str__(self):
-        if self.topological is not None:
-            return "t%d-%d" % (self.topological, self.stream)
-        else:
-            return "s%d" % (self.stream,)
+    def persisted_after(self, token: RoomStreamToken) -> bool:
+        return token.stream < self.stream
 
 
 class ThirdPartyInstanceID(
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index a13f11f8d8..d55b93d763 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,11 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import json
 import logging
 import re
 
 import attr
-from canonicaljson import json
 
 from twisted.internet import defer, task
 
@@ -45,7 +45,7 @@ def unwrapFirstError(failure):
     return failure.value.subFailure
 
 
-@attr.s
+@attr.s(slots=True)
 class Clock:
     """
     A Clock wraps a Twisted reactor and provides utilities on top of it.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index bb57e27beb..67ce9a5f39 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -17,13 +17,25 @@
 import collections
 import logging
 from contextlib import contextmanager
-from typing import Dict, Sequence, Set, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Hashable,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    TypeVar,
+    Union,
+)
 
 import attr
 from typing_extensions import ContextManager
 
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
+from twisted.internet.interfaces import IReactorTime
 from twisted.python import failure
 
 from synapse.logging.context import (
@@ -54,7 +66,7 @@ class ObservableDeferred:
 
     __slots__ = ["_deferred", "_observers", "_result"]
 
-    def __init__(self, deferred, consumeErrors=False):
+    def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", set())
@@ -111,25 +123,25 @@ class ObservableDeferred:
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
-    def observers(self):
+    def observers(self) -> List[defer.Deferred]:
         return self._observers
 
-    def has_called(self):
+    def has_called(self) -> bool:
         return self._result is not None
 
-    def has_succeeded(self):
+    def has_succeeded(self) -> bool:
         return self._result is not None and self._result[0] is True
 
-    def get_result(self):
+    def get_result(self) -> Any:
         return self._result[1]
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         return getattr(self._deferred, name)
 
-    def __setattr__(self, name, value):
+    def __setattr__(self, name: str, value: Any) -> None:
         setattr(self._deferred, name, value)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
             id(self),
             self._result,
@@ -137,18 +149,20 @@ class ObservableDeferred:
         )
 
 
-def concurrently_execute(func, args, limit):
-    """Executes the function with each argument conncurrently while limiting
+def concurrently_execute(
+    func: Callable, args: Iterable[Any], limit: int
+) -> defer.Deferred:
+    """Executes the function with each argument concurrently while limiting
     the number of concurrent executions.
 
     Args:
-        func (func): Function to execute, should return a deferred or coroutine.
-        args (Iterable): List of arguments to pass to func, each invocation of func
+        func: Function to execute, should return a deferred or coroutine.
+        args: List of arguments to pass to func, each invocation of func
             gets a single argument.
-        limit (int): Maximum number of conccurent executions.
+        limit: Maximum number of conccurent executions.
 
     Returns:
-        deferred: Resolved when all function invocations have finished.
+        Deferred[list]: Resolved when all function invocations have finished.
     """
     it = iter(args)
 
@@ -167,14 +181,17 @@ def concurrently_execute(func, args, limit):
     ).addErrback(unwrapFirstError)
 
 
-def yieldable_gather_results(func, iter, *args, **kwargs):
+def yieldable_gather_results(
+    func: Callable, iter: Iterable, *args: Any, **kwargs: Any
+) -> defer.Deferred:
     """Executes the function with each argument concurrently.
 
     Args:
-        func (func): Function to execute that returns a Deferred
-        iter (iter): An iterable that yields items that get passed as the first
+        func: Function to execute that returns a Deferred
+        iter: An iterable that yields items that get passed as the first
             argument to the function
         *args: Arguments to be passed to each call to func
+        **kwargs: Keyword arguments to be passed to each call to func
 
     Returns
         Deferred[list]: Resolved when all functions have been invoked, or errors if
@@ -188,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
     ).addErrback(unwrapFirstError)
 
 
+@attr.s(slots=True)
+class _LinearizerEntry:
+    # The number of things executing.
+    count = attr.ib(type=int)
+    # Deferreds for the things blocked from executing.
+    deferreds = attr.ib(type=collections.OrderedDict)
+
+
 class Linearizer:
     """Limits concurrent access to resources based on a key. Useful to ensure
     only a few things happen at a time on a given resource.
 
     Example:
 
-        with (yield limiter.queue("test_key")):
+        with await limiter.queue("test_key"):
             # do some work.
 
     """
 
-    def __init__(self, name=None, max_count=1, clock=None):
+    def __init__(
+        self,
+        name: Optional[str] = None,
+        max_count: int = 1,
+        clock: Optional[Clock] = None,
+    ):
         """
         Args:
-            max_count(int): The maximum number of concurrent accesses
+            max_count: The maximum number of concurrent accesses
         """
         if name is None:
-            self.name = id(self)
+            self.name = id(self)  # type: Union[str, int]
         else:
             self.name = name
 
@@ -216,15 +246,10 @@ class Linearizer:
         self._clock = clock
         self.max_count = max_count
 
-        # key_to_defer is a map from the key to a 2 element list where
-        # the first element is the number of things executing, and
-        # the second element is an OrderedDict, where the keys are deferreds for the
-        # things blocked from executing.
-        self.key_to_defer = (
-            {}
-        )  # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
+        # key_to_defer is a map from the key to a _LinearizerEntry.
+        self.key_to_defer = {}  # type: Dict[Hashable, _LinearizerEntry]
 
-    def is_queued(self, key) -> bool:
+    def is_queued(self, key: Hashable) -> bool:
         """Checks whether there is a process queued up waiting
         """
         entry = self.key_to_defer.get(key)
@@ -234,25 +259,27 @@ class Linearizer:
 
         # There are waiting deferreds only in the OrderedDict of deferreds is
         # non-empty.
-        return bool(entry[1])
+        return bool(entry.deferreds)
 
-    def queue(self, key):
+    def queue(self, key: Hashable) -> defer.Deferred:
         # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
         # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
         # propagated inside inlineCallbacks until Twisted 18.7)
-        entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
+        entry = self.key_to_defer.setdefault(
+            key, _LinearizerEntry(0, collections.OrderedDict())
+        )
 
         # If the number of things executing is greater than the maximum
         # then add a deferred to the list of blocked items
         # When one of the things currently executing finishes it will callback
         # this item so that it can continue executing.
-        if entry[0] >= self.max_count:
+        if entry.count >= self.max_count:
             res = self._await_lock(key)
         else:
             logger.debug(
                 "Acquired uncontended linearizer lock %r for key %r", self.name, key
             )
-            entry[0] += 1
+            entry.count += 1
             res = defer.succeed(None)
 
         # once we successfully get the lock, we need to return a context manager which
@@ -267,15 +294,15 @@ class Linearizer:
 
                 # We've finished executing so check if there are any things
                 # blocked waiting to execute and start one of them
-                entry[0] -= 1
+                entry.count -= 1
 
-                if entry[1]:
-                    (next_def, _) = entry[1].popitem(last=False)
+                if entry.deferreds:
+                    (next_def, _) = entry.deferreds.popitem(last=False)
 
                     # we need to run the next thing in the sentinel context.
                     with PreserveLoggingContext():
                         next_def.callback(None)
-                elif entry[0] == 0:
+                elif entry.count == 0:
                     # We were the last thing for this key: remove it from the
                     # map.
                     del self.key_to_defer[key]
@@ -283,7 +310,7 @@ class Linearizer:
         res.addCallback(_ctx_manager)
         return res
 
-    def _await_lock(self, key):
+    def _await_lock(self, key: Hashable) -> defer.Deferred:
         """Helper for queue: adds a deferred to the queue
 
         Assumes that we've already checked that we've reached the limit of the number
@@ -298,11 +325,11 @@ class Linearizer:
         logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
 
         new_defer = make_deferred_yieldable(defer.Deferred())
-        entry[1][new_defer] = 1
+        entry.deferreds[new_defer] = 1
 
         def cb(_r):
             logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
-            entry[0] += 1
+            entry.count += 1
 
             # if the code holding the lock completes synchronously, then it
             # will recursively run the next claimant on the list. That can
@@ -331,7 +358,7 @@ class Linearizer:
                 )
 
             # we just have to take ourselves back out of the queue.
-            del entry[1][new_defer]
+            del entry.deferreds[new_defer]
             return e
 
         new_defer.addCallbacks(cb, eb)
@@ -419,14 +446,22 @@ class ReadWriteLock:
         return _ctx_manager()
 
 
-def _cancelled_to_timed_out_error(value, timeout):
+R = TypeVar("R")
+
+
+def _cancelled_to_timed_out_error(value: R, timeout: float) -> R:
     if isinstance(value, failure.Failure):
         value.trap(CancelledError)
         raise defer.TimeoutError(timeout, "Deferred")
     return value
 
 
-def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
+def timeout_deferred(
+    deferred: defer.Deferred,
+    timeout: float,
+    reactor: IReactorTime,
+    on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None,
+) -> defer.Deferred:
     """The in built twisted `Deferred.addTimeout` fails to time out deferreds
     that have a canceller that throws exceptions. This method creates a new
     deferred that wraps and times out the given deferred, correctly handling
@@ -437,10 +472,10 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
     NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
 
     Args:
-        deferred (Deferred)
-        timeout (float): Timeout in seconds
-        reactor (twisted.interfaces.IReactorTime): The twisted reactor to use
-        on_timeout_cancel (callable): A callable which is called immediately
+        deferred: The Deferred to potentially timeout.
+        timeout: Timeout in seconds
+        reactor: The twisted reactor to use
+        on_timeout_cancel: A callable which is called immediately
             after the deferred times out, and not if this deferred is
             otherwise cancelled before the timeout.
 
@@ -452,7 +487,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
             CancelledError Failure into a defer.TimeoutError.
 
     Returns:
-        Deferred
+        A new Deferred.
     """
 
     new_d = defer.Deferred()
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 237f588658..8fc05be278 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -42,7 +42,7 @@ response_cache_evicted = Gauge(
 response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
 
 
-@attr.s
+@attr.s(slots=True)
 class CacheMetric:
 
     _cache = attr.ib()
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index a750261e77..f73e95393c 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -16,8 +16,6 @@ import inspect
 import logging
 
 from twisted.internet import defer
-from twisted.internet.defer import Deferred, fail, succeed
-from twisted.python import failure
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -29,11 +27,6 @@ def user_left_room(distributor, user, room_id):
     distributor.fire("user_left_room", user=user, room_id=room_id)
 
 
-# XXX: this is no longer used. We should probably kill it.
-def user_joined_room(distributor, user, room_id):
-    distributor.fire("user_joined_room", user=user, room_id=room_id)
-
-
 class Distributor:
     """A central dispatch point for loosely-connected pieces of code to
     register, observe, and fire signals.
@@ -81,28 +74,6 @@ class Distributor:
         run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
 
 
-def maybeAwaitableDeferred(f, *args, **kw):
-    """
-    Invoke a function that may or may not return a Deferred or an Awaitable.
-
-    This is a modified version of twisted.internet.defer.maybeDeferred.
-    """
-    try:
-        result = f(*args, **kw)
-    except Exception:
-        return fail(failure.Failure(captureVars=Deferred.debug))
-
-    if isinstance(result, Deferred):
-        return result
-    # Handle the additional case of an awaitable being returned.
-    elif inspect.isawaitable(result):
-        return defer.ensureDeferred(result)
-    elif isinstance(result, failure.Failure):
-        return fail(result)
-    else:
-        return succeed(result)
-
-
 class Signal:
     """A Signal is a dispatch point that stores a list of callables as
     observers of it.
@@ -132,22 +103,17 @@ class Signal:
         Returns a Deferred that will complete when all the observers have
         completed."""
 
-        def do(observer):
-            def eb(failure):
+        async def do(observer):
+            try:
+                result = observer(*args, **kwargs)
+                if inspect.isawaitable(result):
+                    result = await result
+                return result
+            except Exception as e:
                 logger.warning(
-                    "%s signal observer %s failed: %r",
-                    self.name,
-                    observer,
-                    failure,
-                    exc_info=(
-                        failure.type,
-                        failure.value,
-                        failure.getTracebackObject(),
-                    ),
+                    "%s signal observer %s failed: %r", self.name, observer, e,
                 )
 
-            return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
-
         deferreds = [run_in_background(do, o) for o in self.observers]
 
         return make_deferred_yieldable(
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 0e445e01d7..bf094c9386 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,7 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from canonicaljson import json
+import json
+
 from frozendict import frozendict
 
 
@@ -66,5 +67,5 @@ def _handle_frozendict(obj):
 # A JSONEncoder which is capable of encoding frozendicts without barfing.
 # Additionally reduce the whitespace produced by JSON encoding.
 frozendict_json_encoder = json.JSONEncoder(
-    default=_handle_frozendict, separators=(",", ":"),
+    allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
 )
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 631654f297..da24ba0470 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -94,7 +94,7 @@ class SynapseManhole(ColoredManhole):
     """Overrides connectionMade to create our own ManholeInterpreter"""
 
     def connectionMade(self):
-        super(SynapseManhole, self).connectionMade()
+        super().connectionMade()
 
         # replace the manhole interpreter with our own impl
         self.interpreter = SynapseManholeInterpreter(self, self.namespace)
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 54c046b6e1..72574d3af2 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -13,8 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from __future__ import print_function
-
 import functools
 import sys
 from typing import Any, Callable, List
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 79869aaa44..a5cc9d0551 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -45,7 +45,7 @@ class NotRetryingDestination(Exception):
         """
 
         msg = "Not retrying server %s." % (destination,)
-        super(NotRetryingDestination, self).__init__(msg)
+        super().__init__(msg)
 
         self.retry_last_ts = retry_last_ts
         self.retry_interval = retry_interval