summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/acme_issuing_service.py2
-rw-r--r--synapse/handlers/admin.py8
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/deactivate_account.py2
-rw-r--r--synapse/handlers/device.py16
-rw-r--r--synapse/handlers/directory.py2
-rw-r--r--synapse/handlers/e2e_keys.py2
-rw-r--r--synapse/handlers/events.py8
-rw-r--r--synapse/handlers/federation.py94
-rw-r--r--synapse/handlers/groups_local.py2
-rw-r--r--synapse/handlers/identity.py2
-rw-r--r--synapse/handlers/initial_sync.py17
-rw-r--r--synapse/handlers/message.py19
-rw-r--r--synapse/handlers/oidc_handler.py4
-rw-r--r--synapse/handlers/pagination.py49
-rw-r--r--synapse/handlers/presence.py3
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/handlers/read_marker.py2
-rw-r--r--synapse/handlers/receipts.py17
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/handlers/room.py31
-rw-r--r--synapse/handlers/room_list.py2
-rw-r--r--synapse/handlers/room_member.py53
-rw-r--r--synapse/handlers/room_member_worker.py11
-rw-r--r--synapse/handlers/saml_handler.py171
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/set_password.py2
-rw-r--r--synapse/handlers/sync.py50
-rw-r--r--synapse/handlers/user_directory.py2
29 files changed, 261 insertions, 322 deletions
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 69650ff221..7294649d71 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -76,7 +76,7 @@ def create_issuing_service(reactor, acme_url, account_key_file, well_known_resou
     )
 
 
-@attr.s
+@attr.s(slots=True)
 @implementer(ICertificateStore)
 class ErsatzStore:
     """
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 918d0e037c..dd981c597e 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
 
 class AdminHandler(BaseHandler):
     def __init__(self, hs):
-        super(AdminHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -125,8 +125,8 @@ class AdminHandler(BaseHandler):
             else:
                 stream_ordering = room.stream_ordering
 
-            from_key = str(RoomStreamToken(0, 0))
-            to_key = str(RoomStreamToken(None, stream_ordering))
+            from_key = RoomStreamToken(0, 0)
+            to_key = RoomStreamToken(None, stream_ordering)
 
             written_events = set()  # Events that we've processed in this room
 
@@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
                 if not events:
                     break
 
-                from_key = events[-1].internal_metadata.after
+                from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
 
                 events = await filter_events_for_client(self.storage, user_id, events)
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 90189869cc..0322b60cfc 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -145,7 +145,7 @@ class AuthHandler(BaseHandler):
         Args:
             hs (synapse.server.HomeServer):
         """
