summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py6
-rw-r--r--synapse/appservice/api.py2
-rw-r--r--synapse/config/_base.py10
-rw-r--r--synapse/config/oidc_config.py6
-rw-r--r--synapse/crypto/context_factory.py6
-rw-r--r--synapse/crypto/keyring.py70
-rw-r--r--synapse/handlers/device.py30
-rw-r--r--synapse/handlers/federation.py16
-rw-r--r--synapse/handlers/message.py108
-rw-r--r--synapse/handlers/oidc_handler.py42
-rw-r--r--synapse/handlers/sync.py10
-rw-r--r--synapse/http/client.py187
-rw-r--r--synapse/notifier.py55
-rw-r--r--synapse/push/pusherpool.py18
-rw-r--r--synapse/python_dependencies.py13
-rw-r--r--synapse/replication/slave/storage/_base.py2
-rw-r--r--synapse/replication/tcp/client.py12
-rw-r--r--synapse/res/templates/sso_error.html2
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/event_reports.py88
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py14
-rw-r--r--synapse/state/__init__.py64
-rw-r--r--synapse/storage/databases/main/__init__.py8
-rw-r--r--synapse/storage/databases/main/account_data.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/devices.py6
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py2
-rw-r--r--synapse/storage/databases/main/events.py16
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/group_server.py2
-rw-r--r--synapse/storage/databases/main/presence.py4
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/receipts.py2
-rw-r--r--synapse/storage/databases/main/registration.py18
-rw-r--r--synapse/storage/databases/main/room.py101
-rw-r--r--synapse/storage/databases/main/roommember.py14
-rw-r--r--synapse/storage/databases/main/schema/delta/56/event_labels.sql2
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres4
-rw-r--r--synapse/storage/databases/main/schema/delta/58/18stream_positions.sql22
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--synapse/storage/databases/main/tags.py4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py2
-rw-r--r--synapse/storage/persist_events.py14
-rw-r--r--synapse/storage/roommember.py2
-rw-r--r--synapse/storage/util/id_generators.py282
-rw-r--r--synapse/types.py15
48 files changed, 924 insertions, 387 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index a95753dcc7..e40b582bd5 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.20.0rc5"
+__version__ = "1.20.1"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 75388643ee..1071a0576e 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -218,11 +218,7 @@ class Auth:
             # Deny the request if the user account has expired.
             if self._account_validity.enabled and not allow_expired:
                 user_id = user.to_string()
-                expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
-                if (
-                    expiration_ts is not None
-                    and self.clock.time_msec() >= expiration_ts
-                ):
+                if await self.store.is_account_expired(user_id, self.clock.time_msec()):
                     raise AuthError(
                         403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
                     )
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 9523c3a5d9..2c1dae5984 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                 urllib.parse.quote(protocol),
             )
             try:
