summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2020-10-15 14:48:13 +0100
committerBen Banfield-Zanin <benbz@matrix.org>2020-10-15 14:48:13 +0100
commit8d9ae573f33110e0420204bceb111fd8df649e7c (patch)
treec8113c67df9769a14e8bb0a03620026dbe9aa0ba /synapse/handlers
parentMerge remote-tracking branch 'origin/anoa/3pid_check_invite_exemption' into b... (diff)
parentRemove racey assertion in MultiWriterIDGenerator (#8530) (diff)
downloadsynapse-8d9ae573f33110e0420204bceb111fd8df649e7c.tar.xz
Merge remote-tracking branch 'origin/release-v1.21.2' into bbz/info-mainline-1.21.2 github/bbz/info-mainline-1.21.2 bbz/info-mainline-1.21.2
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/acme_issuing_service.py2
-rw-r--r--synapse/handlers/admin.py6
-rw-r--r--synapse/handlers/auth.py64
-rw-r--r--synapse/handlers/deactivate_account.py2
-rw-r--r--synapse/handlers/device.py46
-rw-r--r--synapse/handlers/directory.py2
-rw-r--r--synapse/handlers/e2e_keys.py2
-rw-r--r--synapse/handlers/events.py12
-rw-r--r--synapse/handlers/federation.py138
-rw-r--r--synapse/handlers/groups_local.py2
-rw-r--r--synapse/handlers/identity.py27
-rw-r--r--synapse/handlers/initial_sync.py34
-rw-r--r--synapse/handlers/message.py127
-rw-r--r--synapse/handlers/oidc_handler.py102
-rw-r--r--synapse/handlers/pagination.py54
-rw-r--r--synapse/handlers/presence.py3
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/handlers/read_marker.py2
-rw-r--r--synapse/handlers/receipts.py17
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/handlers/room.py37
-rw-r--r--synapse/handlers/room_list.py2
-rw-r--r--synapse/handlers/room_member.py53
-rw-r--r--synapse/handlers/room_member_worker.py11
-rw-r--r--synapse/handlers/saml_handler.py171
-rw-r--r--synapse/handlers/search.py10
-rw-r--r--synapse/handlers/set_password.py2
-rw-r--r--synapse/handlers/sync.py76
-rw-r--r--synapse/handlers/user_directory.py2
29 files changed, 557 insertions, 455 deletions
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 69650ff221..7294649d71 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -76,7 +76,7 @@ def create_issuing_service(reactor, acme_url, account_key_file, well_known_resou
     )
 
 
-@attr.s
+@attr.s(slots=True)
 @implementer(ICertificateStore)
 class ErsatzStore:
     """
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 918d0e037c..1ce2091b46 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
 
 class AdminHandler(BaseHandler):
     def __init__(self, hs):
-        super(AdminHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -125,8 +125,8 @@ class AdminHandler(BaseHandler):
             else:
                 stream_ordering = room.stream_ordering
 
-            from_key = str(RoomStreamToken(0, 0))
-            to_key = str(RoomStreamToken(None, stream_ordering))
+            from_key = RoomStreamToken(0, 0)
+            to_key = RoomStreamToken(None, stream_ordering)
 
             written_events = set()  # Events that we've processed in this room
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 90189869cc..00eae92052 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
     }
 
 
+@attr.s(slots=True)
+class SsoLoginExtraAttributes:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib(type=int)
+    extra_attributes = attr.ib(type=JsonDict)
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -145,7 +154,7 @@ class AuthHandler(BaseHandler):
         Args:
             hs (synapse.server.HomeServer):
         """
-        super(AuthHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
@@ -239,6 +248,10 @@ class AuthHandler(BaseHandler):
         # cast to tuple for use with str.startswith
         self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
 
+        # A mapping of user ID to extra attributes to include in the login
+        # response.
+        self._extra_attributes = {}  # type: Dict[str, SsoLoginExtraAttributes]
+
     async def validate_user_via_ui_auth(
         self,
         requester: Requester,
@@ -1165,6 +1178,7 @@ class AuthHandler(BaseHandler):
         registered_user_id: str,
         request: SynapseRequest,
         client_redirect_url: str,
+        extra_attributes: Optional[JsonDict] = None,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1173,6 +1187,8 @@ class AuthHandler(BaseHandler):
             request: The request to complete.
             client_redirect_url: The URL to which to redirect the user at the end of the
                 process.
+            extra_attributes: Extra attributes which will be passed to the client
+                during successful login. Must be JSON serializable.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1181,19 +1197,30 @@ class AuthHandler(BaseHandler):
             respond_with_html(request, 403, self._sso_account_deactivated_template)
             return
 
-        self._complete_sso_login(registered_user_id, request, client_redirect_url)
+        self._complete_sso_login(
+            registered_user_id, request, client_redirect_url, extra_attributes
+        )
 
     def _complete_sso_login(
         self,
         registered_user_id: str,
         request: SynapseRequest,
         client_redirect_url: str,
+        extra_attributes: Optional[JsonDict] = None,
     ):
         """
         The synchronous portion of complete_sso_login.
 
         This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
         """
+        # Store any extra attributes which will be passed in the login response.
+        # Note that this is per-user so it may overwrite a previous value, this
+        # is considered OK since the newest SSO attributes should be most valid.
+        if extra_attributes:
+            self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
+                self._clock.time_msec(), extra_attributes,
+            )
+
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
             registered_user_id
@@ -1226,6 +1253,37 @@ class AuthHandler(BaseHandler):
         )
         respond_with_html(request, 200, html)
 
+    async def _sso_login_callback(self, login_result: JsonDict) -> None:
+        """
+        A login callback which might add additional attributes to the login response.
+
+        Args:
+            login_result: The data to be sent to the client. Includes the user
+                ID and access token.
+        """
+        # Expire attributes before processing. Note that there shouldn't be any
+        # valid logins that still have extra attributes.
+        self._expire_sso_extra_attributes()
+
+        extra_attributes = self._extra_attributes.get(login_result["user_id"])
+        if extra_attributes:
+            login_result.update(extra_attributes.extra_attributes)
+
+    def _expire_sso_extra_attributes(self) -> None:
+        """
+        Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid.
+        """
+        # TODO This should match the amount of time the macaroon is valid for.
+        LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000
+        expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME
+        to_expire = set()
+        for user_id, data in self._extra_attributes.items():
+            if data.creation_time < expire_before:
+                to_expire.add(user_id)
+        for user_id in to_expire:
+            logger.debug("Expiring extra attributes for user %s", user_id)
+            del self._extra_attributes[user_id]
+
     @staticmethod
     def add_query_param_to_url(url: str, param_name: str, param: Any):
         url_parts = list(urllib.parse.urlparse(url))
@@ -1235,7 +1293,7 @@ class AuthHandler(BaseHandler):
         return urllib.parse.urlunparse(url_parts)
 
 
-@attr.s
+@attr.s(slots=True)
 class MacaroonGenerator:
 
     hs = attr.ib()
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 25169157c1..0635ad5708 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -29,7 +29,7 @@ class DeactivateAccountHandler(BaseHandler):
     """Handler which deals with deactivating user accounts."""
 
     def __init__(self, hs):
-        super(DeactivateAccountHandler, self).__init__(hs)
+        super().__init__(hs)
         self.hs = hs
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 643d71a710..b9d9098104 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
 from synapse.api import errors
 from synapse.api.constants import EventTypes
 from synapse.api.errors import (
+    Codes,
     FederationDeniedError,
     HttpResponseException,
     RequestSendFailed,
@@ -28,7 +29,7 @@ from synapse.api.errors import (
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import (
-    RoomStreamToken,
+    StreamToken,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
@@ -47,7 +48,7 @@ MAX_DEVICE_DISPLAY_NAME_LEN = 100
 
 class DeviceWorkerHandler(BaseHandler):
     def __init__(self, hs):
-        super(DeviceWorkerHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.hs = hs
         self.state = hs.get_state_handler()
@@ -104,18 +105,14 @@ class DeviceWorkerHandler(BaseHandler):
 
     @trace
     @measure_func("device.get_user_ids_changed")
-    async def get_user_ids_changed(self, user_id, from_token):
+    async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
         """Get list of users that have had the devices updated, or have newly
         joined a room, that `user_id` may be interested in.
-
-        Args:
-            user_id (str)
-            from_token (StreamToken)
         """
 
         set_tag("user_id", user_id)
         set_tag("from_token", from_token)
-        now_room_key = await self.store.get_room_events_max_id()
+        now_room_key = self.store.get_room_max_token()
 
         room_ids = await self.store.get_rooms_for_user(user_id)
 
@@ -142,7 +139,7 @@ class DeviceWorkerHandler(BaseHandler):
         )
         rooms_changed.update(event.room_id for event in member_events)
 
-        stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
+        stream_ordering = from_token.room_key.stream
 
         possibly_changed = set(changed)
         possibly_left = set()
@@ -253,7 +250,7 @@ class DeviceWorkerHandler(BaseHandler):
 
 class DeviceHandler(DeviceWorkerHandler):
     def __init__(self, hs):
-        super(DeviceHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
 
@@ -267,6 +264,24 @@ class DeviceHandler(DeviceWorkerHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
+    def _check_device_name_length(self, name: str):
+        """
+        Checks whether a device name is longer than the maximum allowed length.
+
+        Args:
+            name: The name of the device.
+
+        Raises:
+            SynapseError: if the device name is too long.
+        """
+        if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN:
+            raise SynapseError(
+                400,
+                "Device display name is too long (max %i)"
+                % (MAX_DEVICE_DISPLAY_NAME_LEN,),
+                errcode=Codes.TOO_LARGE,
+            )
+
     async def check_device_registered(
         self, user_id, device_id, initial_device_display_name=None
     ):
@@ -284,6 +299,9 @@ class DeviceHandler(DeviceWorkerHandler):
         Returns:
             str: device id (generated if none was supplied)
         """
+
+        self._check_device_name_length(initial_device_display_name)
+
         if device_id is not None:
             new_device = await self.store.store_device(
                 user_id=user_id,
@@ -399,12 +417,8 @@ class DeviceHandler(DeviceWorkerHandler):
 
         # Reject a new displayname which is too long.
         new_display_name = content.get("display_name")
-        if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN:
-            raise SynapseError(
-                400,
-                "Device display name is too long (max %i)"
-                % (MAX_DEVICE_DISPLAY_NAME_LEN,),
-            )
+
+        self._check_device_name_length(new_display_name)
 
         try:
             await self.store.update_device(
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 46826eb784..62aa9a2da8 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
 
 class DirectoryHandler(BaseHandler):
     def __init__(self, hs):
-        super(DirectoryHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.state = hs.get_state_handler()
         self.appservice_handler = hs.get_application_service_handler()
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d629c7c16c..dd40fd1299 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1201,7 +1201,7 @@ def _one_time_keys_match(old_key_json, new_key):
     return old_key == new_key_copy
 
 
-@attr.s
+@attr.s(slots=True)
 class SignatureListItem:
     """An item in the signature list as used by upload_signatures_for_device_keys.
     """
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index b05e32f457..539b4fc32e 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -37,11 +37,7 @@ logger = logging.getLogger(__name__)
 
 class EventStreamHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(EventStreamHandler, self).__init__(hs)
-
-        self.distributor = hs.get_distributor()
-        self.distributor.declare("started_user_eventstream")
-        self.distributor.declare("stopped_user_eventstream")
+        super().__init__(hs)
 
         self.clock = hs.get_clock()
 
@@ -137,8 +133,8 @@ class EventStreamHandler(BaseHandler):
 
             chunk = {
                 "chunk": chunks,
-                "start": tokens[0].to_string(),
-                "end": tokens[1].to_string(),
+                "start": await tokens[0].to_string(self.store),
+                "end": await tokens[1].to_string(self.store),
             }
 
             return chunk
@@ -146,7 +142,7 @@ class EventStreamHandler(BaseHandler):
 
 class EventHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(EventHandler, self).__init__(hs)
+        super().__init__(hs)
         self.storage = hs.get_storage()
 
     async def get_event(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 014dab2940..1a8144405a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,7 +21,7 @@ import itertools
 import logging
 from collections.abc import Container
 from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -69,26 +69,29 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
     ReplicationStoreRoomOnInviteRestServlet,
 )
-from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
-from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.state import StateResolutionStore
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     JsonDict,
     MutableStateMap,
+    PersistedEventPosition,
+    RoomStreamToken,
     StateMap,
     UserID,
     get_domain_from_id,
 )
 from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.util.stringutils import shortstr
 from synapse.visibility import filter_events_for_server
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+@attr.s(slots=True)
 class _NewEventInfo:
     """Holds information about a received event, ready for passing to _handle_new_events
 
@@ -116,8 +119,8 @@ class FederationHandler(BaseHandler):
         rooms.
     """
 
-    def __init__(self, hs):
-        super(FederationHandler, self).__init__(hs)
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
 
         self.hs = hs
 
@@ -126,11 +129,11 @@ class FederationHandler(BaseHandler):
         self.state_store = self.storage.state
         self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
+        self._state_resolution_handler = hs.get_state_resolution_handler()
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
         self.action_generator = hs.get_action_generator()
         self.is_mine_id = hs.is_mine_id
-        self.pusher_pool = hs.get_pusherpool()
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
         self._message_handler = hs.get_message_handler()
@@ -141,9 +144,6 @@ class FederationHandler(BaseHandler):
         self._replication = hs.get_replication_data_handler()
 
         self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
-        self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
-            hs
-        )
         self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
             hs
         )
@@ -159,8 +159,9 @@ class FederationHandler(BaseHandler):
             self._device_list_updater = hs.get_device_handler().device_list_updater
             self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
 
-        # When joining a room we need to queue any events for that room up
-        self.room_queues = {}
+        # When joining a room we need to queue any events for that room up.
+        # For each room, a list of (pdu, origin) tuples.
+        self.room_queues = {}  # type: Dict[str, List[Tuple[EventBase, str]]]
         self._room_pdu_linearizer = Linearizer("fed_room_pdu")
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -285,7 +286,7 @@ class FederationHandler(BaseHandler):
                             raise Exception(
                                 "Error fetching missing prev_events for %s: %s"
                                 % (event_id, e)
-                            )
+                            ) from e
 
                         # Update the set of things we've seen after trying to
                         # fetch the missing stuff
@@ -384,8 +385,7 @@ class FederationHandler(BaseHandler):
                                 event_map[x.event_id] = x
 
                     room_version = await self.store.get_room_version_id(room_id)
-                    state_map = await resolve_events_with_store(
-                        self.clock,
+                    state_map = await self._state_resolution_handler.resolve_events_with_store(
                         room_id,
                         room_version,
                         state_maps,
@@ -704,31 +704,10 @@ class FederationHandler(BaseHandler):
         logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
 
         try:
-            context = await self._handle_new_event(origin, event, state=state)
+            await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
-        if event.type == EventTypes.Member:
-            if event.membership == Membership.JOIN:
-                # Only fire user_joined_room if the user has acutally
-                # joined the room. Don't bother if the user is just
-                # changing their profile info.
-                newly_joined = True
-
-                prev_state_ids = await context.get_prev_state_ids()
-
-                prev_state_id = prev_state_ids.get((event.type, event.state_key))
-                if prev_state_id:
-                    prev_state = await self.store.get_event(
-                        prev_state_id, allow_none=True
-                    )
-                    if prev_state and prev_state.membership == Membership.JOIN:
-                        newly_joined = False
-
-                if newly_joined:
-                    user = UserID.from_string(event.state_key)
-                    await self.user_joined_room(user, room_id)
-
         # For encrypted messages we check that we know about the sending device,
         # if we don't then we mark the device cache for that user as stale.
         if event.type == EventTypes.Encrypted:
@@ -839,6 +818,9 @@ class FederationHandler(BaseHandler):
             dest, room_id, limit=limit, extremities=extremities
         )
 
+        if not events:
+            return []
+
         # ideally we'd sanity check the events here for excess prev_events etc,
         # but it's hard to reject events at this point without completely
         # breaking backfill in the same way that it is currently broken by
@@ -923,7 +905,8 @@ class FederationHandler(BaseHandler):
                 )
             )
 
-        await self._handle_new_events(dest, ev_infos, backfilled=True)
+        if ev_infos:
+            await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -1265,7 +1248,7 @@ class FederationHandler(BaseHandler):
             event_infos.append(_NewEventInfo(event, None, auth))
 
         await self._handle_new_events(
-            destination, event_infos,
+            destination, room_id, event_infos,
         )
 
     def _sanity_check_event(self, ev):
@@ -1412,15 +1395,15 @@ class FederationHandler(BaseHandler):
             )
 
             max_stream_id = await self._persist_auth_tree(
-                origin, auth_chain, state, event, room_version_obj
+                origin, room_id, auth_chain, state, event, room_version_obj
             )
 
             # We wait here until this instance has seen the events come down
             # replication (if we're using replication) as the below uses caches.
-            #
-            # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                self.config.worker.writers.events, "events", max_stream_id
+                self.config.worker.events_shard_config.get_instance(room_id),
+                "events",
+                max_stream_id,
             )
 
             # Check whether this room is the result of an upgrade of a room we already know
@@ -1599,11 +1582,6 @@ class FederationHandler(BaseHandler):
             event.signatures,
         )
 
-        if event.type == EventTypes.Member:
-            if event.content["membership"] == Membership.JOIN:
-                user = UserID.from_string(event.state_key)
-                await self.user_joined_room(user, event.room_id)
-
         prev_state_ids = await context.get_prev_state_ids()
 
         state_ids = list(prev_state_ids.values())
@@ -1674,7 +1652,7 @@ class FederationHandler(BaseHandler):
         )
 
         context = await self.state_handler.compute_event_context(event)
-        await self.persist_events_and_notify([(event, context)])
+        await self.persist_events_and_notify(event.room_id, [(event, context)])
 
         return event
 
@@ -1701,7 +1679,9 @@ class FederationHandler(BaseHandler):
         await self.federation_client.send_leave(host_list, event)
 
         context = await self.state_handler.compute_event_context(event)
-        stream_id = await self.persist_events_and_notify([(event, context)])
+        stream_id = await self.persist_events_and_notify(
+            event.room_id, [(event, context)]
+        )
 
         return event, stream_id
 
@@ -1949,7 +1929,7 @@ class FederationHandler(BaseHandler):
                 )
 
             await self.persist_events_and_notify(
-                [(event, context)], backfilled=backfilled
+                event.room_id, [(event, context)], backfilled=backfilled
             )
         except Exception:
             run_in_background(
@@ -1962,6 +1942,7 @@ class FederationHandler(BaseHandler):
     async def _handle_new_events(
         self,
         origin: str,
+        room_id: str,
         event_infos: Iterable[_NewEventInfo],
         backfilled: bool = False,
     ) -> None:
@@ -1993,6 +1974,7 @@ class FederationHandler(BaseHandler):
         )
 
         await self.persist_events_and_notify(
+            room_id,
             [
                 (ev_info.event, context)
                 for ev_info, context in zip(event_infos, contexts)
@@ -2003,6 +1985,7 @@ class FederationHandler(BaseHandler):
     async def _persist_auth_tree(
         self,
         origin: str,
+        room_id: str,
         auth_events: List[EventBase],
         state: List[EventBase],
         event: EventBase,
@@ -2017,6 +2000,7 @@ class FederationHandler(BaseHandler):
 
         Args:
             origin: Where the events came from
+            room_id,
             auth_events
             state
             event
@@ -2091,17 +2075,20 @@ class FederationHandler(BaseHandler):
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
         await self.persist_events_and_notify(
+            room_id,
             [
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
-            ]
+            ],
         )
 
         new_event_context = await self.state_handler.compute_event_context(
             event, old_state=state
         )
 
-        return await self.persist_events_and_notify([(event, new_event_context)])
+        return await self.persist_events_and_notify(
+            room_id, [(event, new_event_context)]
+        )
 
     async def _prep_event(
         self,
@@ -2184,10 +2171,10 @@ class FederationHandler(BaseHandler):
             # given state at the event. This should correctly handle cases
             # like bans, especially with state res v2.
 
-            state_sets = await self.state_store.get_state_groups(
+            state_sets_d = await self.state_store.get_state_groups(
                 event.room_id, extrem_ids
             )
-            state_sets = list(state_sets.values())
+            state_sets = list(state_sets_d.values())  # type: List[Iterable[EventBase]]
             state_sets.append(state)
             current_states = await self.state_handler.resolve_events(
                 room_version, state_sets, event
@@ -2952,6 +2939,7 @@ class FederationHandler(BaseHandler):
 
     async def persist_events_and_notify(
         self,
+        room_id: str,
         event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
     ) -> int:
@@ -2959,20 +2947,26 @@ class FederationHandler(BaseHandler):
         necessary.
 
         Args:
-            event_and_contexts:
+            room_id: The room ID of events being persisted.
+            event_and_contexts: Sequence of events with their associated
+                context that should be persisted. All events must belong to
+                the same room.
             backfilled: Whether these events are a result of
                 backfilling or not
         """
-        if self.config.worker.writers.events != self._instance_name:
+        instance = self.config.worker.events_shard_config.get_instance(room_id)
+        if instance != self._instance_name:
             result = await self._send_events(
-                instance_name=self.config.worker.writers.events,
+                instance_name=instance,
                 store=self.store,
+                room_id=room_id,
                 event_and_contexts=event_and_contexts,
                 backfilled=backfilled,
             )
             return result["max_stream_id"]
         else:
-            max_stream_id = await self.storage.persistence.persist_events(
+            assert self.storage.persistence
+            max_stream_token = await self.storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
@@ -2983,12 +2977,12 @@ class FederationHandler(BaseHandler):
 
             if not backfilled:  # Never notify for backfilled events
                 for event, _ in event_and_contexts:
-                    await self._notify_persisted_event(event, max_stream_id)
+                    await self._notify_persisted_event(event, max_stream_token)
 
-            return max_stream_id
+            return max_stream_token.stream
 
     async def _notify_persisted_event(
-        self, event: EventBase, max_stream_id: int
+        self, event: EventBase, max_stream_token: RoomStreamToken
     ) -> None:
         """Checks to see if notifier/pushers should be notified about the
         event or not.
@@ -3014,13 +3008,13 @@ class FederationHandler(BaseHandler):
         elif event.internal_metadata.is_outlier():
             return
 
-        event_stream_id = event.internal_metadata.stream_ordering
+        event_pos = PersistedEventPosition(
+            self._instance_name, event.internal_metadata.stream_ordering
+        )
         self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id, extra_users=extra_users
+            event, event_pos, max_stream_token, extra_users=extra_users
         )
 
-        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
     async def _clean_room_for_join(self, room_id: str) -> None:
         """Called to clean up any data in DB for a given room, ready for the
         server to join the room.
@@ -3033,16 +3027,6 @@ class FederationHandler(BaseHandler):
         else:
             await self.store.clean_room_for_join(room_id)
 
-    async def user_joined_room(self, user: UserID, room_id: str) -> None:
-        """Called when a new user has joined the room
-        """
-        if self.config.worker_app:
-            await self._notify_user_membership_change(
-                room_id=room_id, user_id=user.to_string(), change="joined"
-            )
-        else:
-            user_joined_room(self.distributor, user, room_id)
-
     async def get_room_complexity(
         self, remote_room_hosts: List[str], room_id: str
     ) -> Optional[dict]:
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index ba47e3a242..489a7b885d 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -240,7 +240,7 @@ class GroupsLocalWorkerHandler:
 
 class GroupsLocalHandler(GroupsLocalWorkerHandler):
     def __init__(self, hs):
-        super(GroupsLocalHandler, self).__init__(hs)
+        super().__init__(hs)
 
         # Ensure attestations get renewed
         hs.get_groups_attestation_renewer()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 0ce6ddfbe4..bc3e9607ca 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -21,8 +21,6 @@ import logging
 import urllib.parse
 from typing import Awaitable, Callable, Dict, List, Optional, Tuple
 
-from twisted.internet.error import TimeoutError
-
 from synapse.api.errors import (
     CodeMessageException,
     Codes,
@@ -30,6 +28,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
 from synapse.types import JsonDict, Requester
 from synapse.util import json_decoder
@@ -45,7 +44,7 @@ id_server_scheme = "https://"
 
 class IdentityHandler(BaseHandler):
     def __init__(self, hs):
-        super(IdentityHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.http_client = SimpleHttpClient(hs)
         # We create a blacklisting instance of SimpleHttpClient for contacting identity
@@ -93,7 +92,7 @@ class IdentityHandler(BaseHandler):
 
         try:
             data = await self.http_client.get_json(url, query_params)
-        except TimeoutError:
+        except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except HttpResponseException as e:
             logger.info(
@@ -173,7 +172,7 @@ class IdentityHandler(BaseHandler):
             if e.code != 404 or not use_v2:
                 logger.error("3PID bind failed with Matrix error: %r", e)
                 raise e.to_synapse_error()
-        except TimeoutError:
+        except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except CodeMessageException as e:
             data = json_decoder.decode(e.msg)  # XXX WAT?
@@ -273,7 +272,7 @@ class IdentityHandler(BaseHandler):
             else:
                 logger.error("Failed to unbind threepid on identity server: %s", e)
                 raise SynapseError(500, "Failed to contact identity server")
-        except TimeoutError:
+        except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
         await self.store.remove_user_bound_threepid(
@@ -419,7 +418,7 @@ class IdentityHandler(BaseHandler):
         except HttpResponseException as e:
             logger.info("Proxied requestToken failed: %r", e)
             raise e.to_synapse_error()
-        except TimeoutError:
+        except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
     async def requestMsisdnToken(
@@ -471,7 +470,7 @@ class IdentityHandler(BaseHandler):
         except HttpResponseException as e:
             logger.info("Proxied requestToken failed: %r", e)
             raise e.to_synapse_error()
-        except TimeoutError:
+        except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
         assert self.hs.config.public_baseurl
@@ -553,7 +552,7 @@ class IdentityHandler(BaseHandler):
                 id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
                 body,
             )
-        except TimeoutError:
+        except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except HttpResponseException as e:
             logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
@@ -627,7 +626,7 @@ class IdentityHandler(BaseHandler):
                 # require or validate it. See the following for context:
                 # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950
                 return data["mxid"]
-        except TimeoutError:
+        except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except IOError as e:
             logger.warning("Error from v1 identity server lookup: %s" % (e,))
@@ -655,7 +654,7 @@ class IdentityHandler(BaseHandler):
                 "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
                 {"access_token": id_access_token},
             )
-        except TimeoutError:
+        except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
         if not isinstance(hash_details, dict):
@@ -727,7 +726,7 @@ class IdentityHandler(BaseHandler):
                 },
                 headers=headers,
             )
-        except TimeoutError:
+        except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except Exception as e:
             logger.warning("Error when performing a v2 3pid lookup: %s", e)
@@ -823,7 +822,7 @@ class IdentityHandler(BaseHandler):
                     invite_config,
                     {"Authorization": create_id_access_token_header(id_access_token)},
                 )
-            except TimeoutError:
+            except RequestTimedOutError:
                 raise SynapseError(500, "Timed out contacting identity server")
             except HttpResponseException as e:
                 if e.code != 404:
@@ -841,7 +840,7 @@ class IdentityHandler(BaseHandler):
                 data = await self.blacklisting_http_client.post_json_get_json(
                     url, invite_config
                 )
-            except TimeoutError:
+            except RequestTimedOutError:
                 raise SynapseError(500, "Timed out contacting identity server")
             except HttpResponseException as e:
                 logger.warning(
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d5ddc583ad..39a85801c1 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage.roommember import RoomsForUser
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, StreamToken, UserID
+from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
 
 class InitialSyncHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(InitialSyncHandler, self).__init__(hs)
+        super().__init__(hs)
         self.hs = hs
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
@@ -116,14 +116,13 @@ class InitialSyncHandler(BaseHandler):
         now_token = self.hs.get_event_sources().get_current_token()
 
         presence_stream = self.hs.get_event_sources().sources["presence"]
-        pagination_config = PaginationConfig(from_token=now_token)
-        presence, _ = await presence_stream.get_pagination_rows(
-            user, pagination_config.get_source_config("presence"), None
+        presence, _ = await presence_stream.get_new_events(
+            user, from_key=None, include_offline=False
         )
 
-        receipt_stream = self.hs.get_event_sources().sources["receipt"]
-        receipt, _ = await receipt_stream.get_pagination_rows(
-            user, pagination_config.get_source_config("receipt"), None
+        joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
+        receipt = await self.store.get_linearized_receipts_for_rooms(
+            joined_rooms, to_key=int(now_token.receipt_key),
         )
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -168,7 +167,7 @@ class InitialSyncHandler(BaseHandler):
                         self.state_handler.get_current_state, event.room_id
                     )
                 elif event.membership == Membership.LEAVE:
-                    room_end_token = "s%d" % (event.stream_ordering,)
+                    room_end_token = RoomStreamToken(None, event.stream_ordering,)
                     deferred_room_state = run_in_background(
                         self.state_store.get_state_for_events, [event.event_id]
                     )
@@ -204,8 +203,8 @@ class InitialSyncHandler(BaseHandler):
                             messages, time_now=time_now, as_client_event=as_client_event
                         )
                     ),
-                    "start": start_token.to_string(),
-                    "end": end_token.to_string(),
+                    "start": await start_token.to_string(self.store),
+                    "end": await end_token.to_string(self.store),
                 }
 
                 d["state"] = await self._event_serializer.serialize_events(
@@ -250,7 +249,7 @@ class InitialSyncHandler(BaseHandler):
             ],
             "account_data": account_data_events,
             "receipts": receipt,
-            "end": now_token.to_string(),
+            "end": await now_token.to_string(self.store),
         }
 
         return ret
@@ -326,7 +325,8 @@ class InitialSyncHandler(BaseHandler):
         if limit is None:
             limit = 10
 
-        stream_token = await self.store.get_stream_token_for_event(member_event_id)
+        leave_position = await self.store.get_position_for_event(member_event_id)
+        stream_token = leave_position.to_room_stream_token()
 
         messages, token = await self.store.get_recent_events_for_room(
             room_id, limit=limit, end_token=stream_token
@@ -348,8 +348,8 @@ class InitialSyncHandler(BaseHandler):
                 "chunk": (
                     await self._event_serializer.serialize_events(messages, time_now)
                 ),
-                "start": start_token.to_string(),
-                "end": end_token.to_string(),
+                "start": await start_token.to_string(self.store),
+                "end": await end_token.to_string(self.store),
             },
             "state": (
                 await self._event_serializer.serialize_events(
@@ -447,8 +447,8 @@ class InitialSyncHandler(BaseHandler):
                 "chunk": (
                     await self._event_serializer.serialize_events(messages, time_now)
                 ),
-                "start": start_token.to_string(),
-                "end": end_token.to_string(),
+                "start": await start_token.to_string(self.store),
+                "end": await end_token.to_string(self.store),
             },
             "state": state,
             "presence": presence,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 8a7b4916cd..ee271e85e5 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -376,9 +376,8 @@ class EventCreationHandler:
         self.notifier = hs.get_notifier()
         self.config = hs.config
         self.require_membership_for_aliases = hs.config.require_membership_for_aliases
-        self._is_event_writer = (
-            self.config.worker.writers.events == hs.get_instance_name()
-        )
+        self._events_shard_config = self.config.worker.events_shard_config
+        self._instance_name = hs.get_instance_name()
 
         self.room_invite_state_types = self.hs.config.room_invite_state_types
 
@@ -387,8 +386,6 @@ class EventCreationHandler:
         # This is only used to get at ratelimit function, and maybe_kick_guest_users
         self.base_handler = BaseHandler(hs)
 
-        self.pusher_pool = hs.get_pusherpool()
-
         # We arbitrarily limit concurrent event creation for a room to 5.
         # This is to stop us from diverging history *too* much.
         self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
@@ -904,9 +901,10 @@ class EventCreationHandler:
 
         try:
             # If we're a worker we need to hit out to the master.
-            if not self._is_event_writer:
+            writer_instance = self._events_shard_config.get_instance(event.room_id)
+            if writer_instance != self._instance_name:
                 result = await self.send_event(
-                    instance_name=self.config.worker.writers.events,
+                    instance_name=writer_instance,
                     event_id=event.event_id,
                     store=self.store,
                     requester=requester,
@@ -974,7 +972,10 @@ class EventCreationHandler:
 
         This should only be run on the instance in charge of persisting events.
         """
-        assert self._is_event_writer
+        assert self.storage.persistence is not None
+        assert self._events_shard_config.should_handle(
+            self._instance_name, event.room_id
+        )
 
         if ratelimit:
             # We check if this is a room admin redacting an event so that we
@@ -1137,7 +1138,7 @@ class EventCreationHandler:
             if prev_state_ids:
                 raise AuthError(403, "Changing the room create event is forbidden")
 
-        event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
+        event_pos, max_stream_token = await self.storage.persistence.persist_event(
             event, context=context
         )
 
@@ -1145,12 +1146,10 @@ class EventCreationHandler:
             # If there's an expiry timestamp on the event, schedule its expiry.
             self._message_handler.maybe_schedule_expiry(event)
 
-        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
         def _notify():
             try:
                 self.notifier.on_new_room_event(
-                    event, event_stream_id, max_stream_id, extra_users=extra_users
+                    event, event_pos, max_stream_token, extra_users=extra_users
                 )
             except Exception:
                 logger.exception("Error notifying about new room event")
@@ -1162,7 +1161,7 @@ class EventCreationHandler:
             # matters as sometimes presence code can take a while.
             run_in_background(self._bump_active_time, requester.user)
 
-        return event_stream_id
+        return event_pos.stream
 
     async def _bump_active_time(self, user: UserID) -> None:
         try:
@@ -1183,54 +1182,7 @@ class EventCreationHandler:
         )
 
         for room_id in room_ids:
-            # For each room we need to find a joined member we can use to send
-            # the dummy event with.
-
-            latest_event_ids = await self.store.get_prev_events_for_room(room_id)
-
-            members = await self.state.get_current_users_in_room(
-                room_id, latest_event_ids=latest_event_ids
-            )
-            dummy_event_sent = False
-            for user_id in members:
-                if not self.hs.is_mine_id(user_id):
-                    continue
-                requester = create_requester(user_id)
-                try:
-                    event, context = await self.create_event(
-                        requester,
-                        {
-                            "type": "org.matrix.dummy_event",
-                            "content": {},
-                            "room_id": room_id,
-                            "sender": user_id,
-                        },
-                        prev_event_ids=latest_event_ids,
-                    )
-
-                    event.internal_metadata.proactively_send = False
-
-                    # Since this is a dummy-event it is OK if it is sent by a
-                    # shadow-banned user.
-                    await self.send_nonmember_event(
-                        requester,
-                        event,
-                        context,
-                        ratelimit=False,
-                        ignore_shadow_ban=True,
-                    )
-                    dummy_event_sent = True
-                    break
-                except ConsentNotGivenError:
-                    logger.info(
-                        "Failed to send dummy event into room %s for user %s due to "
-                        "lack of consent. Will try another user" % (room_id, user_id)
-                    )
-                except AuthError:
-                    logger.info(
-                        "Failed to send dummy event into room %s for user %s due to "
-                        "lack of power. Will try another user" % (room_id, user_id)
-                    )
+            dummy_event_sent = await self._send_dummy_event_for_room(room_id)
 
             if not dummy_event_sent:
                 # Did not find a valid user in the room, so remove from future attempts
@@ -1243,6 +1195,59 @@ class EventCreationHandler:
                 now = self.clock.time_msec()
                 self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now
 
+    async def _send_dummy_event_for_room(self, room_id: str) -> bool:
+        """Attempt to send a dummy event for the given room.
+
+        Args:
+            room_id: room to try to send an event from
+
+        Returns:
+            True if a dummy event was successfully sent. False if no user was able
+            to send an event.
+        """
+
+        # For each room we need to find a joined member we can use to send
+        # the dummy event with.
+        latest_event_ids = await self.store.get_prev_events_for_room(room_id)
+        members = await self.state.get_current_users_in_room(
+            room_id, latest_event_ids=latest_event_ids
+        )
+        for user_id in members:
+            if not self.hs.is_mine_id(user_id):
+                continue
+            requester = create_requester(user_id)
+            try:
+                event, context = await self.create_event(
+                    requester,
+                    {
+                        "type": "org.matrix.dummy_event",
+                        "content": {},
+                        "room_id": room_id,
+                        "sender": user_id,
+                    },
+                    prev_event_ids=latest_event_ids,
+                )
+
+                event.internal_metadata.proactively_send = False
+
+                # Since this is a dummy-event it is OK if it is sent by a
+                # shadow-banned user.
+                await self.send_nonmember_event(
+                    requester, event, context, ratelimit=False, ignore_shadow_ban=True,
+                )
+                return True
+            except ConsentNotGivenError:
+                logger.info(
+                    "Failed to send dummy event into room %s for user %s due to "
+                    "lack of consent. Will try another user" % (room_id, user_id)
+                )
+            except AuthError:
+                logger.info(
+                    "Failed to send dummy event into room %s for user %s due to "
+                    "lack of power. Will try another user" % (room_id, user_id)
+                )
+        return False
+
     def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
         expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
         to_expire = set()
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 1b06f3173f..19cd652675 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -37,7 +37,7 @@ from synapse.config import ConfigError
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -114,6 +114,7 @@ class OidcHandler:
             hs.config.oidc_user_mapping_provider_config
         )  # type: OidcMappingProvider
         self._skip_verification = hs.config.oidc_skip_verification  # type: bool
+        self._allow_existing_users = hs.config.oidc_allow_existing_users  # type: bool
 
         self._http_client = hs.get_proxied_http_client()
         self._auth_handler = hs.get_auth_handler()
@@ -131,10 +132,10 @@ class OidcHandler:
     def _render_error(
         self, request, error: str, error_description: Optional[str] = None
     ) -> None:
-        """Renders the error template and respond with it.
+        """Render the error template and respond to the request with it.
 
         This is used to show errors to the user. The template of this page can
-        be found under ``synapse/res/templates/sso_error.html``.
+        be found under `synapse/res/templates/sso_error.html`.
 
         Args:
             request: The incoming request from the browser.
@@ -706,6 +707,15 @@ class OidcHandler:
             self._render_error(request, "mapping_error", str(e))
             return
 
+        # Mapping providers might not have get_extra_attributes: only call this
+        # method if it exists.
+        extra_attributes = None
+        get_extra_attributes = getattr(
+            self._user_mapping_provider, "get_extra_attributes", None
+        )
+        if get_extra_attributes:
+            extra_attributes = await get_extra_attributes(userinfo, token)
+
         # and finally complete the login
         if ui_auth_session_id:
             await self._auth_handler.complete_sso_ui_auth(
@@ -713,7 +723,7 @@ class OidcHandler:
             )
         else:
             await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url
+                user_id, request, client_redirect_url, extra_attributes
             )
 
     def _generate_oidc_session_token(
@@ -849,7 +859,8 @@ class OidcHandler:
         If we don't find the user that way, we should register the user,
         mapping the localpart and the display name from the UserInfo.
 
-        If a user already exists with the mxid we've mapped, raise an exception.
+        If a user already exists with the mxid we've mapped and allow_existing_users
+        is disabled, raise an exception.
 
         Args:
             userinfo: an object representing the user
@@ -905,21 +916,31 @@ class OidcHandler:
 
         localpart = map_username_to_mxid_localpart(attributes["localpart"])
 
-        user_id = UserID(localpart, self._hostname)
-        if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
-            # This mxid is taken
-            raise MappingException(
-                "mxid '{}' is already taken".format(user_id.to_string())
+        user_id = UserID(localpart, self._hostname).to_string()
+        users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+        if users:
+            if self._allow_existing_users:
+                if len(users) == 1:
+                    registered_user_id = next(iter(users))
+                elif user_id in users:
+                    registered_user_id = user_id
+                else:
+                    raise MappingException(
+                        "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+                            user_id, list(users.keys())
+                        )
+                    )
+            else:
+                # This mxid is taken
+                raise MappingException("mxid '{}' is already taken".format(user_id))
+        else:
+            # It's the first time this user is logging in and the mapped mxid was
+            # not taken, register the user
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart,
+                default_display_name=attributes["display_name"],
+                user_agent_ips=(user_agent, ip_address),
             )
-
-        # It's the first time this user is logging in and the mapped mxid was
-        # not taken, register the user
-        registered_user_id = await self._registration_handler.register_user(
-            localpart=localpart,
-            default_display_name=attributes["display_name"],
-            user_agent_ips=(user_agent, ip_address),
-        )
-
         await self._datastore.record_user_external_id(
             self._auth_provider_id, remote_user_id, registered_user_id,
         )
@@ -972,7 +993,7 @@ class OidcMappingProvider(Generic[C]):
     async def map_user_attributes(
         self, userinfo: UserInfo, token: Token
     ) -> UserAttribute:
-        """Map a ``UserInfo`` objects into user attributes.
+        """Map a `UserInfo` object into user attributes.
 
         Args:
             userinfo: An object representing the user given by the OIDC provider
@@ -983,6 +1004,18 @@ class OidcMappingProvider(Generic[C]):
         """
         raise NotImplementedError()
 
+    async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+        """Map a `UserInfo` object into additional attributes passed to the client during login.
+
+        Args:
+            userinfo: An object representing the user given by the OIDC provider
+            token: A dict with the tokens returned by the provider
+
+        Returns:
+            A dict containing additional attributes. Must be JSON serializable.
+        """
+        return {}
+
 
 # Used to clear out "None" values in templates
 def jinja_finalize(thing):
@@ -997,6 +1030,7 @@ class JinjaOidcMappingConfig:
     subject_claim = attr.ib()  # type: str
     localpart_template = attr.ib()  # type: Template
     display_name_template = attr.ib()  # type: Optional[Template]
+    extra_attributes = attr.ib()  # type: Dict[str, Template]
 
 
 class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1035,10 +1069,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
                     % (e,)
                 )
 
+        extra_attributes = {}  # type Dict[str, Template]
+        if "extra_attributes" in config:
+            extra_attributes_config = config.get("extra_attributes") or {}
+            if not isinstance(extra_attributes_config, dict):
+                raise ConfigError(
+                    "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
+                )
+
+            for key, value in extra_attributes_config.items():
+                try:
+                    extra_attributes[key] = env.from_string(value)
+                except Exception as e:
+                    raise ConfigError(
+                        "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
+                        % (key, e)
+                    )
+
         return JinjaOidcMappingConfig(
             subject_claim=subject_claim,
             localpart_template=localpart_template,
             display_name_template=display_name_template,
+            extra_attributes=extra_attributes,
         )
 
     def get_remote_user_id(self, userinfo: UserInfo) -> str:
@@ -1059,3 +1111,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
                 display_name = None
 
         return UserAttribute(localpart=localpart, display_name=display_name)
+
+    async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+        extras = {}  # type: Dict[str, str]
+        for key, template in self._config.extra_attributes.items():
+            try:
+                extras[key] = template.render(user=userinfo).strip()
+            except Exception as e:
+                # Log an error and skip this value (don't break login for this).
+                logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
+        return extras
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6067585f9b..2c2a633938 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -25,7 +25,7 @@ from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
-from synapse.types import Requester, RoomStreamToken
+from synapse.types import Requester
 from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.stringutils import random_string
 from synapse.visibility import filter_events_for_client
@@ -335,20 +335,16 @@ class PaginationHandler:
         user_id = requester.user.to_string()
 
         if pagin_config.from_token:
-            room_token = pagin_config.from_token.room_key
+            from_token = pagin_config.from_token
         else:
-            pagin_config.from_token = (
-                self.hs.get_event_sources().get_current_token_for_pagination()
-            )
-            room_token = pagin_config.from_token.room_key
-
-        room_token = RoomStreamToken.parse(room_token)
+            from_token = self.hs.get_event_sources().get_current_token_for_pagination()
 
-        pagin_config.from_token = pagin_config.from_token.copy_and_replace(
-            "room_key", str(room_token)
-        )
+        if pagin_config.limit is None:
+            # This shouldn't happen as we've set a default limit before this
+            # gets called.
+            raise Exception("limit not set")
 
-        source_config = pagin_config.get_source_config("room")
+        room_token = from_token.room_key
 
         with await self.pagination_lock.read(room_id):
             (
@@ -358,7 +354,7 @@ class PaginationHandler:
                 room_id, user_id, allow_departed_users=True
             )
 
-            if source_config.direction == "b":
+            if pagin_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
                 # requires that we have a topo token.
                 if room_token.topological:
@@ -380,23 +376,31 @@ class PaginationHandler:
                     leave_token = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
-                    if RoomStreamToken.parse(leave_token).topological < curr_topo:
-                        source_config.from_key = str(leave_token)
+                    assert leave_token.topological is not None
+
+                    if leave_token.topological < curr_topo:
+                        from_token = from_token.copy_and_replace(
+                            "room_key", leave_token
+                        )
 
                 await self.hs.get_handlers().federation_handler.maybe_backfill(
-                    room_id, curr_topo, limit=source_config.limit,
+                    room_id, curr_topo, limit=pagin_config.limit,
                 )
 
+            to_room_key = None
+            if pagin_config.to_token:
+                to_room_key = pagin_config.to_token.room_key
+
             events, next_key = await self.store.paginate_room_events(
                 room_id=room_id,
-                from_key=source_config.from_key,
-                to_key=source_config.to_key,
-                direction=source_config.direction,
-                limit=source_config.limit,
+                from_key=from_token.room_key,
+                to_key=to_room_key,
+                direction=pagin_config.direction,
+                limit=pagin_config.limit,
                 event_filter=event_filter,
             )
 
-            next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
+            next_token = from_token.copy_and_replace("room_key", next_key)
 
         if events:
             if event_filter:
@@ -409,8 +413,8 @@ class PaginationHandler:
         if not events:
             return {
                 "chunk": [],
-                "start": pagin_config.from_token.to_string(),
-                "end": next_token.to_string(),
+                "start": await from_token.to_string(self.store),
+                "end": await next_token.to_string(self.store),
             }
 
         state = None
@@ -438,8 +442,8 @@ class PaginationHandler:
                     events, time_now, as_client_event=as_client_event
                 )
             ),
-            "start": pagin_config.from_token.to_string(),
-            "end": next_token.to_string(),
+            "start": await from_token.to_string(self.store),
+            "end": await next_token.to_string(self.store),
         }
 
         if state:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 91a3aec1cc..1000ac95ff 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1108,9 +1108,6 @@ class PresenceEventSource:
     def get_current_key(self):
         return self.store.get_current_presence_token()
 
-    async def get_pagination_rows(self, user, pagination_config, key):
-        return await self.get_new_events(user, from_key=None, include_offline=False)
-
     @cached(num_args=2, cache_context=True)
     async def _get_interested_in(self, user, explicit_room_id, cache_context):
         """Returns the set of users that the given user should see presence
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 0cb8fad89a..5453e6dfc8 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -44,7 +44,7 @@ class BaseProfileHandler(BaseHandler):
     """
 
     def __init__(self, hs):
-        super(BaseProfileHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.federation = hs.get_federation_client()
         hs.get_federation_registry().register_query_handler(
@@ -369,7 +369,7 @@ class MasterProfileHandler(BaseProfileHandler):
     PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
 
     def __init__(self, hs):
-        super(MasterProfileHandler, self).__init__(hs)
+        super().__init__(hs)
 
         assert hs.config.worker_app is None
 
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index e3b528d271..c32f314a1c 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
 
 class ReadMarkerHandler(BaseHandler):
     def __init__(self, hs):
-        super(ReadMarkerHandler, self).__init__(hs)
+        super().__init__(hs)
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
         self.read_marker_linearizer = Linearizer(name="read_marker")
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 2cc6c2eb68..7225923757 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
 
 class ReceiptsHandler(BaseHandler):
     def __init__(self, hs):
-        super(ReceiptsHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
@@ -142,18 +142,3 @@ class ReceiptEventSource:
 
     def get_current_key(self, direction="f"):
         return self.store.get_max_receipt_stream_id()
-
-    async def get_pagination_rows(self, user, config, key):
-        to_key = int(config.from_key)
-
-        if config.to_key:
-            from_key = int(config.to_key)
-        else:
-            from_key = None
-
-        room_ids = await self.store.get_rooms_for_user(user.to_string())
-        events = await self.store.get_linearized_receipts_for_rooms(
-            room_ids, from_key=from_key, to_key=to_key
-        )
-
-        return (events, to_key)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index cde2dbca92..538f4b2a61 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -42,7 +42,7 @@ class RegistrationHandler(BaseHandler):
         Args:
             hs (synapse.server.HomeServer):
         """
-        super(RegistrationHandler, self).__init__(hs)
+        super().__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a29305f655..d5f7c78edf 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -70,7 +70,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
 
 class RoomCreationHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(RoomCreationHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
@@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler):
 
         # Always wait for room creation to progate before returning
         await self._replication.wait_for_stream_position(
-            self.hs.config.worker.writers.events, "events", last_stream_id
+            self.hs.config.worker.events_shard_config.get_instance(room_id),
+            "events",
+            last_stream_id,
         )
 
         return result, last_stream_id
@@ -1075,11 +1077,13 @@ class RoomContextHandler:
         # the token, which we replace.
         token = StreamToken.START
 
-        results["start"] = token.copy_and_replace(
+        results["start"] = await token.copy_and_replace(
             "room_key", results["start"]
-        ).to_string()
+        ).to_string(self.store)
 
-        results["end"] = token.copy_and_replace("room_key", results["end"]).to_string()
+        results["end"] = await token.copy_and_replace(
+            "room_key", results["end"]
+        ).to_string(self.store)
 
         return results
 
@@ -1091,20 +1095,19 @@ class RoomEventSource:
     async def get_new_events(
         self,
         user: UserID,
-        from_key: str,
+        from_key: RoomStreamToken,
         limit: int,
         room_ids: List[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
-    ) -> Tuple[List[EventBase], str]:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         # We just ignore the key for now.
 
         to_key = self.get_current_key()
 
-        from_token = RoomStreamToken.parse(from_key)
-        if from_token.topological:
+        if from_key.topological:
             logger.warning("Stream has topological part!!!! %r", from_key)
-            from_key = "s%s" % (from_token.stream,)
+            from_key = RoomStreamToken(None, from_key.stream)
 
         app_service = self.store.get_app_service_by_user_id(user.to_string())
         if app_service:
@@ -1139,8 +1142,8 @@ class RoomEventSource:
 
         return (events, end_key)
 
-    def get_current_key(self) -> str:
-        return "s%d" % (self.store.get_room_max_stream_ordering(),)
+    def get_current_key(self) -> RoomStreamToken:
+        return self.store.get_room_max_token()
 
     def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
         return self.store.get_room_events_max_id(room_id)
@@ -1260,10 +1263,10 @@ class RoomShutdownHandler:
             # We now wait for the create room to come back in via replication so
             # that we can assume that all the joins/invites have propogated before
             # we try and auto join below.
-            #
-            # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                self.hs.config.worker.writers.events, "events", stream_id
+                self.hs.config.worker.events_shard_config.get_instance(new_room_id),
+                "events",
+                stream_id,
             )
         else:
             new_room_id = None
@@ -1293,7 +1296,9 @@ class RoomShutdownHandler:
 
                 # Wait for leave to come in over replication before trying to forget.
                 await self._replication.wait_for_stream_position(
-                    self.hs.config.worker.writers.events, "events", stream_id
+                    self.hs.config.worker.events_shard_config.get_instance(room_id),
+                    "events",
+                    stream_id,
                 )
 
                 await self.room_member_handler.forget(target_requester.user, room_id)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5dd7b28391..4a13c8e912 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -38,7 +38,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
 
 class RoomListHandler(BaseHandler):
     def __init__(self, hs):
-        super(RoomListHandler, self).__init__(hs)
+        super().__init__(hs)
         self.enable_room_list_search = hs.config.enable_room_list_search
         self.response_cache = ResponseCache(hs, "room_list")
         self.remote_response_cache = ResponseCache(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 32b7e323fa..8feba8c90a 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -40,7 +40,7 @@ from synapse.events.validator import EventValidator
 from synapse.storage.roommember import RoomsForUser
 from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
 from synapse.util.async_helpers import Linearizer
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
 
 from ._base import BaseHandler
 
@@ -51,14 +51,12 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class RoomMemberHandler:
+class RoomMemberHandler(metaclass=abc.ABCMeta):
     # TODO(paul): This handler currently contains a messy conflation of
     #   low-level API that works on UserID objects and so on, and REST-level
     #   API that takes ID strings and returns pagination chunks. These concerns
     #   ought to be separated out a lot better.
 
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
@@ -82,13 +80,6 @@ class RoomMemberHandler:
         self._enable_lookup = hs.config.enable_3pid_lookup
         self.allow_per_room_profiles = self.config.allow_per_room_profiles
 
-        self._event_stream_writer_instance = hs.config.worker.writers.events
-        self._is_on_event_persistence_instance = (
-            self._event_stream_writer_instance == hs.get_instance_name()
-        )
-        if self._is_on_event_persistence_instance:
-            self.persist_event_storage = hs.get_storage().persistence
-
         self._join_rate_limiter_local = Ratelimiter(
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
@@ -149,17 +140,6 @@ class RoomMemberHandler:
         raise NotImplementedError()
 
     @abc.abstractmethod
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Notifies distributor on master process that the user has joined the
-        room.
-
-        Args:
-            target
-            room_id
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Notifies distributor on master process that the user has left the
         room.
@@ -221,7 +201,6 @@ class RoomMemberHandler:
 
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
 
-        newly_joined = False
         if event.membership == Membership.JOIN:
             newly_joined = True
             if prev_member_event_id:
@@ -246,12 +225,7 @@ class RoomMemberHandler:
             requester, event, context, extra_users=[target], ratelimit=ratelimit,
         )
 
-        if event.membership == Membership.JOIN and newly_joined:
-            # Only fire user_joined_room if the user has actually joined the
-            # room. Don't bother if the user is just changing their profile
-            # info.
-            await self._user_joined_room(target, room_id)
-        elif event.membership == Membership.LEAVE:
+        if event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
@@ -726,17 +700,7 @@ class RoomMemberHandler:
             (EventTypes.Member, event.state_key), None
         )
 
-        if event.membership == Membership.JOIN:
-            # Only fire user_joined_room if the user has actually joined the
-            # room. Don't bother if the user is just changing their profile
-            # info.
-            newly_joined = True
-            if prev_member_event_id:
-                prev_member_event = await self.store.get_event(prev_member_event_id)
-                newly_joined = prev_member_event.membership != Membership.JOIN
-            if newly_joined:
-                await self._user_joined_room(target_user, room_id)
-        elif event.membership == Membership.LEAVE:
+        if event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
@@ -1002,10 +966,9 @@ class RoomMemberHandler:
 
 class RoomMemberMasterHandler(RoomMemberHandler):
     def __init__(self, hs):
-        super(RoomMemberMasterHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.distributor = hs.get_distributor()
-        self.distributor.declare("user_joined_room")
         self.distributor.declare("user_left_room")
 
     async def _is_remote_room_too_complex(
@@ -1085,7 +1048,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         event_id, stream_id = await self.federation_handler.do_invite_join(
             remote_room_hosts, room_id, user.to_string(), content
         )
-        await self._user_joined_room(user, room_id)
 
         # Check the room we just joined wasn't too large, if we didn't fetch the
         # complexity of it before.
@@ -1228,11 +1190,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         )
         return event.event_id, stream_id
 
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Implements RoomMemberHandler._user_joined_room
-        """
-        user_joined_room(self.distributor, target, room_id)
-
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 897338fd54..f2e88f6a5b 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 
 class RoomMemberWorkerHandler(RoomMemberHandler):
     def __init__(self, hs):
-        super(RoomMemberWorkerHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self._remote_join_client = ReplRemoteJoin.make_client(hs)
         self._remote_reject_client = ReplRejectInvite.make_client(hs)
@@ -57,8 +57,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
             content=content,
         )
 
-        await self._user_joined_room(user, room_id)
-
         return ret["event_id"], ret["stream_id"]
 
     async def remote_reject_invite(
@@ -81,13 +79,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         )
         return ret["event_id"], ret["stream_id"]
 
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Implements RoomMemberHandler._user_joined_room
-        """
-        await self._notify_change_client(
-            user_id=target.to_string(), room_id=room_id, change="joined"
-        )
-
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 66b063f991..285c481a96 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -21,9 +21,10 @@ import saml2
 import saml2.response
 from saml2.client import Saml2Client
 
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
+from synapse.http.server import respond_with_html
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
@@ -41,7 +42,11 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+class MappingException(Exception):
+    """Used to catch errors when mapping the SAML2 response to a user."""
+
+
+@attr.s(slots=True)
 class Saml2SessionData:
     """Data we track about SAML2 sessions"""
 
@@ -68,6 +73,7 @@ class SamlHandler:
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
         self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
+        self._error_template = hs.config.sso_error_template
 
         # plugin to do custom mapping from saml response to mxid
         self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@@ -84,6 +90,25 @@ class SamlHandler:
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
 
+    def _render_error(
+        self, request, error: str, error_description: Optional[str] = None
+    ) -> None:
+        """Render the error template and respond to the request with it.
+
+        This is used to show errors to the user. The template of this page can
+        be found under `synapse/res/templates/sso_error.html`.
+
+        Args:
+            request: The incoming request from the browser.
+                We'll respond with an HTML page describing the error.
+            error: A technical identifier for this error.
+            error_description: A human-readable description of the error.
+        """
+        html = self._error_template.render(
+            error=error, error_description=error_description
+        )
+        respond_with_html(request, 400, html)
+
     def handle_redirect_request(
         self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
     ) -> bytes:
@@ -134,49 +159,6 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
-            0
-        ].decode("ascii", "surrogateescape")
-        ip_address = self.hs.get_ip_from_request(request)
-
-        user_id, current_session = await self._map_saml_response_to_user(
-            resp_bytes, relay_state, user_agent, ip_address
-        )
-
-        # Complete the interactive auth session or the login.
-        if current_session and current_session.ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, current_session.ui_auth_session_id, request
-            )
-
-        else:
-            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
-
-    async def _map_saml_response_to_user(
-        self,
-        resp_bytes: str,
-        client_redirect_url: str,
-        user_agent: str,
-        ip_address: str,
-    ) -> Tuple[str, Optional[Saml2SessionData]]:
-        """
-        Given a sample response, retrieve the cached session and user for it.
-
-        Args:
-            resp_bytes: The SAML response.
-            client_redirect_url: The redirect URL passed in by the client.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
-
-        Returns:
-             Tuple of the user ID and SAML session associated with this response.
-
-        Raises:
-            SynapseError if there was a problem with the response.
-            RedirectException: some mapping providers may raise this if they need
-                to redirect to an interstitial page.
-        """
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -189,12 +171,23 @@ class SamlHandler:
             # in the (user-visible) exception message, so let's log the exception here
             # so we can track down the session IDs later.
             logger.warning(str(e))
-            raise SynapseError(400, "Unexpected SAML2 login.")
+            self._render_error(
+                request, "unsolicited_response", "Unexpected SAML2 login."
+            )
+            return
         except Exception as e:
-            raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
+            self._render_error(
+                request,
+                "invalid_response",
+                "Unable to parse SAML2 response: %s." % (e,),
+            )
+            return
 
         if saml2_auth.not_signed:
-            raise SynapseError(400, "SAML2 response was not signed.")
+            self._render_error(
+                request, "unsigned_respond", "SAML2 response was not signed."
+            )
+            return
 
         logger.debug("SAML2 response: %s", saml2_auth.origxml)
         for assertion in saml2_auth.assertions:
@@ -213,15 +206,73 @@ class SamlHandler:
             saml2_auth.in_response_to, None
         )
 
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
         for requirement in self._saml2_attribute_requirements:
-            _check_attribute_requirement(saml2_auth.ava, requirement)
+            if not _check_attribute_requirement(saml2_auth.ava, requirement):
+                self._render_error(
+                    request, "unauthorised", "You are not authorised to log in here."
+                )
+                return
+
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+        ip_address = self.hs.get_ip_from_request(request)
+
+        # Call the mapper to register/login the user
+        try:
+            user_id = await self._map_saml_response_to_user(
+                saml2_auth, relay_state, user_agent, ip_address
+            )
+        except MappingException as e:
+            logger.exception("Could not map user")
+            self._render_error(request, "mapping_error", str(e))
+            return
+
+        # Complete the interactive auth session or the login.
+        if current_session and current_session.ui_auth_session_id:
+            await self._auth_handler.complete_sso_ui_auth(
+                user_id, current_session.ui_auth_session_id, request
+            )
+
+        else:
+            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+
+    async def _map_saml_response_to_user(
+        self,
+        saml2_auth: saml2.response.AuthnResponse,
+        client_redirect_url: str,
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """
+        Given a SAML response, retrieve the user ID for it and possibly register the user.
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+
+        Returns:
+             The user ID associated with this response.
+
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
+        """
 
         remote_user_id = self._user_mapping_provider.get_remote_user_id(
             saml2_auth, client_redirect_url
         )
 
         if not remote_user_id:
-            raise Exception("Failed to extract remote user id from SAML response")
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
+            )
 
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
@@ -235,7 +286,7 @@ class SamlHandler:
             )
             if registered_user_id is not None:
                 logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id, current_session
+                return registered_user_id
 
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
@@ -260,7 +311,7 @@ class SamlHandler:
                     await self._datastore.record_user_external_id(
                         self._auth_provider_id, remote_user_id, registered_user_id
                     )
-                    return registered_user_id, current_session
+                    return registered_user_id
 
             # Map saml response to user attributes using the configured mapping provider
             for i in range(1000):
@@ -277,7 +328,7 @@ class SamlHandler:
 
                 localpart = attribute_dict.get("mxid_localpart")
                 if not localpart:
-                    raise Exception(
+                    raise MappingException(
                         "Error parsing SAML2 response: SAML mapping provider plugin "
                         "did not return a mxid_localpart value"
                     )
@@ -294,8 +345,8 @@ class SamlHandler:
             else:
                 # Unable to generate a username in 1000 iterations
                 # Break and return error to the user
-                raise SynapseError(
-                    500, "Unable to generate a Matrix ID from the SAML response"
+                raise MappingException(
+                    "Unable to generate a Matrix ID from the SAML response"
                 )
 
             logger.info("Mapped SAML user to local part %s", localpart)
@@ -310,7 +361,7 @@ class SamlHandler:
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
-            return registered_user_id, current_session
+            return registered_user_id
 
     def expire_sessions(self):
         expire_before = self._clock.time_msec() - self._saml2_session_lifetime
@@ -323,11 +374,11 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
+def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
     values = ava.get(req.attribute, [])
     for v in values:
         if v == req.value:
-            return
+            return True
 
     logger.info(
         "SAML2 attribute %s did not match required value '%s' (was '%s')",
@@ -335,7 +386,7 @@ def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
         req.value,
         values,
     )
-    raise AuthError(403, "You are not authorized to log in here.")
+    return False
 
 
 DOT_REPLACE_PATTERN = re.compile(
@@ -390,7 +441,7 @@ class DefaultSamlMappingProvider:
             return saml_response.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "'uid' not in SAML2 response")
+            raise MappingException("'uid' not in SAML2 response")
 
     def saml_response_to_user_attributes(
         self,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index d58f9788c5..e9402e6e2e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
 
 class SearchHandler(BaseHandler):
     def __init__(self, hs):
-        super(SearchHandler, self).__init__(hs)
+        super().__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -362,13 +362,13 @@ class SearchHandler(BaseHandler):
                     self.storage, user.to_string(), res["events_after"]
                 )
 
-                res["start"] = now_token.copy_and_replace(
+                res["start"] = await now_token.copy_and_replace(
                     "room_key", res["start"]
-                ).to_string()
+                ).to_string(self.store)
 
-                res["end"] = now_token.copy_and_replace(
+                res["end"] = await now_token.copy_and_replace(
                     "room_key", res["end"]
-                ).to_string()
+                ).to_string(self.store)
 
                 if include_profile:
                     senders = {
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 4d245b618b..a5d67f828f 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -27,7 +27,7 @@ class SetPasswordHandler(BaseHandler):
     """Handler which deals with changing user account passwords"""
 
     def __init__(self, hs):
-        super(SetPasswordHandler, self).__init__(hs)
+        super().__init__(hs)
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
         self._password_policy_handler = hs.get_password_policy_handler()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e2ddb628ff..bfe2583002 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -89,14 +89,12 @@ class TimelineBatch:
     events = attr.ib(type=List[EventBase])
     limited = attr.ib(bool)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
         return bool(self.events)
 
-    __bool__ = __nonzero__  # python3
-
 
 # We can't freeze this class, because we need to update it after it's instantiated to
 # update its unread count. This is because we calculate the unread count for a room only
@@ -114,7 +112,7 @@ class JoinedSyncResult:
     summary = attr.ib(type=Optional[JsonDict])
     unread_count = attr.ib(type=int)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
@@ -127,8 +125,6 @@ class JoinedSyncResult:
             # else in the result, we don't need to send it.
         )
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class ArchivedSyncResult:
@@ -137,26 +133,22 @@ class ArchivedSyncResult:
     state = attr.ib(type=StateMap[EventBase])
     account_data = attr.ib(type=List[JsonDict])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
         return bool(self.timeline or self.state or self.account_data)
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class InvitedSyncResult:
     room_id = attr.ib(type=str)
     invite = attr.ib(type=EventBase)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Invited rooms should always be reported to the client"""
         return True
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class GroupsSyncResult:
@@ -164,11 +156,9 @@ class GroupsSyncResult:
     invite = attr.ib(type=JsonDict)
     leave = attr.ib(type=JsonDict)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         return bool(self.join or self.invite or self.leave)
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class DeviceLists:
@@ -181,13 +171,11 @@ class DeviceLists:
     changed = attr.ib(type=Collection[str])
     left = attr.ib(type=Collection[str])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         return bool(self.changed or self.left)
 
-    __bool__ = __nonzero__  # python3
 
-
-@attr.s
+@attr.s(slots=True)
 class _RoomChanges:
     """The set of room entries to include in the sync, plus the set of joined
     and left room IDs since last sync.
@@ -227,7 +215,7 @@ class SyncResult:
     device_one_time_keys_count = attr.ib(type=JsonDict)
     groups = attr.ib(type=Optional[GroupsSyncResult])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if the notifier needs to wait for more events when polling for
         events.
@@ -243,8 +231,6 @@ class SyncResult:
             or self.groups
         )
 
-    __bool__ = __nonzero__  # python3
-
 
 class SyncHandler:
     def __init__(self, hs: "HomeServer"):
@@ -378,7 +364,7 @@ class SyncHandler:
         sync_config = sync_result_builder.sync_config
 
         with Measure(self.clock, "ephemeral_by_room"):
-            typing_key = since_token.typing_key if since_token else "0"
+            typing_key = since_token.typing_key if since_token else 0
 
             room_ids = sync_result_builder.joined_room_ids
 
@@ -402,7 +388,7 @@ class SyncHandler:
                 event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
                 ephemeral_by_room.setdefault(room_id, []).append(event_copy)
 
-            receipt_key = since_token.receipt_key if since_token else "0"
+            receipt_key = since_token.receipt_key if since_token else 0
 
             receipt_source = self.event_sources.sources["receipt"]
             receipts, receipt_key = await receipt_source.get_new_events(
@@ -981,7 +967,7 @@ class SyncHandler:
             raise NotImplementedError()
         else:
             joined_room_ids = await self.get_rooms_for_user_at(
-                user_id, now_token.room_stream_id
+                user_id, now_token.room_key
             )
         sync_result_builder = SyncResultBuilder(
             sync_config,
@@ -1310,12 +1296,11 @@ class SyncHandler:
         presence_source = self.event_sources.sources["presence"]
 
         since_token = sync_result_builder.since_token
+        presence_key = None
+        include_offline = False
         if since_token and not sync_result_builder.full_state:
             presence_key = since_token.presence_key
             include_offline = True
-        else:
-            presence_key = None
-            include_offline = False
 
         presence, presence_key = await presence_source.get_new_events(
             user=user,
@@ -1323,6 +1308,7 @@ class SyncHandler:
             is_guest=sync_config.is_guest,
             include_offline=include_offline,
         )
+        assert presence_key
         sync_result_builder.now_token = now_token.copy_and_replace(
             "presence_key", presence_key
         )
@@ -1485,7 +1471,7 @@ class SyncHandler:
         if rooms_changed:
             return True
 
-        stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
+        stream_id = since_token.room_key.stream
         for room_id in sync_result_builder.joined_room_ids:
             if self.store.has_room_changed_since(room_id, stream_id):
                 return True
@@ -1609,16 +1595,24 @@ class SyncHandler:
 
             if leave_events:
                 leave_event = leave_events[-1]
-                leave_stream_token = await self.store.get_stream_token_for_event(
+                leave_position = await self.store.get_position_for_event(
                     leave_event.event_id
                 )
-                leave_token = since_token.copy_and_replace(
-                    "room_key", leave_stream_token
-                )
 
-                if since_token and since_token.is_after(leave_token):
+                # If the leave event happened before the since token then we
+                # bail.
+                if since_token and not leave_position.persisted_after(
+                    since_token.room_key
+                ):
                     continue
 
+                # We can safely convert the position of the leave event into a
+                # stream token as it'll only be used in the context of this
+                # room. (c.f. the docstring of `to_room_stream_token`).
+                leave_token = since_token.copy_and_replace(
+                    "room_key", leave_position.to_room_stream_token()
+                )
+
                 # If this is an out of band message, like a remote invite
                 # rejection, we include it in the recents batch. Otherwise, we
                 # let _load_filtered_recents handle fetching the correct
@@ -1751,7 +1745,7 @@ class SyncHandler:
                             continue
 
                 leave_token = now_token.copy_and_replace(
-                    "room_key", "s%d" % (event.stream_ordering,)
+                    "room_key", RoomStreamToken(None, event.stream_ordering)
                 )
                 room_entries.append(
                     RoomSyncResultBuilder(
@@ -1930,7 +1924,7 @@ class SyncHandler:
             raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
     async def get_rooms_for_user_at(
-        self, user_id: str, stream_ordering: int
+        self, user_id: str, room_key: RoomStreamToken
     ) -> FrozenSet[str]:
         """Get set of joined rooms for a user at the given stream ordering.
 
@@ -1956,15 +1950,15 @@ class SyncHandler:
         # If the membership's stream ordering is after the given stream
         # ordering, we need to go and work out if the user was in the room
         # before.
-        for room_id, membership_stream_ordering in joined_rooms:
-            if membership_stream_ordering <= stream_ordering:
+        for room_id, event_pos in joined_rooms:
+            if not event_pos.persisted_after(room_key):
                 joined_room_ids.add(room_id)
                 continue
 
             logger.info("User joined room after current token: %s", room_id)
 
             extrems = await self.store.get_forward_extremeties_for_room(
-                room_id, stream_ordering
+                room_id, event_pos.stream
             )
             users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
             if user_id in users_in_room:
@@ -2038,7 +2032,7 @@ def _calculate_state(
     return {event_id_to_key[e]: e for e in state_ids}
 
 
-@attr.s
+@attr.s(slots=True)
 class SyncResultBuilder:
     """Used to help build up a new SyncResult for a user
 
@@ -2074,7 +2068,7 @@ class SyncResultBuilder:
     to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
 
 
-@attr.s
+@attr.s(slots=True)
 class RoomSyncResultBuilder:
     """Stores information needed to create either a `JoinedSyncResult` or
     `ArchivedSyncResult`.
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index e21f8dbc58..79393c8829 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -37,7 +37,7 @@ class UserDirectoryHandler(StateDeltasHandler):
     """
 
     def __init__(self, hs):
-        super(UserDirectoryHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()