-        super(AuthHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
@@ -1235,7 +1235,7 @@ class AuthHandler(BaseHandler):
         return urllib.parse.urlunparse(url_parts)
 
 
-@attr.s
+@attr.s(slots=True)
 class MacaroonGenerator:
 
     hs = attr.ib()
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 25169157c1..0635ad5708 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -29,7 +29,7 @@ class DeactivateAccountHandler(BaseHandler):
     """Handler which deals with deactivating user accounts."""
 
     def __init__(self, hs):
-        super(DeactivateAccountHandler, self).__init__(hs)
+        super().__init__(hs)
         self.hs = hs
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 643d71a710..55a9787439 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -29,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import (
     RoomStreamToken,
+    StreamToken,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
@@ -47,7 +48,7 @@ MAX_DEVICE_DISPLAY_NAME_LEN = 100
 
 class DeviceWorkerHandler(BaseHandler):
     def __init__(self, hs):
-        super(DeviceWorkerHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.hs = hs
         self.state = hs.get_state_handler()
@@ -104,18 +105,15 @@ class DeviceWorkerHandler(BaseHandler):
 
     @trace
     @measure_func("device.get_user_ids_changed")
-    async def get_user_ids_changed(self, user_id, from_token):
+    async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
         """Get list of users that have had the devices updated, or have newly
         joined a room, that `user_id` may be interested in.
-
-        Args:
-            user_id (str)
-            from_token (StreamToken)
         """
 
         set_tag("user_id", user_id)
         set_tag("from_token", from_token)
-        now_room_key = await self.store.get_room_events_max_id()
+        now_room_id = self.store.get_room_max_stream_ordering()
+        now_room_key = RoomStreamToken(None, now_room_id)
 
         room_ids = await self.store.get_rooms_for_user(user_id)
 
@@ -142,7 +140,7 @@ class DeviceWorkerHandler(BaseHandler):
         )
         rooms_changed.update(event.room_id for event in member_events)
 
-        stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
+        stream_ordering = from_token.room_key.stream
 
         possibly_changed = set(changed)
         possibly_left = set()
@@ -253,7 +251,7 @@ class DeviceWorkerHandler(BaseHandler):
 
 class DeviceHandler(DeviceWorkerHandler):
     def __init__(self, hs):
-        super(DeviceHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
 
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 46826eb784..62aa9a2da8 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
 
 class DirectoryHandler(BaseHandler):
     def __init__(self, hs):
-        super(DirectoryHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.state = hs.get_state_handler()
         self.appservice_handler = hs.get_application_service_handler()
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d629c7c16c..dd40fd1299 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1201,7 +1201,7 @@ def _one_time_keys_match(old_key_json, new_key):
     return old_key == new_key_copy
 
 
-@attr.s
+@attr.s(slots=True)
 class SignatureListItem:
     """An item in the signature list as used by upload_signatures_for_device_keys.
     """
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index b05e32f457..0875b74ea8 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -37,11 +37,7 @@ logger = logging.getLogger(__name__)
 
 class EventStreamHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(EventStreamHandler, self).__init__(hs)
-
-        self.distributor = hs.get_distributor()
-        self.distributor.declare("started_user_eventstream")
-        self.distributor.declare("stopped_user_eventstream")
+        super().__init__(hs)
 
         self.clock = hs.get_clock()
 
@@ -146,7 +142,7 @@ class EventStreamHandler(BaseHandler):
 
 class EventHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(EventHandler, self).__init__(hs)
+        super().__init__(hs)
         self.storage = hs.get_storage()
 
     async def get_event(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 014dab2940..ea9264e751 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -69,7 +69,6 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
     ReplicationStoreRoomOnInviteRestServlet,
 )
-from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
@@ -80,7 +79,6 @@ from synapse.types import (
     get_domain_from_id,
 )
 from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.util.stringutils import shortstr
 from synapse.visibility import filter_events_for_server
@@ -88,7 +86,7 @@ from synapse.visibility import filter_events_for_server
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+@attr.s(slots=True)
 class _NewEventInfo:
     """Holds information about a received event, ready for passing to _handle_new_events
 
@@ -117,7 +115,7 @@ class FederationHandler(BaseHandler):
     """
 
     def __init__(self, hs):
-        super(FederationHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.hs = hs
 
@@ -130,7 +128,6 @@ class FederationHandler(BaseHandler):
         self.keyring = hs.get_keyring()
         self.action_generator = hs.get_action_generator()
         self.is_mine_id = hs.is_mine_id
-        self.pusher_pool = hs.get_pusherpool()
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
         self._message_handler = hs.get_message_handler()
@@ -141,9 +138,6 @@ class FederationHandler(BaseHandler):
         self._replication = hs.get_replication_data_handler()
 
         self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
-        self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
-            hs
-        )
         self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
             hs
         )
@@ -704,31 +698,10 @@ class FederationHandler(BaseHandler):
         logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
 
         try:
-            context = await self._handle_new_event(origin, event, state=state)
+            await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
-        if event.type == EventTypes.Member:
-            if event.membership == Membership.JOIN:
-                # Only fire user_joined_room if the user has acutally
-                # joined the room. Don't bother if the user is just
-                # changing their profile info.
-                newly_joined = True
-
-                prev_state_ids = await context.get_prev_state_ids()
-
-                prev_state_id = prev_state_ids.get((event.type, event.state_key))
-                if prev_state_id:
-                    prev_state = await self.store.get_event(
-                        prev_state_id, allow_none=True
-                    )
-                    if prev_state and prev_state.membership == Membership.JOIN:
-                        newly_joined = False
-
-                if newly_joined:
-                    user = UserID.from_string(event.state_key)
-                    await self.user_joined_room(user, room_id)
-
         # For encrypted messages we check that we know about the sending device,
         # if we don't then we mark the device cache for that user as stale.
         if event.type == EventTypes.Encrypted:
@@ -923,7 +896,8 @@ class FederationHandler(BaseHandler):
                 )
             )
 
-        await self._handle_new_events(dest, ev_infos, backfilled=True)
+        if ev_infos:
+            await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -1265,7 +1239,7 @@ class FederationHandler(BaseHandler):
             event_infos.append(_NewEventInfo(event, None, auth))
 
         await self._handle_new_events(
-            destination, event_infos,
+            destination, room_id, event_infos,
         )
 
     def _sanity_check_event(self, ev):
@@ -1412,15 +1386,15 @@ class FederationHandler(BaseHandler):
             )
 
             max_stream_id = await self._persist_auth_tree(
-                origin, auth_chain, state, event, room_version_obj
+                origin, room_id, auth_chain, state, event, room_version_obj
             )
 
             # We wait here until this instance has seen the events come down
             # replication (if we're using replication) as the below uses caches.
-            #
-            # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                self.config.worker.writers.events, "events", max_stream_id
+                self.config.worker.events_shard_config.get_instance(room_id),
+                "events",
+                max_stream_id,
             )
 
             # Check whether this room is the result of an upgrade of a room we already know
@@ -1599,11 +1573,6 @@ class FederationHandler(BaseHandler):
             event.signatures,
         )
 
-        if event.type == EventTypes.Member:
-            if event.content["membership"] == Membership.JOIN:
-                user = UserID.from_string(event.state_key)
-                await self.user_joined_room(user, event.room_id)
-
         prev_state_ids = await context.get_prev_state_ids()
 
         state_ids = list(prev_state_ids.values())
@@ -1674,7 +1643,7 @@ class FederationHandler(BaseHandler):
         )
 
         context = await self.state_handler.compute_event_context(event)
-        await self.persist_events_and_notify([(event, context)])
+        await self.persist_events_and_notify(event.room_id, [(event, context)])
 
         return event
 
@@ -1701,7 +1670,9 @@ class FederationHandler(BaseHandler):
         await self.federation_client.send_leave(host_list, event)
 
         context = await self.state_handler.compute_event_context(event)
-        stream_id = await self.persist_events_and_notify([(event, context)])
+        stream_id = await self.persist_events_and_notify(
+            event.room_id, [(event, context)]
+        )
 
         return event, stream_id
 