-                info = await self.get_json(uri, {})
+                info = await self.get_json(uri)
 
                 if not _is_valid_3pe_metadata(info):
                     logger.warning(
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index bb9bf8598d..05a66841c3 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -194,7 +194,10 @@ class Config:
             return file_stream.read()
 
     def read_templates(
-        self, filenames: List[str], custom_template_directory: Optional[str] = None,
+        self,
+        filenames: List[str],
+        custom_template_directory: Optional[str] = None,
+        autoescape: bool = False,
     ) -> List[jinja2.Template]:
         """Load a list of template files from disk using the given variables.
 
@@ -210,6 +213,9 @@ class Config:
             custom_template_directory: A directory to try to look for the templates
                 before using the default Synapse template directory instead.
 
+            autoescape: Whether to autoescape variables before inserting them into the
+                template.
+
         Raises:
             ConfigError: if the file's path is incorrect or otherwise cannot be read.
 
@@ -233,7 +239,7 @@ class Config:
             search_directories.insert(0, custom_template_directory)
 
         loader = jinja2.FileSystemLoader(search_directories)
-        env = jinja2.Environment(loader=loader, autoescape=True)
+        env = jinja2.Environment(loader=loader, autoescape=autoescape)
 
         # Update the environment with our custom filters
         env.filters.update(
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index e0939bce84..70fc8a2f62 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -56,6 +56,7 @@ class OIDCConfig(Config):
         self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
         self.oidc_jwks_uri = oidc_config.get("jwks_uri")
         self.oidc_skip_verification = oidc_config.get("skip_verification", False)
+        self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
 
         ump_config = oidc_config.get("user_mapping_provider", {})
         ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
@@ -158,6 +159,11 @@ class OIDCConfig(Config):
           #
           #skip_verification: true
 
+          # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
+          # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
+          #
+          #allow_existing_users: true
+
           # An external module can be provided here as a custom solution to mapping
           # attributes returned from a OIDC provider onto a matrix user.
           #
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 2b03f5ac76..79668a402e 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -45,7 +45,11 @@ _TLS_VERSION_MAP = {
 
 class ServerContextFactory(ContextFactory):
     """Factory for PyOpenSSL SSL contexts that are used to handle incoming
-    connections."""
+    connections.
+
+    TODO: replace this with an implementation of IOpenSSLServerConnectionCreator,
+    per https://github.com/matrix-org/synapse/issues/1691
+    """
 
     def __init__(self, config):
         # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 42e4087a92..c04ad77cf9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -42,7 +42,6 @@ from synapse.api.errors import (
 )
 from synapse.logging.context import (
     PreserveLoggingContext,
-    current_context,
     make_deferred_yieldable,
     preserve_fn,
     run_in_background,
@@ -233,8 +232,6 @@ class Keyring:
         """
 
         try:
-            ctx = current_context()
-
             # map from server name to a set of outstanding request ids
             server_to_request_ids = {}
 
@@ -265,12 +262,8 @@ class Keyring:
 
                 # if there are no more requests for this server, we can drop the lock.
                 if not server_requests:
-                    with PreserveLoggingContext(ctx):
-                        logger.debug("Releasing key lookup lock on %s", server_name)
-
-                    # ... but not immediately, as that can cause stack explosions if
-                    # we get a long queue of lookups.
-                    self.clock.call_later(0, drop_server_lock, server_name)
+                    logger.debug("Releasing key lookup lock on %s", server_name)
+                    drop_server_lock(server_name)
 
                 return res
 
@@ -335,20 +328,32 @@ class Keyring:
                         )
 
                     # look for any requests which weren't satisfied
-                    with PreserveLoggingContext():
-                        for verify_request in remaining_requests:
-                            verify_request.key_ready.errback(
-                                SynapseError(
-                                    401,
-                                    "No key for %s with ids in %s (min_validity %i)"
-                                    % (
-                                        verify_request.server_name,
-                                        verify_request.key_ids,
-                                        verify_request.minimum_valid_until_ts,
-                                    ),
-                                    Codes.UNAUTHORIZED,
-                                )
+                    while remaining_requests:
+                        verify_request = remaining_requests.pop()
+                        rq_str = (
+                            "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
+                            % (
+                                verify_request.server_name,
+                                verify_request.key_ids,
+                                verify_request.minimum_valid_until_ts,
                             )
+                        )
+
+                        # If we run the errback immediately, it may cancel our
+                        # loggingcontext while we are still in it, so instead we
+                        # schedule it for the next time round the reactor.
+                        #
+                        # (this also ensures that we don't get a stack overflow if we
+                        # has a massive queue of lookups waiting for this server).
+                        self.clock.call_later(
+                            0,
+                            verify_request.key_ready.errback,
+                            SynapseError(
+                                401,
+                                "Failed to find any key to satisfy %s" % (rq_str,),
+                                Codes.UNAUTHORIZED,
+                            ),
+                        )
             except Exception as err:
                 # we don't really expect to get here, because any errors should already
                 # have been caught and logged. But if we do, let's log the error and make
@@ -410,10 +415,23 @@ class Keyring:
                     # key was not valid at this point
                     continue
 
-                with PreserveLoggingContext():
-                    verify_request.key_ready.callback(
-                        (server_name, key_id, fetch_key_result.verify_key)
-                    )
+                # we have a valid key for this request. If we run the callback
+                # immediately, it may cancel our loggingcontext while we are still in
+                # it, so instead we schedule it for the next time round the reactor.
+                #
+                # (this also ensures that we don't get a stack overflow if we had
+                # a massive queue of lookups waiting for this server).
+                logger.debug(
+                    "Found key %s:%s for %s",
+                    server_name,
+                    key_id,
+                    verify_request.request_name,
+                )
+                self.clock.call_later(
+                    0,
+                    verify_request.key_ready.callback,
+                    (server_name, key_id, fetch_key_result.verify_key),
+                )
                 completed.append(verify_request)
                 break
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 55a9787439..4149520d6c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
 from synapse.api import errors
 from synapse.api.constants import EventTypes
 from synapse.api.errors import (
+    Codes,
     FederationDeniedError,
     HttpResponseException,
     RequestSendFailed,
@@ -265,6 +266,24 @@ class DeviceHandler(DeviceWorkerHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
+    def _check_device_name_length(self, name: str):
+        """
+        Checks whether a device name is longer than the maximum allowed length.
+
+        Args:
+            name: The name of the device.
+
+        Raises:
+            SynapseError: if the device name is too long.
+        """
+        if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN:
+            raise SynapseError(
+                400,
+                "Device display name is too long (max %i)"
+                % (MAX_DEVICE_DISPLAY_NAME_LEN,),
+                errcode=Codes.TOO_LARGE,
+            )
+
     async def check_device_registered(
         self, user_id, device_id, initial_device_display_name=None
     ):
@@ -282,6 +301,9 @@ class DeviceHandler(DeviceWorkerHandler):
         Returns:
             str: device id (generated if none was supplied)
         """
+
+        self._check_device_name_length(initial_device_display_name)
+
         if device_id is not None:
             new_device = await self.store.store_device(
                 user_id=user_id,
@@ -397,12 +419,8 @@ class DeviceHandler(DeviceWorkerHandler):
 
         # Reject a new displayname which is too long.
         new_display_name = content.get("display_name")
-        if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN:
-            raise SynapseError(
-                400,
-                "Device display name is too long (max %i)"
-                % (MAX_DEVICE_DISPLAY_NAME_LEN,),
-            )
+
+        self._check_device_name_length(new_display_name)
 
         try:
             await self.store.update_device(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ea9264e751..9f773aefa7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -74,6 +74,8 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     JsonDict,
     MutableStateMap,
+    PersistedEventPosition,
+    RoomStreamToken,
     StateMap,
     UserID,
     get_domain_from_id,
@@ -2956,7 +2958,7 @@ class FederationHandler(BaseHandler):
             )
             return result["max_stream_id"]
         else:
-            max_stream_id = await self.storage.persistence.persist_events(
+            max_stream_token = await self.storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
@@ -2967,12 +2969,12 @@ class FederationHandler(BaseHandler):
 
             if not backfilled:  # Never notify for backfilled events
                 for event, _ in event_and_contexts:
-                    await self._notify_persisted_event(event, max_stream_id)
+                    await self._notify_persisted_event(event, max_stream_token)
 
-            return max_stream_id
+            return max_stream_token.stream
 
     async def _notify_persisted_event(
-        self, event: EventBase, max_stream_id: int
+        self, event: EventBase, max_stream_token: RoomStreamToken
     ) -> None:
         """Checks to see if notifier/pushers should be notified about the
         event or not.
@@ -2998,9 +3000,11 @@ class FederationHandler(BaseHandler):
         elif event.internal_metadata.is_outlier():
             return
 
-        event_stream_id = event.internal_metadata.stream_ordering
+        event_pos = PersistedEventPosition(
+            self._instance_name, event.internal_metadata.stream_ordering
+        )
         self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id, extra_users=extra_users
+            event, event_pos, max_stream_token, extra_users=extra_users
         )
 
     async def _clean_room_for_join(self, room_id: str) -> None:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a8fe5cf4e2..ee271e85e5 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1138,7 +1138,7 @@ class EventCreationHandler:
             if prev_state_ids:
                 raise AuthError(403, "Changing the room create event is forbidden")
 
-        event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
+        event_pos, max_stream_token = await self.storage.persistence.persist_event(
             event, context=context
         )
 
@@ -1149,7 +1149,7 @@ class EventCreationHandler:
         def _notify():
             try:
                 self.notifier.on_new_room_event(
-                    event, event_stream_id, max_stream_id, extra_users=extra_users
+                    event, event_pos, max_stream_token, extra_users=extra_users
                 )
             except Exception:
                 logger.exception("Error notifying about new room event")
@@ -1161,7 +1161,7 @@ class EventCreationHandler:
             # matters as sometimes presence code can take a while.
             run_in_background(self._bump_active_time, requester.user)
 
-        return event_stream_id
+        return event_pos.stream
 
     async def _bump_active_time(self, user: UserID) -> None:
         try:
@@ -1182,54 +1182,7 @@ class EventCreationHandler:
         )
 
         for room_id in room_ids:
-            # For each room we need to find a joined member we can use to send
-            # the dummy event with.
-
-            latest_event_ids = await self.store.get_prev_events_for_room(room_id)
-
-            members = await self.state.get_current_users_in_room(
-                room_id, latest_event_ids=latest_event_ids
-            )
-            dummy_event_sent = False
-            for user_id in members:
-                if not self.hs.is_mine_id(user_id):
-                    continue
-                requester = create_requester(user_id)
-                try:
-                    event, context = await self.create_event(
-                        requester,
-                        {
-                            "type": "org.matrix.dummy_event",
-                            "content": {},
-                            "room_id": room_id,
-                            "sender": user_id,
-                        },
-                        prev_event_ids=latest_event_ids,
-                    )
-
-                    event.internal_metadata.proactively_send = False
-
-                    # Since this is a dummy-event it is OK if it is sent by a
-                    # shadow-banned user.
-                    await self.send_nonmember_event(
-                        requester,
-                        event,
-                        context,
-                        ratelimit=False,
-                        ignore_shadow_ban=True,
-                    )
-                    dummy_event_sent = True
-                    break
-                except ConsentNotGivenError:
-                    logger.info(
-                        "Failed to send dummy event into room %s for user %s due to "
-                        "lack of consent. Will try another user" % (room_id, user_id)
-                    )
-                except AuthError:
-                    logger.info(
-                        "Failed to send dummy event into room %s for user %s due to "
-                        "lack of power. Will try another user" % (room_id, user_id)
-                    )
+            dummy_event_sent = await self._send_dummy_event_for_room(room_id)
 
             if not dummy_event_sent:
                 # Did not find a valid user in the room, so remove from future attempts
@@ -1242,6 +1195,59 @@ class EventCreationHandler:
                 now = self.clock.time_msec()
                 self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now
 
+    async def _send_dummy_event_for_room(self, room_id: str) -> bool:
+        """Attempt to send a dummy event for the given room.
+
+        Args:
+            room_id: room to try to send an event from
+
+        Returns:
+            True if a dummy event was successfully sent. False if no user was able
+            to send an event.
+        """
+
+        # For each room we need to find a joined member we can use to send
+        # the dummy event with.
+        latest_event_ids = await self.store.get_prev_events_for_room(room_id)
+        members = await self.state.get_current_users_in_room(
+            room_id, latest_event_ids=latest_event_ids
+        )
+        for user_id in members:
+            if not self.hs.is_mine_id(user_id):
+                continue
+            requester = create_requester(user_id)
+            try:
+                event, context = await self.create_event(
+                    requester,
+                    {
+                        "type": "org.matrix.dummy_event",
+                        "content": {},
+                        "room_id": room_id,
+                        "sender": user_id,
+                    },
+                    prev_event_ids=latest_event_ids,
+                )
+
+                event.internal_metadata.proactively_send = False
+
+                # Since this is a dummy-event it is OK if it is sent by a
+                # shadow-banned user.
+                await self.send_nonmember_event(
+                    requester, event, context, ratelimit=False, ignore_shadow_ban=True,
+                )
+                return True
+            except ConsentNotGivenError:
+                logger.info(
+                    "Failed to send dummy event into room %s for user %s due to "
+                    "lack of consent. Will try another user" % (room_id, user_id)
+                )
+            except AuthError:
+                logger.info(
+                    "Failed to send dummy event into room %s for user %s due to "
+                    "lack of power. Will try another user" % (room_id, user_id)
+                )
+        return False
+
     def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
         expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
         to_expire = set()
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 4230dbaf99..0e06e4408d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -114,6 +114,7 @@ class OidcHandler:
             hs.config.oidc_user_mapping_provider_config
         )  # type: OidcMappingProvider
         self._skip_verification = hs.config.oidc_skip_verification  # type: bool
+        self._allow_existing_users = hs.config.oidc_allow_existing_users  # type: bool
 
         self._http_client = hs.get_proxied_http_client()
         self._auth_handler = hs.get_auth_handler()
@@ -849,7 +850,8 @@ class OidcHandler:
         If we don't find the user that way, we should register the user,
         mapping the localpart and the display name from the UserInfo.
 
-        If a user already exists with the mxid we've mapped, raise an exception.
+        If a user already exists with the mxid we've mapped and allow_existing_users
+        is disabled, raise an exception.
 
         Args:
             userinfo: an object representing the user
@@ -905,21 +907,31 @@ class OidcHandler:
 
         localpart = map_username_to_mxid_localpart(attributes["localpart"])
 
-        user_id = UserID(localpart, self._hostname)
-        if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
-            # This mxid is taken
-            raise MappingException(
-                "mxid '{}' is already taken".format(user_id.to_string())
+        user_id = UserID(localpart, self._hostname).to_string()
+        users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+        if users:
+            if self._allow_existing_users:
+                if len(users) == 1:
+                    registered_user_id = next(iter(users))
+                elif user_id in users:
+                    registered_user_id = user_id
+                else:
+                    raise MappingException(
+                        "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+                            user_id, list(users.keys())
+                        )
+                    )
+            else:
+                # This mxid is taken
+                raise MappingException("mxid '{}' is already taken".format(user_id))
+        else:
+            # It's the first time this user is logging in and the mapped mxid was
+            # not taken, register the user
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart,
+                default_display_name=attributes["display_name"],
+                user_agent_ips=(user_agent, ip_address),
             )
-
-        # It's the first time this user is logging in and the mapped mxid was
-        # not taken, register the user
-        registered_user_id = await self._registration_handler.register_user(
-            localpart=localpart,
-            default_display_name=attributes["display_name"],
-            user_agent_ips=(user_agent, ip_address),
-        )
-
         await self._datastore.record_user_external_id(
             self._auth_provider_id, remote_user_id, registered_user_id,
         )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9b3a4f638b..e948efef2e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -967,7 +967,7 @@ class SyncHandler:
             raise NotImplementedError()
         else:
             joined_room_ids = await self.get_rooms_for_user_at(
-                user_id, now_token.room_stream_id
+                user_id, now_token.room_key
             )
         sync_result_builder = SyncResultBuilder(
             sync_config,
@@ -1916,7 +1916,7 @@ class SyncHandler:
             raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
     async def get_rooms_for_user_at(
-        self, user_id: str, stream_ordering: int
+        self, user_id: str, room_key: RoomStreamToken
     ) -> FrozenSet[str]:
         """Get set of joined rooms for a user at the given stream ordering.
 
@@ -1942,15 +1942,15 @@ class SyncHandler:
         # If the membership's stream ordering is after the given stream
         # ordering, we need to go and work out if the user was in the room
         # before.
-        for room_id, membership_stream_ordering in joined_rooms:
-            if membership_stream_ordering <= stream_ordering:
+        for room_id, event_pos in joined_rooms:
+            if not event_pos.persisted_after(room_key):
                 joined_room_ids.add(room_id)
                 continue
 
             logger.info("User joined room after current token: %s", room_id)
 
             extrems = await self.store.get_forward_extremeties_for_room(
-                room_id, stream_ordering
+                room_id, event_pos.stream
             )
             users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
             if user_id in users_in_room:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 13fcab3378..4694adc400 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -17,6 +17,18 @@
 import logging
 import urllib
 from io import BytesIO
+from typing import (
+    Any,
+    BinaryIO,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+)
 
 import treq
 from canonicaljson import encode_canonical_json
@@ -37,6 +49,7 @@ from twisted.web._newclient import ResponseDone
 from twisted.web.client import Agent, HTTPConnectionPool, readBody
 from twisted.web.http import PotentialDataLoss
 from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
 
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.http import (
@@ -57,6 +70,19 @@ incoming_responses_counter = Counter(
     "synapse_http_client_responses", "", ["method", "code"]
 )
 
+# the type of the headers list, to be passed to the t.w.h.Headers.
+# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
+# we simplify.
+RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
+
+# the value actually has to be a List, but List is invariant so we can't specify that
+# the entries can either be Lists or bytes.
+RawHeaderValue = Sequence[Union[str, bytes]]
+
+# the type of the query params, to be passed into `urlencode`
+QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
+QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
+
 
 def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
     """
@@ -285,13 +311,26 @@ class SimpleHttpClient:
                 ip_blacklist=self._ip_blacklist,
             )
 
-    async def request(self, method, uri, data=None, headers=None):
+    async def request(
+        self,
+        method: str,
+        uri: str,
+        data: Optional[bytes] = None,
+        headers: Optional[Headers] = None,
+    ) -> IResponse:
         """
         Args:
-            method (str): HTTP method to use.
-            uri (str): URI to query.
-            data (bytes): Data to send in the request body, if applicable.
-            headers (t.w.http_headers.Headers): Request headers.
+            method: HTTP method to use.
+            uri: URI to query.
+            data: Data to send in the request body, if applicable.
+            headers: Request headers.
+
+        Returns:
+            Response object, once the headers have been read.
+
+        Raises:
+            RequestTimedOutError if the request times out before the headers are read
+
         """
         # A small wrapper around self.agent.request() so we can easily attach
         # counters to it
@@ -324,6 +363,8 @@ class SimpleHttpClient:
                     headers=headers,
                     **self._extra_treq_args
                 )
+                # we use our own timeout mechanism rather than treq's as a workaround
+                # for https://twistedmatrix.com/trac/ticket/9534.
                 request_deferred = timeout_deferred(
                     request_deferred,
                     60,
@@ -353,18 +394,26 @@ class SimpleHttpClient:
                 set_tag("error_reason", e.args[0])
                 raise
 
-    async def post_urlencoded_get_json(self, uri, args={}, headers=None):
+    async def post_urlencoded_get_json(
+        self,
+        uri: str,
+        args: Mapping[str, Union[str, List[str]]] = {},
+        headers: Optional[RawHeaders] = None,
+    ) -> Any:
         """
         Args:
-            uri (str):
-            args (dict[str, str|List[str]]): query params
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: uri to query
+            args: parameters to be url-encoded in the body
+            headers: a map from header name to a list of values for that header
 
         Returns:
-            object: parsed json
+            parsed json
 
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException: On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -398,19 +447,24 @@ class SimpleHttpClient:
                 response.code, response.phrase.decode("ascii", errors="replace"), body
             )
 
-    async def post_json_get_json(self, uri, post_json, headers=None):
+    async def post_json_get_json(
+        self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
+    ) -> Any:
         """
 
         Args:
-            uri (str):
-            post_json (object):
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: URI to query.
+            post_json: request body, to be encoded as json
+            headers: a map from header name to a list of values for that header
 
         Returns:
-            object: parsed json
+            parsed json
 
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException: On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -440,21 +494,22 @@ class SimpleHttpClient:
                 response.code, response.phrase.decode("ascii", errors="replace"), body
             )
 
-    async def get_json(self, uri, args={}, headers=None):
-        """ Gets some json from the given URI.
+    async def get_json(
+        self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+    ) -> Any:
+        """Gets some json from the given URI.
 
         Args:
-            uri (str): The URI to request, not including query parameters
-            args (dict): A dictionary used to create query strings, defaults to
-                None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: The URI to request, not including query parameters
+            args: A dictionary used to create query string
+            headers: a map from header name to a list of values for that header
         Returns:
-            Succeeds when we get *any* 2xx HTTP response, with the
-            HTTP body as JSON.
+            Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -466,22 +521,27 @@ class SimpleHttpClient:
         body = await self.get_raw(uri, args, headers=headers)
         return json_decoder.decode(body.decode("utf-8"))
 
-    async def put_json(self, uri, json_body, args={}, headers=None):
-        """ Puts some json to the given URI.
+    async def put_json(
+        self,
+        uri: str,
+        json_body: Any,
+        args: QueryParams = {},
+        headers: RawHeaders = None,
+    ) -> Any:
+        """Puts some json to the given URI.
 
         Args:
-            uri (str): The URI to request, not including query parameters
-            json_body (dict): The JSON to put in the HTTP body,
-            args (dict): A dictionary used to create query strings, defaults to
-                None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: The URI to request, not including query parameters
+            json_body: The JSON to put in the HTTP body,
+            args: A dictionary used to create query strings
+            headers: a map from header name to a list of values for that header
         Returns:
-            Succeeds when we get *any* 2xx HTTP response, with the
-            HTTP body as JSON.
+            Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
         Raises:
+             RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -513,21 +573,23 @@ class SimpleHttpClient:
                 response.code, response.phrase.decode("ascii", errors="replace"), body
             )
 
-    async def get_raw(self, uri, args={}, headers=None):
-        """ Gets raw text from the given URI.
+    async def get_raw(
+        self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+    ) -> bytes:
+        """Gets raw text from the given URI.
 
         Args:
-            uri (str): The URI to request, not including query parameters
-            args (dict): A dictionary used to create query strings, defaults to
-                None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: The URI to request, not including query parameters
+            args: A dictionary used to create query strings
+            headers: a map from header name to a list of values for that header
         Returns:
-            Succeeds when we get *any* 2xx HTTP response, with the
+            Succeeds when we get a 2xx HTTP response, with the
             HTTP body as bytes.
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException on a non-2xx HTTP response.
         """
         if len(args):
@@ -552,16 +614,29 @@ class SimpleHttpClient:
     # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
     # The two should be factored out.
 
-    async def get_file(self, url, output_stream, max_size=None, headers=None):
+    async def get_file(
+        self,
+        url: str,
+        output_stream: BinaryIO,
+        max_size: Optional[int] = None,
+        headers: Optional[RawHeaders] = None,
+    ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
         """GETs a file from a given URL
         Args:
-            url (str): The URL to GET
-            output_stream (file): File to write the response body to.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            url: The URL to GET
+            output_stream: File to write the response body to.
+            headers: A map from header name to a list of values for that header
         Returns:
-            A (int,dict,string,int) tuple of the file length, dict of the response
+            A tuple of the file length, dict of the response
             headers, absolute URI of the response and HTTP response code.
+
+        Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
+            SynapseError: if the response is not a 2xx, the remote file is too large, or
+               another exception happens during the download.
         """
 
         actual_headers = {b"User-Agent": [self.user_agent]}
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 37546e6c21..b6b231c15d 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -42,7 +42,13 @@ from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.streams.config import PaginationConfig
-from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+    Collection,
+    PersistedEventPosition,
+    RoomStreamToken,
+    StreamToken,
+    UserID,
+)
 from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
 from synapse.util.metrics import Measure
 from synapse.visibility import filter_events_for_client
@@ -187,7 +193,7 @@ class Notifier:
         self.store = hs.get_datastore()
         self.pending_new_room_events = (
             []
-        )  # type: List[Tuple[int, EventBase, Collection[UserID]]]
+        )  # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
 
         # Called when there are new things to stream over replication
         self.replication_callbacks = []  # type: List[Callable[[], None]]
@@ -246,8 +252,8 @@ class Notifier:
     def on_new_room_event(
         self,
         event: EventBase,
-        room_stream_id: int,
-        max_room_stream_id: int,
+        event_pos: PersistedEventPosition,
+        max_room_stream_token: RoomStreamToken,
         extra_users: Collection[UserID] = [],
     ):
         """ Used by handlers to inform the notifier something has happened
@@ -261,16 +267,16 @@ class Notifier:
         until all previous events have been persisted before notifying
         the client streams.
         """
-        self.pending_new_room_events.append((room_stream_id, event, extra_users))
-        self._notify_pending_new_room_events(max_room_stream_id)
+        self.pending_new_room_events.append((event_pos, event, extra_users))
+        self._notify_pending_new_room_events(max_room_stream_token)
 
         self.notify_replication()
 
-    def _notify_pending_new_room_events(self, max_room_stream_id: int):
+    def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
         """Notify for the room events that were queued waiting for a previous
         event to be persisted.
         Args:
-            max_room_stream_id: The highest stream_id below which all
+            max_room_stream_token: The highest stream_id below which all
                 events have been persisted.
         """
         pending = self.pending_new_room_events
@@ -279,11 +285,9 @@ class Notifier:
         users = set()  # type: Set[UserID]
         rooms = set()  # type: Set[str]
 
-        for room_stream_id, event, extra_users in pending:
-            if room_stream_id > max_room_stream_id:
-                self.pending_new_room_events.append(
-                    (room_stream_id, event, extra_users)
-                )
+        for event_pos, event, extra_users in pending:
+            if event_pos.persisted_after(max_room_stream_token):
+                self.pending_new_room_events.append((event_pos, event, extra_users))
             else:
                 if (
                     event.type == EventTypes.Member
@@ -296,33 +300,32 @@ class Notifier:
 
         if users or rooms:
             self.on_new_event(
-                "room_key",
-                RoomStreamToken(None, max_room_stream_id),
-                users=users,
-                rooms=rooms,
+                "room_key", max_room_stream_token, users=users, rooms=rooms,
             )
-            self._on_updated_room_token(max_room_stream_id)
+            self._on_updated_room_token(max_room_stream_token)
 
-    def _on_updated_room_token(self, max_room_stream_id: int):
+    def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken):
         """Poke services that might care that the room position has been
         updated.
         """
 
         # poke any interested application service.
         run_as_background_process(
-            "_notify_app_services", self._notify_app_services, max_room_stream_id
+            "_notify_app_services", self._notify_app_services, max_room_stream_token
         )
 
         run_as_background_process(
-            "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id
+            "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token
         )
 
         if self.federation_sender:
-            self.federation_sender.notify_new_events(max_room_stream_id)
+            self.federation_sender.notify_new_events(max_room_stream_token.stream)
 
-    async def _notify_app_services(self, max_room_stream_id: int):
+    async def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
         try:
-            await self.appservice_handler.notify_interested_services(max_room_stream_id)
+            await self.appservice_handler.notify_interested_services(
+                max_room_stream_token.stream
+            )
         except Exception:
             logger.exception("Error notifying application services of event")
 
@@ -332,9 +335,9 @@ class Notifier:
         except Exception:
             logger.exception("Error notifying application services of event")
 
-    async def _notify_pusher_pool(self, max_room_stream_id: int):
+    async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
         try:
-            await self._pusher_pool.on_new_notifications(max_room_stream_id)
+            await self._pusher_pool.on_new_notifications(max_room_stream_token.stream)
         except Exception:
             logger.exception("Error pusher pool of event")
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index cc839ffce4..76150e117b 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -60,6 +60,8 @@ class PusherPool:
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
+        self._account_validity = hs.config.account_validity
+
         # We shard the handling of push notifications by user ID.
         self._pusher_shard_config = hs.config.push.pusher_shard_config
         self._instance_name = hs.get_instance_name()
@@ -202,6 +204,14 @@ class PusherPool:
             )
 
             for u in users_affected:
+                # Don't push if the user account has expired
+                if self._account_validity.enabled:
+                    expired = await self.store.is_account_expired(
+                        u, self.clock.time_msec()
+                    )
+                    if expired:
+                        continue
+
                 if u in self.pushers:
                     for p in self.pushers[u].values():
                         p.on_new_notifications(max_stream_id)
@@ -222,6 +232,14 @@ class PusherPool:
             )
 
             for u in users_affected:
+                # Don't push if the user account has expired
+                if self._account_validity.enabled:
+                    expired = await self.store.is_account_expired(
+                        u, self.clock.time_msec()
+                    )
+                    if expired:
+                        continue
+
                 if u in self.pushers:
                     for p in self.pushers[u].values():
                         p.on_new_receipts(min_stream_id, max_stream_id)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 67f019fd22..288631477e 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -37,6 +37,9 @@ logger = logging.getLogger(__name__)
 # installed when that optional dependency requirement is specified. It is passed
 # to setup() as extras_require in setup.py
 #
+# Note that these both represent runtime dependencies (and the versions
+# installed are checked at runtime).
+#
 # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
 
 REQUIREMENTS = [
@@ -92,20 +95,12 @@ CONDITIONAL_REQUIREMENTS = {
     "oidc": ["authlib>=0.14.0"],
     "systemd": ["systemd-python>=231"],
     "url_preview": ["lxml>=3.5.0"],
-    # Dependencies which are exclusively required by unit test code. This is
-    # NOT a list of all modules that are necessary to run the unit tests.
-    # Tests assume that all optional dependencies are installed.
-    #
-    # parameterized_class decorator was introduced in parameterized 0.7.0
-    "test": ["mock>=2.0", "parameterized>=0.7.0"],
     "sentry": ["sentry-sdk>=0.7.2"],
     "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
     "jwt": ["pyjwt>=1.6.4"],
     # hiredis is not a *strict* dependency, but it makes things much faster.
     # (if it is not installed, we fall back to slow code.)
     "redis": ["txredisapi>=1.4.7", "hiredis"],
-    # We pin black so that our tests don't start failing on new releases.
-    "lint": ["isort==5.0.3", "black==19.10b0", "flake8-comprehensions", "flake8"],
 }
 
 ALL_OPTIONAL_REQUIREMENTS = set()  # type: Set[str]
@@ -113,7 +108,7 @@ ALL_OPTIONAL_REQUIREMENTS = set()  # type: Set[str]
 for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
     # Exclude systemd as it's a system-based requirement.
     # Exclude lint as it's a dev-based requirement.
-    if name not in ["systemd", "lint"]:
+    if name not in ["systemd"]:
         ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS
 
 
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d25fa49e1a..d0089fe06c 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -31,11 +31,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
             self._cache_id_gen = MultiWriterIdGenerator(
                 db_conn,
                 database,
+                stream_name="caches",
                 instance_name=hs.get_instance_name(),
                 table="cache_invalidation_stream_by_instance",
                 instance_column="instance_name",
                 id_column="stream_id",
                 sequence_name="cache_invalidation_stream_seq",
+                writers=[],
             )  # type: Optional[MultiWriterIdGenerator]
         else:
             self._cache_id_gen = None
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e82b9e386f..55af3d41ea 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
     EventsStreamRow,
 )
-from synapse.types import UserID
+from synapse.types import PersistedEventPosition, RoomStreamToken, UserID
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -151,8 +151,14 @@ class ReplicationDataHandler:
                 extra_users = ()  # type: Tuple[UserID, ...]
                 if event.type == EventTypes.Member:
                     extra_users = (UserID.from_string(event.state_key),)
-                max_token = self.store.get_room_max_stream_ordering()
-                self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+                max_token = RoomStreamToken(
+                    None, self.store.get_room_max_stream_ordering()
+                )
+                event_pos = PersistedEventPosition(instance_name, token)
+                self.notifier.on_new_room_event(
+                    event, event_pos, max_token, extra_users
+                )
 
         # Notify any waiting deferreds. The list is ordered by position so we
         # just iterate through the list until we reach a position that is
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index af8459719a..944bc9c9ca 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -12,7 +12,7 @@
     <p>
         There was an error during authentication:
     </p>
-    <div id="errormsg" style="margin:20px 80px">{{ error_description }}</div>
+    <div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div>
     <p>
         If you are seeing this page after clicking a link sent to you via email, make
         sure you only click the confirmation link once, and that you open the
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 4a75c06480..5c5f00b213 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -31,6 +31,7 @@ from synapse.rest.admin.devices import (
     DeviceRestServlet,
     DevicesRestServlet,
 )