@@ -1949,7 +1920,7 @@ class FederationHandler(BaseHandler):
                 )
 
             await self.persist_events_and_notify(
-                [(event, context)], backfilled=backfilled
+                event.room_id, [(event, context)], backfilled=backfilled
             )
         except Exception:
             run_in_background(
@@ -1962,6 +1933,7 @@ class FederationHandler(BaseHandler):
     async def _handle_new_events(
         self,
         origin: str,
+        room_id: str,
         event_infos: Iterable[_NewEventInfo],
         backfilled: bool = False,
     ) -> None:
@@ -1993,6 +1965,7 @@ class FederationHandler(BaseHandler):
         )
 
         await self.persist_events_and_notify(
+            room_id,
             [
                 (ev_info.event, context)
                 for ev_info, context in zip(event_infos, contexts)
@@ -2003,6 +1976,7 @@ class FederationHandler(BaseHandler):
     async def _persist_auth_tree(
         self,
         origin: str,
+        room_id: str,
         auth_events: List[EventBase],
         state: List[EventBase],
         event: EventBase,
@@ -2017,6 +1991,7 @@ class FederationHandler(BaseHandler):
 
         Args:
             origin: Where the events came from
+            room_id,
             auth_events
             state
             event
@@ -2091,17 +2066,20 @@ class FederationHandler(BaseHandler):
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
         await self.persist_events_and_notify(
+            room_id,
             [
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
-            ]
+            ],
         )
 
         new_event_context = await self.state_handler.compute_event_context(
             event, old_state=state
         )
 
-        return await self.persist_events_and_notify([(event, new_event_context)])
+        return await self.persist_events_and_notify(
+            room_id, [(event, new_event_context)]
+        )
 
     async def _prep_event(
         self,
@@ -2952,6 +2930,7 @@ class FederationHandler(BaseHandler):
 
     async def persist_events_and_notify(
         self,
+        room_id: str,
         event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
     ) -> int:
@@ -2959,14 +2938,19 @@ class FederationHandler(BaseHandler):
         necessary.
 
         Args:
-            event_and_contexts:
+            room_id: The room ID of events being persisted.
+            event_and_contexts: Sequence of events with their associated
+                context that should be persisted. All events must belong to
+                the same room.
             backfilled: Whether these events are a result of
                 backfilling or not
         """
-        if self.config.worker.writers.events != self._instance_name:
+        instance = self.config.worker.events_shard_config.get_instance(room_id)
+        if instance != self._instance_name:
             result = await self._send_events(
-                instance_name=self.config.worker.writers.events,
+                instance_name=instance,
                 store=self.store,
+                room_id=room_id,
                 event_and_contexts=event_and_contexts,
                 backfilled=backfilled,
             )
@@ -3019,8 +3003,6 @@ class FederationHandler(BaseHandler):
             event, event_stream_id, max_stream_id, extra_users=extra_users
         )
 
-        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
     async def _clean_room_for_join(self, room_id: str) -> None:
         """Called to clean up any data in DB for a given room, ready for the
         server to join the room.
@@ -3033,16 +3015,6 @@ class FederationHandler(BaseHandler):
         else:
             await self.store.clean_room_for_join(room_id)
 
-    async def user_joined_room(self, user: UserID, room_id: str) -> None:
-        """Called when a new user has joined the room
-        """
-        if self.config.worker_app:
-            await self._notify_user_membership_change(
-                room_id=room_id, user_id=user.to_string(), change="joined"
-            )
-        else:
-            user_joined_room(self.distributor, user, room_id)
-
     async def get_room_complexity(
         self, remote_room_hosts: List[str], room_id: str
     ) -> Optional[dict]:
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 44df567983..9684e60fc8 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -240,7 +240,7 @@ class GroupsLocalWorkerHandler:
 
 class GroupsLocalHandler(GroupsLocalWorkerHandler):
     def __init__(self, hs):
-        super(GroupsLocalHandler, self).__init__(hs)
+        super().__init__(hs)
 
         # Ensure attestations get renewed
         hs.get_groups_attestation_renewer()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 0ce6ddfbe4..ab15570f7a 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -45,7 +45,7 @@ id_server_scheme = "https://"
 
 class IdentityHandler(BaseHandler):
     def __init__(self, hs):
-        super(IdentityHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.http_client = SimpleHttpClient(hs)
         # We create a blacklisting instance of SimpleHttpClient for contacting identity
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d5ddc583ad..8cd7eb22a3 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage.roommember import RoomsForUser
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, StreamToken, UserID
+from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
 
 class InitialSyncHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(InitialSyncHandler, self).__init__(hs)
+        super().__init__(hs)
         self.hs = hs
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
@@ -116,14 +116,13 @@ class InitialSyncHandler(BaseHandler):
         now_token = self.hs.get_event_sources().get_current_token()
 
         presence_stream = self.hs.get_event_sources().sources["presence"]
-        pagination_config = PaginationConfig(from_token=now_token)
-        presence, _ = await presence_stream.get_pagination_rows(
-            user, pagination_config.get_source_config("presence"), None
+        presence, _ = await presence_stream.get_new_events(
+            user, from_key=None, include_offline=False
         )
 
-        receipt_stream = self.hs.get_event_sources().sources["receipt"]
-        receipt, _ = await receipt_stream.get_pagination_rows(
-            user, pagination_config.get_source_config("receipt"), None
+        joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
+        receipt = await self.store.get_linearized_receipts_for_rooms(
+            joined_rooms, to_key=int(now_token.receipt_key),
         )
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -168,7 +167,7 @@ class InitialSyncHandler(BaseHandler):
                         self.state_handler.get_current_state, event.room_id
                     )
                 elif event.membership == Membership.LEAVE:
-                    room_end_token = "s%d" % (event.stream_ordering,)
+                    room_end_token = RoomStreamToken(None, event.stream_ordering,)
                     deferred_room_state = run_in_background(
                         self.state_store.get_state_for_events, [event.event_id]
                     )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 8a7b4916cd..a8fe5cf4e2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -376,9 +376,8 @@ class EventCreationHandler:
         self.notifier = hs.get_notifier()
         self.config = hs.config
         self.require_membership_for_aliases = hs.config.require_membership_for_aliases
-        self._is_event_writer = (
-            self.config.worker.writers.events == hs.get_instance_name()
-        )
+        self._events_shard_config = self.config.worker.events_shard_config
+        self._instance_name = hs.get_instance_name()
 
         self.room_invite_state_types = self.hs.config.room_invite_state_types
 
@@ -387,8 +386,6 @@ class EventCreationHandler:
         # This is only used to get at ratelimit function, and maybe_kick_guest_users
         self.base_handler = BaseHandler(hs)
 
-        self.pusher_pool = hs.get_pusherpool()
-
         # We arbitrarily limit concurrent event creation for a room to 5.
         # This is to stop us from diverging history *too* much.
         self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
@@ -904,9 +901,10 @@ class EventCreationHandler:
 
         try:
             # If we're a worker we need to hit out to the master.
-            if not self._is_event_writer:
+            writer_instance = self._events_shard_config.get_instance(event.room_id)
+            if writer_instance != self._instance_name:
                 result = await self.send_event(
-                    instance_name=self.config.worker.writers.events,
+                    instance_name=writer_instance,
                     event_id=event.event_id,
                     store=self.store,
                     requester=requester,
@@ -974,7 +972,10 @@ class EventCreationHandler:
 
         This should only be run on the instance in charge of persisting events.
         """
-        assert self._is_event_writer
+        assert self.storage.persistence is not None
+        assert self._events_shard_config.should_handle(
+            self._instance_name, event.room_id
+        )
 
         if ratelimit:
             # We check if this is a room admin redacting an event so that we
@@ -1145,8 +1146,6 @@ class EventCreationHandler:
             # If there's an expiry timestamp on the event, schedule its expiry.
             self._message_handler.maybe_schedule_expiry(event)
 
-        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
         def _notify():
             try:
                 self.notifier.on_new_room_event(
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 1b06f3173f..4230dbaf99 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -131,10 +131,10 @@ class OidcHandler:
     def _render_error(
         self, request, error: str, error_description: Optional[str] = None
     ) -> None:
-        """Renders the error template and respond with it.
+        """Render the error template and respond to the request with it.
 
         This is used to show errors to the user. The template of this page can
-        be found under ``synapse/res/templates/sso_error.html``.
+        be found under `synapse/res/templates/sso_error.html`.
 
         Args:
             request: The incoming request from the browser.
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6067585f9b..f132ed3368 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -335,20 +335,16 @@ class PaginationHandler:
         user_id = requester.user.to_string()
 
         if pagin_config.from_token:
-            room_token = pagin_config.from_token.room_key
+            from_token = pagin_config.from_token
         else:
-            pagin_config.from_token = (
-                self.hs.get_event_sources().get_current_token_for_pagination()
-            )
-            room_token = pagin_config.from_token.room_key
-
-        room_token = RoomStreamToken.parse(room_token)
+            from_token = self.hs.get_event_sources().get_current_token_for_pagination()
 
-        pagin_config.from_token = pagin_config.from_token.copy_and_replace(
-            "room_key", str(room_token)
-        )
+        if pagin_config.limit is None:
+            # This shouldn't happen as we've set a default limit before this
+            # gets called.
+            raise Exception("limit not set")
 
-        source_config = pagin_config.get_source_config("room")
+        room_token = from_token.room_key
 
         with await self.pagination_lock.read(room_id):
             (
@@ -358,7 +354,7 @@ class PaginationHandler:
                 room_id, user_id, allow_departed_users=True
             )
 
-            if source_config.direction == "b":
+            if pagin_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
                 # requires that we have a topo token.
                 if room_token.topological:
@@ -377,26 +373,35 @@ class PaginationHandler:
                     # case "JOIN" would have been returned.
                     assert member_event_id
 
-                    leave_token = await self.store.get_topological_token_for_event(
+                    leave_token_str = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
-                    if RoomStreamToken.parse(leave_token).topological < curr_topo:
-                        source_config.from_key = str(leave_token)
+                    leave_token = RoomStreamToken.parse(leave_token_str)
+                    assert leave_token.topological is not None
+
+                    if leave_token.topological < curr_topo:
+                        from_token = from_token.copy_and_replace(
+                            "room_key", leave_token
+                        )
 
                 await self.hs.get_handlers().federation_handler.maybe_backfill(
                     room_id, curr_topo, limit=source_config.limit,
                 )
 
+            to_room_key = None
+            if pagin_config.to_token:
+                to_room_key = pagin_config.to_token.room_key
+
             events, next_key = await self.store.paginate_room_events(
                 room_id=room_id,
-                from_key=source_config.from_key,
-                to_key=source_config.to_key,
-                direction=source_config.direction,
-                limit=source_config.limit,
+                from_key=from_token.room_key,
+                to_key=to_room_key,
+                direction=pagin_config.direction,
+                limit=pagin_config.limit,
                 event_filter=event_filter,
             )
 
-            next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
+            next_token = from_token.copy_and_replace("room_key", next_key)
 
         if events:
             if event_filter:
@@ -409,7 +414,7 @@ class PaginationHandler:
         if not events:
             return {
                 "chunk": [],
-                "start": pagin_config.from_token.to_string(),
+                "start": from_token.to_string(),
                 "end": next_token.to_string(),
             }
 
@@ -438,7 +443,7 @@ class PaginationHandler:
                     events, time_now, as_client_event=as_client_event
                 )
             ),
-            "start": pagin_config.from_token.to_string(),
+            "start": from_token.to_string(),
             "end": next_token.to_string(),
         }
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 91a3aec1cc..1000ac95ff 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1108,9 +1108,6 @@ class PresenceEventSource:
     def get_current_key(self):
         return self.store.get_current_presence_token()
 
-    async def get_pagination_rows(self, user, pagination_config, key):
-        return await self.get_new_events(user, from_key=None, include_offline=False)
-
     @cached(num_args=2, cache_context=True)
     async def _get_interested_in(self, user, explicit_room_id, cache_context):
         """Returns the set of users that the given user should see presence
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 0cb8fad89a..5453e6dfc8 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -44,7 +44,7 @@ class BaseProfileHandler(BaseHandler):
     """
 
     def __init__(self, hs):
-        super(BaseProfileHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.federation = hs.get_federation_client()
         hs.get_federation_registry().register_query_handler(
@@ -369,7 +369,7 @@ class MasterProfileHandler(BaseProfileHandler):
     PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
 
     def __init__(self, hs):
-        super(MasterProfileHandler, self).__init__(hs)
+        super().__init__(hs)
 
         assert hs.config.worker_app is None
 
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index e3b528d271..c32f314a1c 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
 
 class ReadMarkerHandler(BaseHandler):
     def __init__(self, hs):
-        super(ReadMarkerHandler, self).__init__(hs)
+        super().__init__(hs)
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
         self.read_marker_linearizer = Linearizer(name="read_marker")
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 2cc6c2eb68..7225923757 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
 
 class ReceiptsHandler(BaseHandler):
     def __init__(self, hs):
-        super(ReceiptsHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
@@ -142,18 +142,3 @@ class ReceiptEventSource:
 
     def get_current_key(self, direction="f"):
         return self.store.get_max_receipt_stream_id()
-
-    async def get_pagination_rows(self, user, config, key):
-        to_key = int(config.from_key)
-
-        if config.to_key:
-            from_key = int(config.to_key)
-        else:
-            from_key = None
-
-        room_ids = await self.store.get_rooms_for_user(user.to_string())
-        events = await self.store.get_linearized_receipts_for_rooms(
-            room_ids, from_key=from_key, to_key=to_key
-        )
-
-        return (events, to_key)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index cde2dbca92..538f4b2a61 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -42,7 +42,7 @@ class RegistrationHandler(BaseHandler):
         Args:
             hs (synapse.server.HomeServer):
         """
-        super(RegistrationHandler, self).__init__(hs)
+        super().__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a29305f655..11bf146bed 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -70,7 +70,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
 
 class RoomCreationHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
-        super(RoomCreationHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
@@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler):
 
         # Always wait for room creation to progate before returning
         await self._replication.wait_for_stream_position(
-            self.hs.config.worker.writers.events, "events", last_stream_id
+            self.hs.config.worker.events_shard_config.get_instance(room_id),
+            "events",
+            last_stream_id,
         )
 
         return result, last_stream_id
@@ -1091,20 +1093,19 @@ class RoomEventSource:
     async def get_new_events(
         self,
         user: UserID,
-        from_key: str,
+        from_key: RoomStreamToken,
         limit: int,
         room_ids: List[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
-    ) -> Tuple[List[EventBase], str]:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         # We just ignore the key for now.
 
         to_key = self.get_current_key()
 
-        from_token = RoomStreamToken.parse(from_key)
-        if from_token.topological:
+        if from_key.topological:
             logger.warning("Stream has topological part!!!! %r", from_key)
-            from_key = "s%s" % (from_token.stream,)
+            from_key = RoomStreamToken(None, from_key.stream)
 
         app_service = self.store.get_app_service_by_user_id(user.to_string())
         if app_service:
@@ -1133,14 +1134,14 @@ class RoomEventSource:
                 events[:] = events[:limit]
 
             if events:
-                end_key = events[-1].internal_metadata.after
+                end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
             else:
                 end_key = to_key
 
         return (events, end_key)
 
-    def get_current_key(self) -> str:
-        return "s%d" % (self.store.get_room_max_stream_ordering(),)
+    def get_current_key(self) -> RoomStreamToken:
+        return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
 
     def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
         return self.store.get_room_events_max_id(room_id)
@@ -1260,10 +1261,10 @@ class RoomShutdownHandler:
             # We now wait for the create room to come back in via replication so
             # that we can assume that all the joins/invites have propogated before
             # we try and auto join below.
-            #
-            # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                self.hs.config.worker.writers.events, "events", stream_id
+                self.hs.config.worker.events_shard_config.get_instance(new_room_id),
+                "events",
+                stream_id,
             )
         else:
             new_room_id = None
@@ -1293,7 +1294,9 @@ class RoomShutdownHandler:
 
                 # Wait for leave to come in over replication before trying to forget.
                 await self._replication.wait_for_stream_position(
-                    self.hs.config.worker.writers.events, "events", stream_id
+                    self.hs.config.worker.events_shard_config.get_instance(room_id),
+                    "events",
+                    stream_id,
                 )
 
                 await self.room_member_handler.forget(target_requester.user, room_id)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5dd7b28391..4a13c8e912 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -38,7 +38,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
 
 class RoomListHandler(BaseHandler):
     def __init__(self, hs):
-        super(RoomListHandler, self).__init__(hs)
+        super().__init__(hs)
         self.enable_room_list_search = hs.config.enable_room_list_search
         self.response_cache = ResponseCache(hs, "room_list")
         self.remote_response_cache = ResponseCache(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 32b7e323fa..8feba8c90a 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -40,7 +40,7 @@ from synapse.events.validator import EventValidator
 from synapse.storage.roommember import RoomsForUser
 from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
 from synapse.util.async_helpers import Linearizer
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
 
 from ._base import BaseHandler
 
@@ -51,14 +51,12 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class RoomMemberHandler:
+class RoomMemberHandler(metaclass=abc.ABCMeta):
     # TODO(paul): This handler currently contains a messy conflation of
     #   low-level API that works on UserID objects and so on, and REST-level
     #   API that takes ID strings and returns pagination chunks. These concerns
     #   ought to be separated out a lot better.
 
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
@@ -82,13 +80,6 @@ class RoomMemberHandler:
         self._enable_lookup = hs.config.enable_3pid_lookup
         self.allow_per_room_profiles = self.config.allow_per_room_profiles
 
-        self._event_stream_writer_instance = hs.config.worker.writers.events
-        self._is_on_event_persistence_instance = (
-            self._event_stream_writer_instance == hs.get_instance_name()
-        )
-        if self._is_on_event_persistence_instance:
-            self.persist_event_storage = hs.get_storage().persistence
-
         self._join_rate_limiter_local = Ratelimiter(
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
@@ -149,17 +140,6 @@ class RoomMemberHandler:
         raise NotImplementedError()
 
     @abc.abstractmethod
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Notifies distributor on master process that the user has joined the
-        room.
-
-        Args:
-            target
-            room_id
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Notifies distributor on master process that the user has left the
         room.
@@ -221,7 +201,6 @@ class RoomMemberHandler:
 
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
 
-        newly_joined = False
         if event.membership == Membership.JOIN:
             newly_joined = True
             if prev_member_event_id:
@@ -246,12 +225,7 @@ class RoomMemberHandler:
             requester, event, context, extra_users=[target], ratelimit=ratelimit,
         )
 
-        if event.membership == Membership.JOIN and newly_joined:
-            # Only fire user_joined_room if the user has actually joined the
-            # room. Don't bother if the user is just changing their profile
-            # info.
-            await self._user_joined_room(target, room_id)
-        elif event.membership == Membership.LEAVE:
+        if event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
@@ -726,17 +700,7 @@ class RoomMemberHandler:
             (EventTypes.Member, event.state_key), None
         )
 
-        if event.membership == Membership.JOIN:
-            # Only fire user_joined_room if the user has actually joined the
-            # room. Don't bother if the user is just changing their profile
-            # info.
-            newly_joined = True
-            if prev_member_event_id:
-                prev_member_event = await self.store.get_event(prev_member_event_id)
-                newly_joined = prev_member_event.membership != Membership.JOIN
-            if newly_joined:
-                await self._user_joined_room(target_user, room_id)
-        elif event.membership == Membership.LEAVE:
+        if event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
@@ -1002,10 +966,9 @@ class RoomMemberHandler:
 
 class RoomMemberMasterHandler(RoomMemberHandler):
     def __init__(self, hs):
-        super(RoomMemberMasterHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.distributor = hs.get_distributor()
-        self.distributor.declare("user_joined_room")
         self.distributor.declare("user_left_room")
 
     async def _is_remote_room_too_complex(
@@ -1085,7 +1048,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         event_id, stream_id = await self.federation_handler.do_invite_join(
             remote_room_hosts, room_id, user.to_string(), content
         )
-        await self._user_joined_room(user, room_id)
 
         # Check the room we just joined wasn't too large, if we didn't fetch the
         # complexity of it before.
@@ -1228,11 +1190,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         )
         return event.event_id, stream_id
 
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Implements RoomMemberHandler._user_joined_room
-        """
-        user_joined_room(self.distributor, target, room_id)
-
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 897338fd54..f2e88f6a5b 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 
 class RoomMemberWorkerHandler(RoomMemberHandler):
     def __init__(self, hs):
-        super(RoomMemberWorkerHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self._remote_join_client = ReplRemoteJoin.make_client(hs)
         self._remote_reject_client = ReplRejectInvite.make_client(hs)
@@ -57,8 +57,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
             content=content,
         )
 
-        await self._user_joined_room(user, room_id)
-
         return ret["event_id"], ret["stream_id"]
 
     async def remote_reject_invite(
@@ -81,13 +79,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         )
         return ret["event_id"], ret["stream_id"]
 
-    async def _user_joined_room(self, target: UserID, room_id: str) -> None:
-        """Implements RoomMemberHandler._user_joined_room
-        """
-        await self._notify_change_client(
-            user_id=target.to_string(), room_id=room_id, change="joined"
-        )
-
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
         """
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 66b063f991..285c481a96 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -21,9 +21,10 @@ import saml2
 import saml2.response
 from saml2.client import Saml2Client
 
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
+from synapse.http.server import respond_with_html
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
@@ -41,7 +42,11 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+class MappingException(Exception):
+    """Used to catch errors when mapping the SAML2 response to a user."""
+
+
+@attr.s(slots=True)
 class Saml2SessionData:
     """Data we track about SAML2 sessions"""
 
@@ -68,6 +73,7 @@ class SamlHandler:
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
         self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
+        self._error_template = hs.config.sso_error_template
 
         # plugin to do custom mapping from saml response to mxid
         self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@@ -84,6 +90,25 @@ class SamlHandler:
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
 
+    def _render_error(
+        self, request, error: str, error_description: Optional[str] = None
+    ) -> None:
+        """Render the error template and respond to the request with it.
+
+        This is used to show errors to the user. The template of this page can
+        be found under `synapse/res/templates/sso_error.html`.
+
+        Args:
+            request: The incoming request from the browser.
+                We'll respond with an HTML page describing the error.
+            error: A technical identifier for this error.
+            error_description: A human-readable description of the error.
+        """
+        html = self._error_template.render(
+            error=error, error_description=error_description
+        )
+        respond_with_html(request, 400, html)
+
     def handle_redirect_request(
         self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
     ) -> bytes:
@@ -134,49 +159,6 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
-            0
-        ].decode("ascii", "surrogateescape")
-        ip_address = self.hs.get_ip_from_request(request)
-
-        user_id, current_session = await self._map_saml_response_to_user(
-            resp_bytes, relay_state, user_agent, ip_address
-        )
-
-        # Complete the interactive auth session or the login.
-        if current_session and current_session.ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, current_session.ui_auth_session_id, request
-            )
-
-        else:
-            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
-
-    async def _map_saml_response_to_user(
-        self,
-        resp_bytes: str,
-        client_redirect_url: str,
-        user_agent: str,
-        ip_address: str,
-    ) -> Tuple[str, Optional[Saml2SessionData]]:
-        """
-        Given a sample response, retrieve the cached session and user for it.
-
-        Args:
-            resp_bytes: The SAML response.
-            client_redirect_url: The redirect URL passed in by the client.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
-
-        Returns:
-             Tuple of the user ID and SAML session associated with this response.
-
-        Raises:
-            SynapseError if there was a problem with the response.
-            RedirectException: some mapping providers may raise this if they need
-                to redirect to an interstitial page.
-        """
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -189,12 +171,23 @@ class SamlHandler:
             # in the (user-visible) exception message, so let's log the exception here
             # so we can track down the session IDs later.
             logger.warning(str(e))
-            raise SynapseError(400, "Unexpected SAML2 login.")
+            self._render_error(
+                request, "unsolicited_response", "Unexpected SAML2 login."
+            )
+            return
         except Exception as e:
-            raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
+            self._render_error(
+                request,
+                "invalid_response",
+                "Unable to parse SAML2 response: %s." % (e,),
+            )
+            return
 
         if saml2_auth.not_signed:
-            raise SynapseError(400, "SAML2 response was not signed.")
+            self._render_error(
+                request, "unsigned_respond", "SAML2 response was not signed."
+            )
+            return
 
         logger.debug("SAML2 response: %s", saml2_auth.origxml)
         for assertion in saml2_auth.assertions:
@@ -213,15 +206,73 @@ class SamlHandler:
             saml2_auth.in_response_to, None
         )
 
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
         for requirement in self._saml2_attribute_requirements:
-            _check_attribute_requirement(saml2_auth.ava, requirement)
+            if not _check_attribute_requirement(saml2_auth.ava, requirement):
+                self._render_error(
+                    request, "unauthorised", "You are not authorised to log in here."
+                )
+                return
+
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+        ip_address = self.hs.get_ip_from_request(request)
+
+        # Call the mapper to register/login the user
+        try:
+            user_id = await self._map_saml_response_to_user(
+                saml2_auth, relay_state, user_agent, ip_address
+            )
+        except MappingException as e:
+            logger.exception("Could not map user")
+            self._render_error(request, "mapping_error", str(e))
+            return
+
+        # Complete the interactive auth session or the login.
+        if current_session and current_session.ui_auth_session_id:
+            await self._auth_handler.complete_sso_ui_auth(
+                user_id, current_session.ui_auth_session_id, request
+            )
+
+        else:
+            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+
+    async def _map_saml_response_to_user(
+        self,
+        saml2_auth: saml2.response.AuthnResponse,
+        client_redirect_url: str,
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """
+        Given a SAML response, retrieve the user ID for it and possibly register the user.
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+
+        Returns:
+             The user ID associated with this response.
+
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
+        """
 
         remote_user_id = self._user_mapping_provider.get_remote_user_id(
             saml2_auth, client_redirect_url
         )
 
         if not remote_user_id:
-            raise Exception("Failed to extract remote user id from SAML response")
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
+            )
 
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
@@ -235,7 +286,7 @@ class SamlHandler:
             )
             if registered_user_id is not None:
                 logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id, current_session
+                return registered_user_id
 
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
@@ -260,7 +311,7 @@ class SamlHandler:
                     await self._datastore.record_user_external_id(
                         self._auth_provider_id, remote_user_id, registered_user_id
                     )
-                    return registered_user_id, current_session
+                    return registered_user_id
 
             # Map saml response to user attributes using the configured mapping provider
             for i in range(1000):
@@ -277,7 +328,7 @@ class SamlHandler:
 
                 localpart = attribute_dict.get("mxid_localpart")
                 if not localpart:
-                    raise Exception(
+                    raise MappingException(
                         "Error parsing SAML2 response: SAML mapping provider plugin "
                         "did not return a mxid_localpart value"
                     )
@@ -294,8 +345,8 @@ class SamlHandler:
             else:
                 # Unable to generate a username in 1000 iterations
                 # Break and return error to the user
-                raise SynapseError(
-                    500, "Unable to generate a Matrix ID from the SAML response"
+                raise MappingException(
+                    "Unable to generate a Matrix ID from the SAML response"
                 )
 
             logger.info("Mapped SAML user to local part %s", localpart)
@@ -310,7 +361,7 @@ class SamlHandler:
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
-            return registered_user_id, current_session
+            return registered_user_id
 
     def expire_sessions(self):
         expire_before = self._clock.time_msec() - self._saml2_session_lifetime
@@ -323,11 +374,11 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
+def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
     values = ava.get(req.attribute, [])
     for v in values:
         if v == req.value:
-            return
+            return True
 
     logger.info(
         "SAML2 attribute %s did not match required value '%s' (was '%s')",
@@ -335,7 +386,7 @@ def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
         req.value,
         values,
     )
-    raise AuthError(403, "You are not authorized to log in here.")
+    return False
 
 
 DOT_REPLACE_PATTERN = re.compile(
@@ -390,7 +441,7 @@ class DefaultSamlMappingProvider:
             return saml_response.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "'uid' not in SAML2 response")
+            raise MappingException("'uid' not in SAML2 response")
 
     def saml_response_to_user_attributes(
         self,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index d58f9788c5..6a76c20d79 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
 
 class SearchHandler(BaseHandler):
     def __init__(self, hs):
-        super(SearchHandler, self).__init__(hs)
+        super().__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 4d245b618b..a5d67f828f 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -27,7 +27,7 @@ class SetPasswordHandler(BaseHandler):
     """Handler which deals with changing user account passwords"""
 
     def __init__(self, hs):
-        super(SetPasswordHandler, self).__init__(hs)
+        super().__init__(hs)
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
         self._password_policy_handler = hs.get_password_policy_handler()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e2ddb628ff..9b3a4f638b 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -89,14 +89,12 @@ class TimelineBatch:
     events = attr.ib(type=List[EventBase])
     limited = attr.ib(bool)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
         return bool(self.events)
 
-    __bool__ = __nonzero__  # python3
-
 
 # We can't freeze this class, because we need to update it after it's instantiated to
 # update its unread count. This is because we calculate the unread count for a room only
@@ -114,7 +112,7 @@ class JoinedSyncResult:
     summary = attr.ib(type=Optional[JsonDict])
     unread_count = attr.ib(type=int)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
@@ -127,8 +125,6 @@ class JoinedSyncResult:
             # else in the result, we don't need to send it.
         )
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class ArchivedSyncResult:
@@ -137,26 +133,22 @@ class ArchivedSyncResult:
     state = attr.ib(type=StateMap[EventBase])
     account_data = attr.ib(type=List[JsonDict])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if room needs to be part of the sync result.
         """
         return bool(self.timeline or self.state or self.account_data)
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class InvitedSyncResult:
     room_id = attr.ib(type=str)
     invite = attr.ib(type=EventBase)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Invited rooms should always be reported to the client"""
         return True
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class GroupsSyncResult:
@@ -164,11 +156,9 @@ class GroupsSyncResult:
     invite = attr.ib(type=JsonDict)
     leave = attr.ib(type=JsonDict)
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         return bool(self.join or self.invite or self.leave)
 
-    __bool__ = __nonzero__  # python3
-
 
 @attr.s(slots=True, frozen=True)
 class DeviceLists:
@@ -181,13 +171,11 @@ class DeviceLists:
     changed = attr.ib(type=Collection[str])
     left = attr.ib(type=Collection[str])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         return bool(self.changed or self.left)
 
-    __bool__ = __nonzero__  # python3
-
 
-@attr.s
+@attr.s(slots=True)
 class _RoomChanges:
     """The set of room entries to include in the sync, plus the set of joined
     and left room IDs since last sync.
@@ -227,7 +215,7 @@ class SyncResult:
     device_one_time_keys_count = attr.ib(type=JsonDict)
     groups = attr.ib(type=Optional[GroupsSyncResult])
 
-    def __nonzero__(self) -> bool:
+    def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
         to tell if the notifier needs to wait for more events when polling for
         events.
@@ -243,8 +231,6 @@ class SyncResult:
             or self.groups
         )
 
-    __bool__ = __nonzero__  # python3
-
 
 class SyncHandler:
     def __init__(self, hs: "HomeServer"):
@@ -378,7 +364,7 @@ class SyncHandler:
         sync_config = sync_result_builder.sync_config
 
         with Measure(self.clock, "ephemeral_by_room"):
-            typing_key = since_token.typing_key if since_token else "0"
+            typing_key = since_token.typing_key if since_token else 0
 
             room_ids = sync_result_builder.joined_room_ids
 
@@ -402,7 +388,7 @@ class SyncHandler:
                 event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
                 ephemeral_by_room.setdefault(room_id, []).append(event_copy)
 
-            receipt_key = since_token.receipt_key if since_token else "0"
+            receipt_key = since_token.receipt_key if since_token else 0
 
             receipt_source = self.event_sources.sources["receipt"]
             receipts, receipt_key = await receipt_source.get_new_events(
@@ -533,7 +519,7 @@ class SyncHandler:
             if len(recents) > timeline_limit:
                 limited = True
                 recents = recents[-timeline_limit:]
-                room_key = recents[0].internal_metadata.before
+                room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
 
             prev_batch_token = now_token.copy_and_replace("room_key", room_key)
 
@@ -1310,12 +1296,11 @@ class SyncHandler:
         presence_source = self.event_sources.sources["presence"]
 
         since_token = sync_result_builder.since_token
+        presence_key = None
+        include_offline = False
         if since_token and not sync_result_builder.full_state:
             presence_key = since_token.presence_key
             include_offline = True
-        else:
-            presence_key = None
-            include_offline = False
 
         presence, presence_key = await presence_source.get_new_events(
             user=user,
@@ -1323,6 +1308,7 @@ class SyncHandler:
             is_guest=sync_config.is_guest,
             include_offline=include_offline,
         )
+        assert presence_key
         sync_result_builder.now_token = now_token.copy_and_replace(
             "presence_key", presence_key
         )
@@ -1485,7 +1471,7 @@ class SyncHandler:
         if rooms_changed:
             return True
 
-        stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
+        stream_id = since_token.room_key.stream
         for room_id in sync_result_builder.joined_room_ids:
             if self.store.has_room_changed_since(room_id, stream_id):
                 return True
@@ -1751,7 +1737,7 @@ class SyncHandler:
                             continue
 
                 leave_token = now_token.copy_and_replace(
-                    "room_key", "s%d" % (event.stream_ordering,)
+                    "room_key", RoomStreamToken(None, event.stream_ordering)
                 )
                 room_entries.append(
                     RoomSyncResultBuilder(
@@ -2038,7 +2024,7 @@ def _calculate_state(
     return {event_id_to_key[e]: e for e in state_ids}
 
 
-@attr.s
+@attr.s(slots=True)
 class SyncResultBuilder:
     """Used to help build up a new SyncResult for a user
 
@@ -2074,7 +2060,7 @@ class SyncResultBuilder:
     to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
 
 
-@attr.s
+@attr.s(slots=True)
 class RoomSyncResultBuilder:
     """Stores information needed to create either a `JoinedSyncResult` or
     `ArchivedSyncResult`.
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index e21f8dbc58..79393c8829 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -37,7 +37,7 @@ class UserDirectoryHandler(StateDeltasHandler):
     """
 
     def __init__(self, hs):
-        super(UserDirectoryHandler, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()