+from synapse.rest.admin.event_reports import EventReportsRestServlet
 from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
 from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
 from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
@@ -216,6 +217,7 @@ def register_servlets(hs, http_server):
     DeviceRestServlet(hs).register(http_server)
     DevicesRestServlet(hs).register(http_server)
     DeleteDevicesRestServlet(hs).register(http_server)
+    EventReportsRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
new file mode 100644
index 0000000000..5b8d0594cd
--- /dev/null
+++ b/synapse/rest/admin/event_reports.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+
+logger = logging.getLogger(__name__)
+
+
+class EventReportsRestServlet(RestServlet):
+    """
+    List all reported events that are known to the homeserver. Results are returned
+    in a dictionary containing report information. Supports pagination.
+    The requester must have administrator access in Synapse.
+
+    GET /_synapse/admin/v1/event_reports
+    returns:
+        200 OK with list of reports if success otherwise an error.
+
+    Args:
+        The parameters `from` and `limit` are required only for pagination.
+        By default, a `limit` of 100 is used.
+        The parameter `dir` can be used to define the order of results.
+        The parameter `user_id` can be used to filter by user id.
+        The parameter `room_id` can be used to filter by room id.
+    Returns:
+        A list of reported events and an integer representing the total number of
+        reported events that exist given this query
+    """
+
+    PATTERNS = admin_patterns("/event_reports$")
+
+    def __init__(self, hs):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+
+    async def on_GET(self, request):
+        await assert_requester_is_admin(self.auth, request)
+
+        start = parse_integer(request, "from", default=0)
+        limit = parse_integer(request, "limit", default=100)
+        direction = parse_string(request, "dir", default="b")
+        user_id = parse_string(request, "user_id")
+        room_id = parse_string(request, "room_id")
+
+        if start < 0:
+            raise SynapseError(
+                400,
+                "The start parameter must be a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if limit < 0:
+            raise SynapseError(
+                400,
+                "The limit parameter must be a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if direction not in ("f", "b"):
+            raise SynapseError(
+                400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+            )
+
+        event_reports, total = await self.store.get_event_reports_paginate(
+            start, limit, direction, user_id, room_id
+        )
+        ret = {"event_reports": event_reports, "total": total}
+        if (start + limit) < total:
+            ret["next_token"] = start + len(event_reports)
+
+        return 200, ret
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 987765e877..dce6c4d168 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
             raise OEmbedError() from e
 
-    async def _download_url(self, url, user):
+    async def _download_url(self, url: str, user):
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
 
         # If this URL can be accessed via oEmbed, use that instead.
-        url_to_download = url
+        url_to_download = url  # type: Optional[str]
         oembed_url = self._get_oembed_url(url)
         if oembed_url:
             # The result might be a new URL to download, or it might be HTML content.
@@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
                 # FIXME: we should calculate a proper expiration based on the
                 # Cache-Control and Expire headers.  But for now, assume 1 hour.
                 expires = ONE_HOUR
-                etag = headers["ETag"][0] if "ETag" in headers else None
+                etag = (
+                    headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+                )
         else:
-            html_bytes = oembed_result.html.encode("utf-8")  # type: ignore
+            # we can only get here if we did an oembed request and have an oembed_result.html
+            assert oembed_result.html is not None
+            assert oembed_url is not None
+
+            html_bytes = oembed_result.html.encode("utf-8")
             with self.media_storage.store_into_file(file_info) as (f, fname, finish):
                 f.write(html_bytes)
                 await finish()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 56d6afb863..5a5ea39e01 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -25,7 +25,6 @@ from typing import (
     Sequence,
     Set,
     Union,
-    cast,
     overload,
 )
 
@@ -42,7 +41,7 @@ from synapse.logging.utils import log_function
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
-from synapse.types import Collection, MutableStateMap, StateMap
+from synapse.types import Collection, StateMap
 from synapse.util import Clock
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -472,10 +471,9 @@ class StateResolutionHandler:
     def __init__(self, hs):
         self.clock = hs.get_clock()
 
-        # dict of set of event_ids -> _StateCacheEntry.
-        self._state_cache = None
         self.resolve_linearizer = Linearizer(name="state_resolve_lock")
 
+        # dict of set of event_ids -> _StateCacheEntry.
         self._state_cache = ExpiringCache(
             cache_name="state_cache",
             clock=self.clock,
@@ -519,57 +517,28 @@ class StateResolutionHandler:
         Returns:
             The resolved state
         """
-        logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
-
         group_names = frozenset(state_groups_ids.keys())
 
         with (await self.resolve_linearizer.queue(group_names)):
-            if self._state_cache is not None:
-                cache = self._state_cache.get(group_names, None)
-                if cache:
-                    return cache
+            cache = self._state_cache.get(group_names, None)
+            if cache:
+                return cache
 
             logger.info(
-                "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
+                "Resolving state for %s with groups %s", room_id, list(group_names),
             )
 
             state_groups_histogram.observe(len(state_groups_ids))
 
-            # start by assuming we won't have any conflicted state, and build up the new
-            # state map by iterating through the state groups. If we discover a conflict,
-            # we give up and instead use `resolve_events_with_store`.
-            #
-            # XXX: is this actually worthwhile, or should we just let
-            # resolve_events_with_store do it?
-            new_state = {}  # type: MutableStateMap[str]
-            conflicted_state = False
-            for st in state_groups_ids.values():
-                for key, e_id in st.items():
-                    if key in new_state:
-                        conflicted_state = True
-                        break
-                    new_state[key] = e_id
-                if conflicted_state:
-                    break
-
-            if conflicted_state:
-                logger.info("Resolving conflicted state for %r", room_id)
-                with Measure(self.clock, "state._resolve_events"):
-                    # resolve_events_with_store returns a StateMap, but we can
-                    # treat it as a MutableStateMap as it is above. It isn't
-                    # actually mutated anymore (and is frozen in
-                    # _make_state_cache_entry below).
-                    new_state = cast(
-                        MutableStateMap,
-                        await resolve_events_with_store(
-                            self.clock,
-                            room_id,
-                            room_version,
-                            list(state_groups_ids.values()),
-                            event_map=event_map,
-                            state_res_store=state_res_store,
-                        ),
-                    )
+            with Measure(self.clock, "state._resolve_events"):
+                new_state = await resolve_events_with_store(
+                    self.clock,
+                    room_id,
+                    room_version,
+                    list(state_groups_ids.values()),
+                    event_map=event_map,
+                    state_res_store=state_res_store,
+                )
 
             # if the new state matches any of the input state groups, we can
             # use that state group again. Otherwise we will generate a state_id
@@ -579,8 +548,7 @@ class StateResolutionHandler:
             with Measure(self.clock, "state.create_group_ids"):
                 cache = _make_state_cache_entry(new_state, state_groups_ids)
 
-            if self._state_cache is not None:
-                self._state_cache[group_names] = cache
+            self._state_cache[group_names] = cache
 
             return cache
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ccb3384db9..0cb12f4c61 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -160,14 +160,20 @@ class DataStore(
         )
 
         if isinstance(self.database_engine, PostgresEngine):
+            # We set the `writers` to an empty list here as we don't care about
+            # missing updates over restarts, as we'll not have anything in our
+            # caches to invalidate. (This reduces the amount of writes to the DB
+            # that happen).
             self._cache_id_gen = MultiWriterIdGenerator(
                 db_conn,
                 database,
-                instance_name="master",
+                stream_name="caches",
+                instance_name=hs.get_instance_name(),
                 table="cache_invalidation_stream_by_instance",
                 instance_column="instance_name",
                 id_column="stream_id",
                 sequence_name="cache_invalidation_stream_seq",
+                writers=[],
             )
         else:
             self._cache_id_gen = None
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index c5a36990e4..ef81d73573 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as room_account_data has a unique constraint
             # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
@@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as account_data has a unique constraint on
             # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index e4fd979a33..2d151b9134 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -394,7 +394,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
-        with await self._device_inbox_id_gen.get_next() as stream_id:
+        async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -443,7 +443,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 txn, stream_id, local_messages_by_user_then_device
             )
 
-        with await self._device_inbox_id_gen.get_next() as stream_id:
+        async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c04374e43d..fdf394c612 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
             THe new stream ID.
         """
 
-        with await self._device_list_id_gen.get_next() as stream_id:
+        async with self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return
 
-        with await self._device_list_id_gen.get_next_mult(
+        async with self._device_list_id_gen.get_next_mult(
             len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
@@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return stream_ids[-1]
 
         context = get_active_span_text_map()
-        with await self._device_list_id_gen.get_next_mult(
+        async with self._device_list_id_gen.get_next_mult(
             len(hosts) * len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c8df0bcb3f..22e1ed15d0 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key (dict): the key data
         """
 
-        with await self._cross_signing_id_gen.get_next() as stream_id:
+        async with self._cross_signing_id_gen.get_next() as stream_id:
             return await self.db_pool.runInteraction(
                 "add_e2e_cross_signing_key",
                 self._set_e2e_cross_signing_key_txn,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9a80f419e3..18def01f50 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,7 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
 
 import attr
 from prometheus_client import Counter
@@ -156,15 +156,15 @@ class PersistEventsStore:
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
         if backfilled:
-            stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
+            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
         else:
-            stream_ordering_manager = await self._stream_id_gen.get_next_mult(
+            stream_ordering_manager = self._stream_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
 
-        with stream_ordering_manager as stream_orderings:
+        async with stream_ordering_manager as stream_orderings:
             for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
@@ -1108,6 +1108,10 @@ class PersistEventsStore:
     def _store_room_members_txn(self, txn, events, backfilled):
         """Store a room member in the database.
         """
+
+        def str_or_none(val: Any) -> Optional[str]:
+            return val if isinstance(val, str) else None
+
         self.db_pool.simple_insert_many_txn(
             txn,
             table="room_memberships",
@@ -1118,8 +1122,8 @@ class PersistEventsStore:
                     "sender": event.user_id,
                     "room_id": event.room_id,
                     "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
+                    "display_name": str_or_none(event.content.get("displayname")),
+                    "avatar_url": str_or_none(event.content.get("avatar_url")),
                 }
                 for event in events
             ],
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index de9e8d1dc6..f95679ebc4 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -83,21 +83,25 @@ class EventsWorkerStore(SQLBaseStore):
             self._stream_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                stream_name="events",
                 instance_name=hs.get_instance_name(),
                 table="events",
                 instance_column="instance_name",
                 id_column="stream_ordering",
                 sequence_name="events_stream_seq",
+                writers=hs.config.worker.writers.events,
             )
             self._backfill_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                stream_name="backfill",
                 instance_name=hs.get_instance_name(),
                 table="events",
                 instance_column="instance_name",
                 id_column="stream_ordering",
                 sequence_name="events_backfill_stream_seq",
                 positive=False,
+                writers=hs.config.worker.writers.events,
             )
         else:
             # We shouldn't be running in worker mode with SQLite, but its useful
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ccfbb2135e..7218191965 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        with await self._group_updates_id_gen.get_next() as next_id:
+        async with self._group_updates_id_gen.get_next() as next_id:
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c9f655dfb7..dbbb99cb95 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
 
 class PresenceStore(SQLBaseStore):
     async def update_presence(self, presence_states):
-        stream_ordering_manager = await self._presence_id_gen.get_next_mult(
+        stream_ordering_manager = self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
 
-        with stream_ordering_manager as stream_orderings:
+        async with stream_ordering_manager as stream_orderings:
             await self.db_pool.runInteraction(
                 "update_presence",
                 self._update_presence_txn,
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index e20a16f907..711d5aa23d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             if before or after:
@@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
@@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
         Raises:
             NotFoundError if the rule does not exist.
         """
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
             await self.db_pool.runInteraction(
                 "_set_push_rule_enabled_txn",
@@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c388468273..df8609b97b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
         last_stream_ordering,
         profile_tag="",
     ) -> None:
-        with await self._pushers_id_gen.get_next() as stream_id:
+        async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so simple_upsert will retry
             await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
                 },
             )
 
-        with await self._pushers_id_gen.get_next() as stream_id:
+        async with self._pushers_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "delete_pusher", delete_pusher_txn, stream_id
             )
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 5867d52b62..c10a16ffa3 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -577,7 +577,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "insert_receipt_conv", graph_to_linear
             )
 
-        with await self._receipts_id_gen.get_next() as stream_id:
+        async with self._receipts_id_gen.get_next() as stream_id:
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 675e81fe34..48ce7ecd16 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_expiration_ts_for_user",
         )
 
+    async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+        """
+        Returns whether an user account is expired.
+
+        Args:
+            user_id: The user's ID
+            current_ts: The current timestamp
+
+        Returns:
+            Whether the user account has expired
+        """
+        expiration_ts = await self.get_expiration_ts_for_user(user_id)
+        return expiration_ts is not None and current_ts >= expiration_ts
+
     async def set_account_validity_for_user(
         self,
         user_id: str,
@@ -379,7 +393,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
-    ) -> str:
+    ) -> Optional[str]:
         """Look up a user by their external auth id
 
         Args:
@@ -387,7 +401,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             external_id: id on that system
 
         Returns:
-            str|None: the mxid of the user, or None if they are not known
+            the mxid of the user, or None if they are not known
         """
         return await self.db_pool.simple_select_one_onecol(
             table="user_external_ids",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index bd6f9553c6..3c7630857f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         },
                     )
 
-            with await self._public_room_id_gen.get_next() as next_id:
+            async with self._public_room_id_gen.get_next() as next_id:
                 await self.db_pool.runInteraction(
                     "store_room_txn", store_room_txn, next_id
                 )
@@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with await self._public_room_id_gen.get_next() as next_id:
+        async with self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
@@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with await self._public_room_id_gen.get_next() as next_id:
+        async with self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
@@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             desc="add_event_report",
         )
 
+    async def get_event_reports_paginate(
+        self,
+        start: int,
+        limit: int,
+        direction: str = "b",
+        user_id: Optional[str] = None,
+        room_id: Optional[str] = None,
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Retrieve a paginated list of event reports
+
+        Args:
+            start: event offset to begin the query from
+            limit: number of rows to retrieve
+            direction: Whether to fetch the most recent first (`"b"`) or the
+                oldest first (`"f"`)
+            user_id: search for user_id. Ignored if user_id is None
+            room_id: search for room_id. Ignored if room_id is None
+        Returns:
+            event_reports: json list of event reports
+            count: total number of event reports matching the filter criteria
+        """
+
+        def _get_event_reports_paginate_txn(txn):
+            filters = []
+            args = []
+
+            if user_id:
+                filters.append("er.user_id LIKE ?")
+                args.extend(["%" + user_id + "%"])
+            if room_id:
+                filters.append("er.room_id LIKE ?")
+                args.extend(["%" + room_id + "%"])
+
+            if direction == "b":
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+            sql = """
+                SELECT COUNT(*) as total_event_reports
+                FROM event_reports AS er
+                {}
+                """.format(
+                where_clause
+            )
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.reason,
+                    er.content,
+                    events.sender,
+                    room_aliases.room_alias,
+                    event_json.json AS event_json
+                FROM event_reports AS er
+                LEFT JOIN room_aliases
+                    ON room_aliases.room_id = er.room_id
+                JOIN events
+                    ON events.event_id = er.event_id
+                JOIN event_json
+                    ON event_json.event_id = er.event_id
+                {where_clause}
+                ORDER BY er.received_ts {order}
+                LIMIT ?
+                OFFSET ?
+            """.format(
+                where_clause=where_clause, order=order,
+            )
+
+            args += [limit, start]
+            txn.execute(sql, args)
+            event_reports = self.db_pool.cursor_to_dict(txn)
+
+            if count > 0:
+                for row in event_reports:
+                    try:
+                        row["content"] = db_to_json(row["content"])
+                        row["event_json"] = db_to_json(row["event_json"])
+                    except Exception:
+                        continue
+
+            return event_reports, count
+
+        return await self.db_pool.runInteraction(
+            "get_event_reports_paginate", _get_event_reports_paginate_txn
+        )
+
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4fa8767b01..86ffe2479e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
 
@@ -37,7 +36,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, PersistedEventPosition, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # for rooms the server is participating in.
         if self._current_state_events_membership_up_to_date:
             sql = """
-                SELECT room_id, e.stream_ordering
+                SELECT room_id, e.instance_name, e.stream_ordering
                 FROM current_state_events AS c
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
@@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
         else:
             sql = """
-                SELECT room_id, e.stream_ordering
+                SELECT room_id, e.instance_name, e.stream_ordering
                 FROM current_state_events AS c
                 INNER JOIN room_memberships AS m USING (room_id, event_id)
                 INNER JOIN events AS e USING (room_id, event_id)
@@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
 
         txn.execute(sql, (user_id, Membership.JOIN))
-        return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+        return frozenset(
+            GetRoomsForUserWithStreamOrdering(
+                room_id, PersistedEventPosition(instance, stream_id)
+            )
+            for room_id, instance, stream_id in txn
+        )
 
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
index 5e29c1da19..ccf287971c 100644
--- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
@@ -13,7 +13,7 @@
  * limitations under the License.
  */
 
--- room_id and topoligical_ordering are denormalised from the events table in order to
+-- room_id and topological_ordering are denormalised from the events table in order to
 -- make the index work.
 CREATE TABLE IF NOT EXISTS event_labels (
     event_id TEXT,
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
index 97c1e6a0c5..c31f9af82a 100644
--- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', (
 
 CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
 
+-- If the server has never backfilled a room then doing `-MIN(...)` will give
+-- a negative result, hence why we do `GREATEST(...)`
 SELECT setval('events_backfill_stream_seq', (
-    SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+    SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events
 ));
diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
new file mode 100644
index 0000000000..985fd949a2
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stream_positions (
+    stream_name TEXT NOT NULL,
+    instance_name TEXT NOT NULL,
+    stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name);
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7816a8606..5beb302be3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -210,6 +210,7 @@ class StatsStore(StateDeltasStore):
         * topic
         * avatar
         * canonical_alias
+        * guest_access
 
         A is_federatable key can also be included with a boolean value.
 
@@ -234,6 +235,7 @@ class StatsStore(StateDeltasStore):
             "topic",
             "avatar",
             "canonical_alias",
+            "guest_access",
         ):
             field = fields.get(col, sentinel)
             if field is not sentinel and (not isinstance(field, str) or "\0" in field):
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 96ffe26cc9..9f120d3cb6 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
             )
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
             txn.execute(sql, (user_id, room_id, tag))
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 2f7c95fc74..f9575b1f1f 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore):
                 return
 
             # They are there, delete them.
-            self.simple_delete_one_txn(
+            self.db_pool.simple_delete_one_txn(
                 txn, "erased_users", keyvalues={"user_id": user_id}
             )
 
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index d89f6ed128..603cd7d825 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, StateMap
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
 
@@ -190,6 +190,7 @@ class EventsPersistenceStorage:
         self.persist_events_store = stores.persist_events
 
         self._clock = hs.get_clock()
+        self._instance_name = hs.get_instance_name()
         self.is_mine_id = hs.is_mine_id
         self._event_persist_queue = _EventPeristenceQueue()
         self._state_resolution_handler = hs.get_state_resolution_handler()
@@ -198,7 +199,7 @@ class EventsPersistenceStorage:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ) -> int:
+    ) -> RoomStreamToken:
         """
         Write events to the database
         Args:
@@ -228,11 +229,11 @@ class EventsPersistenceStorage:
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
-        return self.main_store.get_current_events_token()
+        return RoomStreamToken(None, self.main_store.get_current_events_token())
 
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
-    ) -> Tuple[int, int]:
+    ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
         """
         Returns:
             The stream ordering of `event`, and the stream ordering of the
@@ -247,7 +248,10 @@ class EventsPersistenceStorage:
         await make_deferred_yieldable(deferred)
 
         max_persisted_id = self.main_store.get_current_events_token()
-        return (event.internal_metadata.stream_ordering, max_persisted_id)
+        event_stream_id = event.internal_metadata.stream_ordering
+
+        pos = PersistedEventPosition(self._instance_name, event_stream_id)
+        return pos, RoomStreamToken(None, max_persisted_id)
 
     def _maybe_start_persisting(self, room_id: str):
         async def persisting_queue(item):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8c4a83a840..f152f63321 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
 )
 
 GetRoomsForUserWithStreamOrdering = namedtuple(
-    "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
+    "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
 )
 
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 1de2b91587..4269eaf918 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,16 +12,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-import contextlib
 import heapq
 import logging
 import threading
 from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
 
+import attr
 from typing_extensions import Deque
 
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
@@ -86,7 +87,7 @@ class StreamIdGenerator:
             upwards, -1 to grow downwards.
 
     Usage:
-        with await stream_id_gen.get_next() as stream_id:
+        async with stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
 
@@ -101,10 +102,10 @@ class StreamIdGenerator:
             )
         self._unfinished_ids = deque()  # type: Deque[int]
 
-    async def get_next(self):
+    def get_next(self):
         """
         Usage:
-            with await stream_id_gen.get_next() as stream_id:
+            async with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
         with self._lock:
@@ -113,7 +114,7 @@ class StreamIdGenerator:
 
             self._unfinished_ids.append(next_id)
 
-        @contextlib.contextmanager
+        @contextmanager
         def manager():
             try:
                 yield next_id
@@ -121,12 +122,12 @@ class StreamIdGenerator:
                 with self._lock:
                     self._unfinished_ids.remove(next_id)
 
-        return manager()
+        return _AsyncCtxManagerWrapper(manager())
 
-    async def get_next_mult(self, n):
+    def get_next_mult(self, n):
         """
         Usage:
-            with await stream_id_gen.get_next(n) as stream_ids:
+            async with stream_id_gen.get_next(n) as stream_ids:
                 # ... persist events ...
         """
         with self._lock:
@@ -140,7 +141,7 @@ class StreamIdGenerator:
             for next_id in next_ids:
                 self._unfinished_ids.append(next_id)
 
-        @contextlib.contextmanager
+        @contextmanager
         def manager():
             try:
                 yield next_ids
@@ -149,7 +150,7 @@ class StreamIdGenerator:
                     for next_id in next_ids:
                         self._unfinished_ids.remove(next_id)
 
-        return manager()
+        return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
@@ -184,12 +185,16 @@ class MultiWriterIdGenerator:
     Args:
         db_conn
         db
+        stream_name: A name for the stream.
         instance_name: The name of this instance.
         table: Database table associated with stream.
         instance_column: Column that stores the row's writer's instance name
         id_column: Column that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
+        writers: A list of known writers to use to populate current positions
+            on startup. Can be empty if nothing uses `get_current_token` or
+            `get_positions` (e.g. caches stream).
         positive: Whether the IDs are positive (true) or negative (false).
             When using negative IDs we go backwards from -1 to -2, -3, etc.
     """
@@ -198,16 +203,20 @@ class MultiWriterIdGenerator:
         self,
         db_conn,
         db: DatabasePool,
+        stream_name: str,
         instance_name: str,
         table: str,
         instance_column: str,
         id_column: str,
         sequence_name: str,
+        writers: List[str],
         positive: bool = True,
     ):
         self._db = db
+        self._stream_name = stream_name
         self._instance_name = instance_name
         self._positive = positive
+        self._writers = writers
         self._return_factor = 1 if positive else -1
 
         # We lock as some functions may be called from DB threads.
@@ -216,9 +225,7 @@ class MultiWriterIdGenerator:
         # Note: If we are a negative stream then we still store all the IDs as
         # positive to make life easier for us, and simply negate the IDs when we
         # return them.
-        self._current_positions = self._load_current_ids(
-            db_conn, table, instance_column, id_column
-        )
+        self._current_positions = {}  # type: Dict[str, int]
 
         # Set of local IDs that we're still processing. The current position
         # should be less than the minimum of this set (if not empty).
@@ -251,30 +258,84 @@ class MultiWriterIdGenerator:
 
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
+        # This goes and fills out the above state from the database.
+        self._load_current_ids(db_conn, table, instance_column, id_column)
+
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
-    ) -> Dict[str, int]:
-        # If positive stream aggregate via MAX. For negative stream use MIN
-        # *and* negate the result to get a positive number.
-        sql = """
-            SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
-            GROUP BY %(instance)s
-        """ % {
-            "instance": instance_column,
-            "id": id_column,
-            "table": table,
-            "agg": "MAX" if self._positive else "-MIN",
-        }
-
+    ):
         cur = db_conn.cursor()
-        cur.execute(sql)
 
-        # `cur` is an iterable over returned rows, which are 2-tuples.
-        current_positions = dict(cur)
+        # Load the current positions of all writers for the stream.
+        if self._writers:
+            sql = """
+                SELECT instance_name, stream_id FROM stream_positions
+                WHERE stream_name = ?
+            """
+            sql = self._db.engine.convert_param_style(sql)
 
-        cur.close()
+            cur.execute(sql, (self._stream_name,))
 
-        return current_positions
+            self._current_positions = {
+                instance: stream_id * self._return_factor
+                for instance, stream_id in cur
+                if instance in self._writers
+            }
+
+        # We set the `_persisted_upto_position` to be the minimum of all current
+        # positions. If empty we use the max stream ID from the DB table.
+        min_stream_id = min(self._current_positions.values(), default=None)
+
+        if min_stream_id is None:
+            # We add a GREATEST here to ensure that the result is always
+            # positive. (This can be a problem for e.g. backfill streams where
+            # the server has never backfilled).
+            sql = """
+                SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+                FROM %(table)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "agg": "MAX" if self._positive else "-MIN",
+            }
+            cur.execute(sql)
+            (stream_id,) = cur.fetchone()
+            self._persisted_upto_position = stream_id
+        else:
+            # If we have a min_stream_id then we pull out everything greater
+            # than it from the DB so that we can prefill
+            # `_known_persisted_positions` and get a more accurate
+            # `_persisted_upto_position`.
+            #
+            # We also check if any of the later rows are from this instance, in
+            # which case we use that for this instance's current position. This
+            # is to handle the case where we didn't finish persisting to the
+            # stream positions table before restart (or the stream position
+            # table otherwise got out of date).
+
+            sql = """
+                SELECT %(instance)s, %(id)s FROM %(table)s
+                WHERE ? %(cmp)s %(id)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "instance": instance_column,
+                "cmp": "<=" if self._positive else ">=",
+            }
+            sql = self._db.engine.convert_param_style(sql)
+            cur.execute(sql, (min_stream_id,))
+
+            self._persisted_upto_position = min_stream_id
+
+            with self._lock:
+                for (instance, stream_id,) in cur:
+                    stream_id = self._return_factor * stream_id
+                    self._add_persisted_position(stream_id)
+
+                    if instance == self._instance_name:
+                        self._current_positions[instance] = stream_id
+
+        cur.close()
 
     def _load_next_id_txn(self, txn) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
@@ -282,59 +343,23 @@ class MultiWriterIdGenerator:
     def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
         return self._sequence_gen.get_next_mult_txn(txn, n)
 
-    async def get_next(self):
+    def get_next(self):
         """
         Usage:
-            with await stream_id_gen.get_next() as stream_id:
+            async with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
-        next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
-        # Assert the fetched ID is actually greater than what we currently
-        # believe the ID to be. If not, then the sequence and table have got
-        # out of sync somehow.
-        with self._lock:
-            assert self._current_positions.get(self._instance_name, 0) < next_id
-
-            self._unfinished_ids.add(next_id)
 
-        @contextlib.contextmanager
-        def manager():
-            try:
-                # Multiply by the return factor so that the ID has correct sign.
-                yield self._return_factor * next_id
-            finally:
-                self._mark_id_as_finished(next_id)
+        return _MultiWriterCtxManager(self)
 
-        return manager()
-
-    async def get_next_mult(self, n: int):
+    def get_next_mult(self, n: int):
         """
         Usage:
-            with await stream_id_gen.get_next_mult(5) as stream_ids:
+            async with stream_id_gen.get_next_mult(5) as stream_ids:
                 # ... persist events ...
         """
-        next_ids = await self._db.runInteraction(
-            "_load_next_mult_id", self._load_next_mult_id_txn, n
-        )
-
-        # Assert the fetched ID is actually greater than any ID we've already
-        # seen. If not, then the sequence and table have got out of sync
-        # somehow.
-        with self._lock:
-            assert max(self._current_positions.values(), default=0) < min(next_ids)
-
-            self._unfinished_ids.update(next_ids)
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                yield [self._return_factor * i for i in next_ids]
-            finally:
-                for i in next_ids:
-                    self._mark_id_as_finished(i)
 
-        return manager()
+        return _MultiWriterCtxManager(self, n)
 
     def get_next_txn(self, txn: LoggingTransaction):
         """
@@ -352,6 +377,21 @@ class MultiWriterIdGenerator:
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
 
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persited row with the correct instance name.
+        if self._writers:
+            txn.call_after(
+                run_as_background_process,
+                "MultiWriterIdGenerator._update_table",
+                self._db.runInteraction,
+                "MultiWriterIdGenerator._update_table",
+                self._update_stream_positions_table_txn,
+            )
+
         return self._return_factor * next_id
 
     def _mark_id_as_finished(self, next_id: int):
@@ -482,3 +522,95 @@ class MultiWriterIdGenerator:
                 # There was a gap in seen positions, so there is nothing more to
                 # do.
                 break
+
+    def _update_stream_positions_table_txn(self, txn):
+        """Update the `stream_positions` table with newly persisted position.
+        """
+
+        if not self._writers:
+            return
+
+        # We upsert the value, ensuring on conflict that we always increase the
+        # value (or decrease if stream goes backwards).
+        sql = """
+            INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+            VALUES (?, ?, ?)
+            ON CONFLICT (stream_name, instance_name)
+            DO UPDATE SET
+                stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+        """ % {
+            "agg": "GREATEST" if self._positive else "LEAST",
+        }
+
+        pos = (self.get_current_token_for_writer(self._instance_name),)
+        txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+    """Helper class to convert a plain context manager to an async one.
+
+    This is mainly useful if you have a plain context manager but the interface
+    requires an async one.
+    """
+
+    inner = attr.ib()
+
+    async def __aenter__(self):
+        return self.inner.__enter__()
+
+    async def __aexit__(self, exc_type, exc, tb):
+        return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+    """Async context manager returned by MultiWriterIdGenerator
+    """
+
+    id_gen = attr.ib(type=MultiWriterIdGenerator)
+    multiple_ids = attr.ib(type=Optional[int], default=None)
+    stream_ids = attr.ib(type=List[int], factory=list)
+
+    async def __aenter__(self) -> Union[int, List[int]]:
+        self.stream_ids = await self.id_gen._db.runInteraction(
+            "_load_next_mult_id",
+            self.id_gen._load_next_mult_id_txn,
+            self.multiple_ids or 1,
+        )
+
+        # Assert the fetched ID is actually greater than any ID we've already
+        # seen. If not, then the sequence and table have got out of sync
+        # somehow.
+        with self.id_gen._lock:
+            assert max(self.id_gen._current_positions.values(), default=0) < min(
+                self.stream_ids
+            )
+
+            self.id_gen._unfinished_ids.update(self.stream_ids)
+
+        if self.multiple_ids is None:
+            return self.stream_ids[0] * self.id_gen._return_factor
+        else:
+            return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+    async def __aexit__(self, exc_type, exc, tb):
+        for i in self.stream_ids:
+            self.id_gen._mark_id_as_finished(i)
+
+        if exc_type is not None:
+            return False
+
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persisted row with the correct instance name.
+        if self.id_gen._writers:
+            await self.id_gen._db.runInteraction(
+                "MultiWriterIdGenerator._update_table",
+                self.id_gen._update_stream_positions_table_txn,
+            )
+
+        return False
diff --git a/synapse/types.py b/synapse/types.py
index a6fc7df22c..ec39f9e1e8 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -495,6 +495,21 @@ class StreamToken:
 StreamToken.START = StreamToken.from_string("s0_0")
 
 
+@attr.s(slots=True, frozen=True)
+class PersistedEventPosition:
+    """Position of a newly persisted event with instance that persisted it.
+
+    This can be used to test whether the event is persisted before or after a
+    RoomStreamToken.
+    """
+
+    instance_name = attr.ib(type=str)
+    stream = attr.ib(type=int)
+
+    def persisted_after(self, token: RoomStreamToken) -> bool:
+        return token.stream < self.stream
+
+
 class ThirdPartyInstanceID(
     namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
 ):