summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py4
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/app/homeserver.py17
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/cache.py164
-rw-r--r--synapse/config/database.py6
-rw-r--r--synapse/config/homeserver.py4
-rw-r--r--synapse/config/oidc_config.py177
-rw-r--r--synapse/config/server_notices_config.py2
-rw-r--r--synapse/config/spam_checker.py38
-rw-r--r--synapse/config/sso.py17
-rw-r--r--synapse/events/spamcheck.py78
-rw-r--r--synapse/federation/send_queue.py84
-rw-r--r--synapse/federation/sender/__init__.py12
-rw-r--r--synapse/federation/sender/per_destination_queue.py6
-rw-r--r--synapse/federation/sender/transaction_manager.py6
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/federation.py32
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/handlers/oidc_handler.py998
-rw-r--r--synapse/handlers/room.py42
-rw-r--r--synapse/handlers/room_member.py5
-rw-r--r--synapse/handlers/saml_handler.py28
-rw-r--r--synapse/handlers/search.py44
-rw-r--r--synapse/http/client.py13
-rw-r--r--synapse/metrics/_exposition.py12
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py4
-rw-r--r--synapse/push/push_rule_evaluator.py4
-rw-r--r--synapse/python_dependencies.py1
-rw-r--r--synapse/replication/http/streams.py4
-rw-r--r--synapse/replication/slave/storage/_base.py50
-rw-r--r--synapse/replication/slave/storage/account_data.py6
-rw-r--r--synapse/replication/slave/storage/client_ips.py3
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py6
-rw-r--r--synapse/replication/slave/storage/devices.py6
-rw-r--r--synapse/replication/slave/storage/events.py6
-rw-r--r--synapse/replication/slave/storage/groups.py6
-rw-r--r--synapse/replication/slave/storage/presence.py6
-rw-r--r--synapse/replication/slave/storage/push_rule.py6
-rw-r--r--synapse/replication/slave/storage/pushers.py6
-rw-r--r--synapse/replication/slave/storage/receipts.py6
-rw-r--r--synapse/replication/slave/storage/room.py4
-rw-r--r--synapse/replication/tcp/client.py6
-rw-r--r--synapse/replication/tcp/commands.py33
-rw-r--r--synapse/replication/tcp/handler.py98
-rw-r--r--synapse/replication/tcp/redis.py7
-rw-r--r--synapse/replication/tcp/resource.py33
-rw-r--r--synapse/replication/tcp/streams/_base.py86
-rw-r--r--synapse/replication/tcp/streams/events.py4
-rw-r--r--synapse/replication/tcp/streams/federation.py36
-rw-r--r--synapse/res/templates/notice_expiry.html2
-rw-r--r--synapse/res/templates/notice_expiry.txt2
-rw-r--r--synapse/res/templates/sso_error.html18
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/rooms.py26
-rw-r--r--synapse/rest/client/v1/login.py28
-rw-r--r--synapse/rest/oidc/__init__.py27
-rw-r--r--synapse/rest/oidc/callback_resource.py31
-rw-r--r--synapse/server.py6
-rw-r--r--synapse/server.pyi5
-rw-r--r--synapse/state/__init__.py4
-rw-r--r--synapse/storage/_base.py3
-rw-r--r--synapse/storage/data_stores/__init__.py9
-rw-r--r--synapse/storage/data_stores/main/__init__.py23
-rw-r--r--synapse/storage/data_stores/main/cache.py84
-rw-r--r--synapse/storage/data_stores/main/censor_events.py208
-rw-r--r--synapse/storage/data_stores/main/client_ips.py3
-rw-r--r--synapse/storage/data_stores/main/devices.py60
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py98
-rw-r--r--synapse/storage/data_stores/main/event_federation.py83
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py74
-rw-r--r--synapse/storage/data_stores/main/events.py1216
-rw-r--r--synapse/storage/data_stores/main/events_worker.py157
-rw-r--r--synapse/storage/data_stores/main/metrics.py128
-rw-r--r--synapse/storage/data_stores/main/purge_events.py399
-rw-r--r--synapse/storage/data_stores/main/rejections.py11
-rw-r--r--synapse/storage/data_stores/main/relations.py60
-rw-r--r--synapse/storage/data_stores/main/room.py80
-rw-r--r--synapse/storage/data_stores/main/roommember.py100
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql28
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres30
-rw-r--r--synapse/storage/data_stores/main/search.py23
-rw-r--r--synapse/storage/data_stores/main/signatures.py30
-rw-r--r--synapse/storage/data_stores/main/state.py35
-rw-r--r--synapse/storage/data_stores/state/store.py6
-rw-r--r--synapse/storage/database.py74
-rw-r--r--synapse/storage/persist_events.py27
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/util/id_generators.py169
-rw-r--r--synapse/util/caches/__init__.py144
-rw-r--r--synapse/util/caches/descriptors.py36
-rw-r--r--synapse/util/caches/expiringcache.py29
-rw-r--r--synapse/util/caches/lrucache.py52
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--synapse/util/caches/stream_change_cache.py33
-rw-r--r--synapse/util/caches/ttlcache.py2
96 files changed, 3953 insertions, 1942 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1ad5ff9410..e009b1a760 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -37,7 +37,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.types import StateMap, UserID
-from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
+from synapse.util.caches import register_cache
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.metrics import Measure
 
@@ -73,7 +73,7 @@ class Auth(object):
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
 
-        self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
+        self.token_cache = LruCache(10000)
         register_cache("cache", "token_cache", self.token_cache)
 
         self._auth_blocking = AuthBlocking(self.hs)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 667ad20428..bccb1140b2 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -122,6 +122,7 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
 from synapse.rest.key.v2 import KeyApiV2Resource
 from synapse.server import HomeServer
+from synapse.storage.data_stores.main.censor_events import CensorEventsStore
 from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
 from synapse.storage.data_stores.main.monthly_active_users import (
     MonthlyActiveUsersWorkerStore,
@@ -442,6 +443,7 @@ class GenericWorkerSlavedStore(
     SlavedGroupServerStore,
     SlavedAccountDataStore,
     SlavedPusherStore,
+    CensorEventsStore,
     SlavedEventStore,
     SlavedKeyStore,
     RoomStore,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index cbd1ea475a..d7f337e586 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -69,7 +69,6 @@ from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.engines import IncorrectDatabaseSetup
 from synapse.storage.prepare_database import UpgradeDatabaseException
-from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.module_loader import load_module
@@ -192,6 +191,11 @@ class SynapseHomeServer(HomeServer):
                 }
             )
 
+            if self.get_config().oidc_enabled:
+                from synapse.rest.oidc import OIDCResource
+
+                resources["/_synapse/oidc"] = OIDCResource(self)
+
             if self.get_config().saml2_enabled:
                 from synapse.rest.saml2 import SAML2Resource
 
@@ -422,6 +426,13 @@ def setup(config_options):
                 # Check if it needs to be reprovisioned every day.
                 hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
 
+            # Load the OIDC provider metadatas, if OIDC is enabled.
+            if hs.config.oidc_enabled:
+                oidc = hs.get_oidc_handler()
+                # Loading the provider metadata also ensures the provider config is valid.
+                yield defer.ensureDeferred(oidc.load_metadata())
+                yield defer.ensureDeferred(oidc.load_jwks())
+
             _base.start(hs, config.listeners)
 
             hs.get_datastore().db.updates.start_doing_background_updates()
@@ -504,8 +515,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
 
     daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
     stats["daily_sent_messages"] = daily_sent_messages
-    stats["cache_factor"] = CACHE_SIZE_FACTOR
-    stats["event_cache_size"] = hs.config.event_cache_size
+    stats["cache_factor"] = hs.config.caches.global_factor
+    stats["event_cache_size"] = hs.config.caches.event_cache_size
 
     #
     # Performance statistics
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 3053fc9d27..9e576060d4 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -13,6 +13,7 @@ from synapse.config import (
     key,
     logger,
     metrics,
+    oidc_config,
     password,
     password_auth_providers,
     push,
@@ -59,6 +60,7 @@ class RootConfig:
     saml2: saml2_config.SAML2Config
     cas: cas.CasConfig
     sso: sso.SSOConfig
+    oidc: oidc_config.OIDCConfig
     jwt: jwt_config.JWTConfig
     password: password.PasswordConfig
     email: emailconfig.EmailConfig
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
new file mode 100644
index 0000000000..91036a012e
--- /dev/null
+++ b/synapse/config/cache.py
@@ -0,0 +1,164 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Callable, Dict
+
+from ._base import Config, ConfigError
+
+# The prefix for all cache factor-related environment variables
+_CACHES = {}
+_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
+_DEFAULT_FACTOR_SIZE = 0.5
+_DEFAULT_EVENT_CACHE_SIZE = "10K"
+
+
+class CacheProperties(object):
+    def __init__(self):
+        # The default factor size for all caches
+        self.default_factor_size = float(
+            os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
+        )
+        self.resize_all_caches_func = None
+
+
+properties = CacheProperties()
+
+
+def add_resizable_cache(cache_name: str, cache_resize_callback: Callable):
+    """Register a cache that's size can dynamically change
+
+    Args:
+        cache_name: A reference to the cache
+        cache_resize_callback: A callback function that will be ran whenever
+            the cache needs to be resized
+    """
+    _CACHES[cache_name.lower()] = cache_resize_callback
+
+    # Ensure all loaded caches are sized appropriately
+    #
+    # This method should only run once the config has been read,
+    # as it uses values read from it
+    if properties.resize_all_caches_func:
+        properties.resize_all_caches_func()
+
+
+class CacheConfig(Config):
+    section = "caches"
+    _environ = os.environ
+
+    @staticmethod
+    def reset():
+        """Resets the caches to their defaults. Used for tests."""
+        properties.default_factor_size = float(
+            os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
+        )
+        properties.resize_all_caches_func = None
+        _CACHES.clear()
+
+    def generate_config_section(self, **kwargs):
+        return """\
+        ## Caching ##
+
+        # Caching can be configured through the following options.
+        #
+        # A cache 'factor' is a multiplier that can be applied to each of
+        # Synapse's caches in order to increase or decrease the maximum
+        # number of entries that can be stored.
+
+        # The number of events to cache in memory. Not affected by
+        # caches.global_factor.
+        #
+        #event_cache_size: 10K
+
+        caches:
+           # Controls the global cache factor, which is the default cache factor
+           # for all caches if a specific factor for that cache is not otherwise
+           # set.
+           #
+           # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
+           # variable. Setting by environment variable takes priority over
+           # setting through the config file.
+           #
+           # Defaults to 0.5, which will half the size of all caches.
+           #
+           #global_factor: 1.0
+
+           # A dictionary of cache name to cache factor for that individual
+           # cache. Overrides the global cache factor for a given cache.
+           #
+           # These can also be set through environment variables comprised
+           # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
+           # letters and underscores. Setting by environment variable
+           # takes priority over setting through the config file.
+           # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
+           #
+           per_cache_factors:
+             #get_users_who_share_room_with_user: 2.0
+        """
+
+    def read_config(self, config, **kwargs):
+        self.event_cache_size = self.parse_size(
+            config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
+        )
+        self.cache_factors = {}  # type: Dict[str, float]
+
+        cache_config = config.get("caches") or {}
+        self.global_factor = cache_config.get(
+            "global_factor", properties.default_factor_size
+        )
+        if not isinstance(self.global_factor, (int, float)):
+            raise ConfigError("caches.global_factor must be a number.")
+
+        # Set the global one so that it's reflected in new caches
+        properties.default_factor_size = self.global_factor
+
+        # Load cache factors from the config
+        individual_factors = cache_config.get("per_cache_factors") or {}
+        if not isinstance(individual_factors, dict):
+            raise ConfigError("caches.per_cache_factors must be a dictionary")
+
+        # Override factors from environment if necessary
+        individual_factors.update(
+            {
+                key[len(_CACHE_PREFIX) + 1 :].lower(): float(val)
+                for key, val in self._environ.items()
+                if key.startswith(_CACHE_PREFIX + "_")
+            }
+        )
+
+        for cache, factor in individual_factors.items():
+            if not isinstance(factor, (int, float)):
+                raise ConfigError(
+                    "caches.per_cache_factors.%s must be a number" % (cache.lower(),)
+                )
+            self.cache_factors[cache.lower()] = factor
+
+        # Resize all caches (if necessary) with the new factors we've loaded
+        self.resize_all_caches()
+
+        # Store this function so that it can be called from other classes without
+        # needing an instance of Config
+        properties.resize_all_caches_func = self.resize_all_caches
+
+    def resize_all_caches(self):
+        """Ensure all cache sizes are up to date
+
+        For each cache, run the mapped callback function with either
+        a specific cache factor or the default, global one.
+        """
+        for cache_name, callback in _CACHES.items():
+            new_factor = self.cache_factors.get(cache_name, self.global_factor)
+            callback(new_factor)
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 5b662d1b01..1064c2697b 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -68,10 +68,6 @@ database:
   name: sqlite3
   args:
     database: %(database_path)s
-
-# Number of events to cache in memory.
-#
-#event_cache_size: 10K
 """
 
 
@@ -116,8 +112,6 @@ class DatabaseConfig(Config):
         self.databases = []
 
     def read_config(self, config, **kwargs):
-        self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
-
         # We *experimentally* support specifying multiple databases via the
         # `databases` key. This is a map from a label to database config in the
         # same format as the `database` config option, plus an extra
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index be6c6afa74..2c7b3a699f 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -17,6 +17,7 @@
 from ._base import RootConfig
 from .api import ApiConfig
 from .appservice import AppServiceConfig
+from .cache import CacheConfig
 from .captcha import CaptchaConfig
 from .cas import CasConfig
 from .consent_config import ConsentConfig
@@ -27,6 +28,7 @@ from .jwt_config import JWTConfig
 from .key import KeyConfig
 from .logger import LoggingConfig
 from .metrics import MetricsConfig
+from .oidc_config import OIDCConfig
 from .password import PasswordConfig
 from .password_auth_providers import PasswordAuthProviderConfig
 from .push import PushConfig
@@ -54,6 +56,7 @@ class HomeServerConfig(RootConfig):
     config_classes = [
         ServerConfig,
         TlsConfig,
+        CacheConfig,
         DatabaseConfig,
         LoggingConfig,
         RatelimitConfig,
@@ -66,6 +69,7 @@ class HomeServerConfig(RootConfig):
         AppServiceConfig,
         KeyConfig,
         SAML2Config,
+        OIDCConfig,
         CasConfig,
         SSOConfig,
         JWTConfig,
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
new file mode 100644
index 0000000000..5af110745e
--- /dev/null
+++ b/synapse/config/oidc_config.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.util.module_loader import load_module
+
+from ._base import Config, ConfigError
+
+DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
+
+
+class OIDCConfig(Config):
+    section = "oidc"
+
+    def read_config(self, config, **kwargs):
+        self.oidc_enabled = False
+
+        oidc_config = config.get("oidc_config")
+
+        if not oidc_config or not oidc_config.get("enabled", False):
+            return
+
+        try:
+            check_requirements("oidc")
+        except DependencyException as e:
+            raise ConfigError(e.message)
+
+        public_baseurl = self.public_baseurl
+        if public_baseurl is None:
+            raise ConfigError("oidc_config requires a public_baseurl to be set")
+        self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
+
+        self.oidc_enabled = True
+        self.oidc_discover = oidc_config.get("discover", True)
+        self.oidc_issuer = oidc_config["issuer"]
+        self.oidc_client_id = oidc_config["client_id"]
+        self.oidc_client_secret = oidc_config["client_secret"]
+        self.oidc_client_auth_method = oidc_config.get(
+            "client_auth_method", "client_secret_basic"
+        )
+        self.oidc_scopes = oidc_config.get("scopes", ["openid"])
+        self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
+        self.oidc_token_endpoint = oidc_config.get("token_endpoint")
+        self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
+        self.oidc_jwks_uri = oidc_config.get("jwks_uri")
+        self.oidc_subject_claim = oidc_config.get("subject_claim", "sub")
+        self.oidc_skip_verification = oidc_config.get("skip_verification", False)
+
+        ump_config = oidc_config.get("user_mapping_provider", {})
+        ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+        ump_config.setdefault("config", {})
+
+        (
+            self.oidc_user_mapping_provider_class,
+            self.oidc_user_mapping_provider_config,
+        ) = load_module(ump_config)
+
+        # Ensure loaded user mapping module has defined all necessary methods
+        required_methods = [
+            "get_remote_user_id",
+            "map_user_attributes",
+        ]
+        missing_methods = [
+            method
+            for method in required_methods
+            if not hasattr(self.oidc_user_mapping_provider_class, method)
+        ]
+        if missing_methods:
+            raise ConfigError(
+                "Class specified by oidc_config."
+                "user_mapping_provider.module is missing required "
+                "methods: %s" % (", ".join(missing_methods),)
+            )
+
+    def generate_config_section(self, config_dir_path, server_name, **kwargs):
+        return """\
+        # Enable OpenID Connect for registration and login. Uses authlib.
+        #
+        oidc_config:
+            # enable OpenID Connect. Defaults to false.
+            #
+            #enabled: true
+
+            # use the OIDC discovery mechanism to discover endpoints. Defaults to true.
+            #
+            #discover: true
+
+            # the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required.
+            #
+            #issuer: "https://accounts.example.com/"
+
+            # oauth2 client id to use. Required.
+            #
+            #client_id: "provided-by-your-issuer"
+
+            # oauth2 client secret to use. Required.
+            #
+            #client_secret: "provided-by-your-issuer"
+
+            # auth method to use when exchanging the token.
+            # Valid values are "client_secret_basic" (default), "client_secret_post" and "none".
+            #
+            #client_auth_method: "client_auth_basic"
+
+            # list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"].
+            #
+            #scopes: ["openid"]
+
+            # the oauth2 authorization endpoint. Required if provider discovery is disabled.
+            #
+            #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+
+            # the oauth2 token endpoint. Required if provider discovery is disabled.
+            #
+            #token_endpoint: "https://accounts.example.com/oauth2/token"
+
+            # the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked.
+            #
+            #userinfo_endpoint: "https://accounts.example.com/userinfo"
+
+            # URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used.
+            #
+            #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+
+            # skip metadata verification. Defaults to false.
+            # Use this if you are connecting to a provider that is not OpenID Connect compliant.
+            # Avoid this in production.
+            #
+            #skip_verification: false
+
+
+            # An external module can be provided here as a custom solution to mapping
+            # attributes returned from a OIDC provider onto a matrix user.
+            #
+            user_mapping_provider:
+              # The custom module's class. Uncomment to use a custom module.
+              # Default is {mapping_provider!r}.
+              #
+              #module: mapping_provider.OidcMappingProvider
+
+              # Custom configuration values for the module. Below options are intended
+              # for the built-in provider, they should be changed if using a custom
+              # module. This section will be passed as a Python dictionary to the
+              # module's `parse_config` method.
+              #
+              # Below is the config of the default mapping provider, based on Jinja2
+              # templates. Those templates are used to render user attributes, where the
+              # userinfo object is available through the `user` variable.
+              #
+              config:
+                # name of the claim containing a unique identifier for the user.
+                # Defaults to `sub`, which OpenID Connect compliant providers should provide.
+                #
+                #subject_claim: "sub"
+
+                # Jinja2 template for the localpart of the MXID
+                #
+                localpart_template: "{{{{ user.preferred_username }}}}"
+
+                # Jinja2 template for the display name to set on first login. Optional.
+                #
+                #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
+        """.format(
+            mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
+        )
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
index 6ea2ea8869..6c427b6f92 100644
--- a/synapse/config/server_notices_config.py
+++ b/synapse/config/server_notices_config.py
@@ -51,7 +51,7 @@ class ServerNoticesConfig(Config):
             None if server notices are not enabled.
 
         server_notices_mxid_avatar_url (str|None):
-            The display name to use for the server notices user.
+            The MXC URL for the avatar of the server notices user.
             None if server notices are not enabled.
 
         server_notices_room_name (str|None):
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 36e0ddab5c..3d067d29db 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -13,6 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Any, Dict, List, Tuple
+
+from synapse.config import ConfigError
 from synapse.util.module_loader import load_module
 
 from ._base import Config
@@ -22,16 +25,35 @@ class SpamCheckerConfig(Config):
     section = "spamchecker"
 
     def read_config(self, config, **kwargs):
-        self.spam_checker = None
+        self.spam_checkers = []  # type: List[Tuple[Any, Dict]]
+
+        spam_checkers = config.get("spam_checker") or []
+        if isinstance(spam_checkers, dict):
+            # The spam_checker config option used to only support one
+            # spam checker, and thus was simply a dictionary with module
+            # and config keys. Support this old behaviour by checking
+            # to see if the option resolves to a dictionary
+            self.spam_checkers.append(load_module(spam_checkers))
+        elif isinstance(spam_checkers, list):
+            for spam_checker in spam_checkers:
+                if not isinstance(spam_checker, dict):
+                    raise ConfigError("spam_checker syntax is incorrect")
 
-        provider = config.get("spam_checker", None)
-        if provider is not None:
-            self.spam_checker = load_module(provider)
+                self.spam_checkers.append(load_module(spam_checker))
+        else:
+            raise ConfigError("spam_checker syntax is incorrect")
 
     def generate_config_section(self, **kwargs):
         return """\
-        #spam_checker:
-        #  module: "my_custom_project.SuperSpamChecker"
-        #  config:
-        #    example_option: 'things'
+        # Spam checkers are third-party modules that can block specific actions
+        # of local users, such as creating rooms and registering undesirable
+        # usernames, as well as remote users by redacting incoming events.
+        #
+        spam_checker:
+           #- module: "my_custom_project.SuperSpamChecker"
+           #  config:
+           #    example_option: 'things'
+           #- module: "some_other_project.BadEventStopper"
+           #  config:
+           #    example_stop_events_from: ['@bad:example.com']
         """
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index cac6bc0139..aff642f015 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -36,17 +36,13 @@ class SSOConfig(Config):
         if not template_dir:
             template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
 
-        self.sso_redirect_confirm_template_dir = template_dir
+        self.sso_template_dir = template_dir
         self.sso_account_deactivated_template = self.read_file(
-            os.path.join(
-                self.sso_redirect_confirm_template_dir, "sso_account_deactivated.html"
-            ),
+            os.path.join(self.sso_template_dir, "sso_account_deactivated.html"),
             "sso_account_deactivated_template",
         )
         self.sso_auth_success_template = self.read_file(
-            os.path.join(
-                self.sso_redirect_confirm_template_dir, "sso_auth_success.html"
-            ),
+            os.path.join(self.sso_template_dir, "sso_auth_success.html"),
             "sso_auth_success_template",
         )
 
@@ -137,6 +133,13 @@ class SSOConfig(Config):
             #
             #   This template has no additional variables.
             #
+            # * HTML page to display to users if something goes wrong during the
+            #   OpenID Connect authentication process: 'sso_error.html'.
+            #
+            #   When rendering, this template is given two variables:
+            #     * error: the technical name of the error
+            #     * error_description: a human-readable message for the error
+            #
             # You can see the default templates at:
             # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
             #
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index a23b6b7b61..1ffc9525d1 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import inspect
-from typing import Dict
+from typing import Any, Dict, List
 
 from synapse.spam_checker_api import SpamCheckerApi
 
@@ -26,24 +26,17 @@ if MYPY:
 
 class SpamChecker(object):
     def __init__(self, hs: "synapse.server.HomeServer"):
-        self.spam_checker = None
+        self.spam_checkers = []  # type: List[Any]
 
-        module = None
-        config = None
-        try:
-            module, config = hs.config.spam_checker
-        except Exception:
-            pass
-
-        if module is not None:
+        for module, config in hs.config.spam_checkers:
             # Older spam checkers don't accept the `api` argument, so we
             # try and detect support.
             spam_args = inspect.getfullargspec(module)
             if "api" in spam_args.args:
                 api = SpamCheckerApi(hs)
-                self.spam_checker = module(config=config, api=api)
+                self.spam_checkers.append(module(config=config, api=api))
             else:
-                self.spam_checker = module(config=config)
+                self.spam_checkers.append(module(config=config))
 
     def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
         """Checks if a given event is considered "spammy" by this server.
@@ -58,10 +51,11 @@ class SpamChecker(object):
         Returns:
             True if the event is spammy.
         """
-        if self.spam_checker is None:
-            return False
+        for spam_checker in self.spam_checkers:
+            if spam_checker.check_event_for_spam(event):
+                return True
 
-        return self.spam_checker.check_event_for_spam(event)
+        return False
 
     def user_may_invite(
         self, inviter_userid: str, invitee_userid: str, room_id: str
@@ -78,12 +72,14 @@ class SpamChecker(object):
         Returns:
             True if the user may send an invite, otherwise False
         """
-        if self.spam_checker is None:
-            return True
+        for spam_checker in self.spam_checkers:
+            if (
+                spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+                is False
+            ):
+                return False
 
-        return self.spam_checker.user_may_invite(
-            inviter_userid, invitee_userid, room_id
-        )
+        return True
 
     def user_may_create_room(self, userid: str) -> bool:
         """Checks if a given user may create a room
@@ -96,10 +92,11 @@ class SpamChecker(object):
         Returns:
             True if the user may create a room, otherwise False
         """
-        if self.spam_checker is None:
-            return True
+        for spam_checker in self.spam_checkers:
+            if spam_checker.user_may_create_room(userid) is False:
+                return False
 
-        return self.spam_checker.user_may_create_room(userid)
+        return True
 
     def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
         """Checks if a given user may create a room alias
@@ -113,10 +110,11 @@ class SpamChecker(object):
         Returns:
             True if the user may create a room alias, otherwise False
         """
-        if self.spam_checker is None:
-            return True
+        for spam_checker in self.spam_checkers:
+            if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
+                return False
 
-        return self.spam_checker.user_may_create_room_alias(userid, room_alias)
+        return True
 
     def user_may_publish_room(self, userid: str, room_id: str) -> bool:
         """Checks if a given user may publish a room to the directory
@@ -130,10 +128,11 @@ class SpamChecker(object):
         Returns:
             True if the user may publish the room, otherwise False
         """
-        if self.spam_checker is None:
-            return True
+        for spam_checker in self.spam_checkers:
+            if spam_checker.user_may_publish_room(userid, room_id) is False:
+                return False
 
-        return self.spam_checker.user_may_publish_room(userid, room_id)
+        return True
 
     def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
         """Checks if a user ID or display name are considered "spammy" by this server.
@@ -150,13 +149,14 @@ class SpamChecker(object):
         Returns:
             True if the user is spammy.
         """
-        if self.spam_checker is None:
-            return False
-
-        # For backwards compatibility, if the method does not exist on the spam checker, fallback to not interfering.
-        checker = getattr(self.spam_checker, "check_username_for_spam", None)
-        if not checker:
-            return False
-        # Make a copy of the user profile object to ensure the spam checker
-        # cannot modify it.
-        return checker(user_profile.copy())
+        for spam_checker in self.spam_checkers:
+            # For backwards compatibility, only run if the method exists on the
+            # spam checker
+            checker = getattr(spam_checker, "check_username_for_spam", None)
+            if checker:
+                # Make a copy of the user profile object to ensure the spam checker
+                # cannot modify it.
+                if checker(user_profile.copy()):
+                    return True
+
+        return False
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index e1700ca8aa..52f4f54215 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
 
 import logging
 from collections import namedtuple
+from typing import Dict, List, Tuple, Type
 
 from six import iteritems
 
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
         self.notifier = hs.get_notifier()
         self.is_mine_id = hs.is_mine_id
 
-        self.presence_map = {}  # Pending presence map user_id -> UserPresenceState
-        self.presence_changed = SortedDict()  # Stream position -> list[user_id]
+        # Pending presence map user_id -> UserPresenceState
+        self.presence_map = {}  # type: Dict[str, UserPresenceState]
+
+        # Stream position -> list[user_id]
+        self.presence_changed = SortedDict()  # type: SortedDict[int, List[str]]
 
         # Stores the destinations we need to explicitly send presence to about a
         # given user.
         # Stream position -> (user_id, destinations)
-        self.presence_destinations = SortedDict()
+        self.presence_destinations = (
+            SortedDict()
+        )  # type: SortedDict[int, Tuple[str, List[str]]]
+
+        # (destination, key) -> EDU
+        self.keyed_edu = {}  # type: Dict[Tuple[str, tuple], Edu]
 
-        self.keyed_edu = {}  # (destination, key) -> EDU
-        self.keyed_edu_changed = SortedDict()  # stream position -> (destination, key)
+        # stream position -> (destination, key)
+        self.keyed_edu_changed = (
+            SortedDict()
+        )  # type: SortedDict[int, Tuple[str, tuple]]
 
-        self.edus = SortedDict()  # stream position -> Edu
+        self.edus = SortedDict()  # type: SortedDict[int, Edu]
 
+        # stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
         self.pos = 1
-        self.pos_time = SortedDict()
+
+        # map from stream ID to the time that stream entry was generated, so that we
+        # can clear out entries after a while
+        self.pos_time = SortedDict()  # type: SortedDict[int, int]
 
         # EVERYTHING IS SAD. In particular, python only makes new scopes when
         # we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
             for edu_key in self.keyed_edu_changed.values():
                 live_keys.add(edu_key)
 
-            to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
-            for edu_key in to_del:
+            keys_to_del = [
+                edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
+            ]
+            for edu_key in keys_to_del:
                 del self.keyed_edu[edu_key]
 
             # Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
         self._clear_queue_before_pos(token)
 
     async def get_replication_rows(
-        self, from_token, to_token, limit, federation_ack=None
-    ):
+        self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+    ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
         """Get rows to be sent over federation between the two tokens
 
         Args:
-            from_token (int)
-            to_token(int)
-            limit (int)
-            federation_ack (int): Optional. The position where the worker is
-                explicitly acknowledged it has handled. Allows us to drop
-                data from before that point
+            instance_name: the name of the current process
+            from_token: the previous stream token: the starting point for fetching the
+                updates
+            to_token: the new stream token: the point to get updates up to
+            target_row_count: a target for the number of rows to be returned.
+
+        Returns: a triplet `(updates, new_last_token, limited)`, where:
+           * `updates` is a list of `(token, row)` entries.
+           * `new_last_token` is the new position in stream.
+           * `limited` is whether there are more updates to fetch.
         """
-        # TODO: Handle limit.
+        # TODO: Handle target_row_count.
 
         # To handle restarts where we wrap around
         if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
 
         # list of tuple(int, BaseFederationRow), where the first is the position
         # of the federation stream.
-        rows = []
-
-        # There should be only one reader, so lets delete everything its
-        # acknowledged its seen.
-        if federation_ack:
-            self._clear_queue_before_pos(federation_ack)
+        rows = []  # type: List[Tuple[int, BaseFederationRow]]
 
         # Fetch changed presence
         i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
         # Sort rows based on pos
         rows.sort()
 
-        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+        return (
+            [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
+            to_token,
+            False,
+        )
 
 
 class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
     Specifies how to identify, serialize and deserialize the different types.
     """
 
-    TypeId = None  # Unique string that ids the type. Must be overriden in sub classes.
+    TypeId = ""  # Unique string that ids the type. Must be overriden in sub classes.
 
     @staticmethod
     def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))):  # Edu
         buff.edus.setdefault(self.edu.destination, []).append(self.edu)
 
 
-TypeToRow = {
-    Row.TypeId: Row
-    for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
-}
+_rowtypes = (
+    PresenceRow,
+    PresenceDestinationsRow,
+    KeyedEduRow,
+    EduRow,
+)  # type: Tuple[Type[BaseFederationRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
 
 
 ParsedFederationStreamData = namedtuple(
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a477578e44..d473576902 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set
+from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
 
 from six import itervalues
 
@@ -498,14 +498,16 @@ class FederationSender(object):
 
         self._get_per_destination_queue(destination).attempt_new_transaction()
 
-    def get_current_token(self) -> int:
+    @staticmethod
+    def get_current_token() -> int:
         # Dummy implementation for case where federation sender isn't offloaded
         # to a worker.
         return 0
 
+    @staticmethod
     async def get_replication_rows(
-        self, from_token, to_token, limit, federation_ack=None
-    ):
+        instance_name: str, from_token: int, to_token: int, target_row_count: int
+    ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
         # Dummy implementation for case where federation sender isn't offloaded
         # to a worker.
-        return []
+        return [], 0, False
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index e13cd20ffa..276a2b596f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,10 @@
 # limitations under the License.
 import datetime
 import logging
-from typing import Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
 
 from prometheus_client import Counter
 
-import synapse.server
 from synapse.api.errors import (
     FederationDeniedError,
     HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
 from synapse.types import ReadReceipt
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
+if TYPE_CHECKING:
+    import synapse.server
+
 # This is defined in the Matrix spec and enforced by the receiver.
 MAX_EDUS_PER_TRANSACTION = 100
 
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3c2a02a3b3..a2752a54a5 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,11 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List
+from typing import TYPE_CHECKING, List
 
 from canonicaljson import json
 
-import synapse.server
 from synapse.api.errors import HttpResponseException
 from synapse.events import EventBase
 from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
 )
 from synapse.util.metrics import measure_func
 
+if TYPE_CHECKING:
+    import synapse.server
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 5c20e29171..524281d2f1 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -126,13 +126,13 @@ class AuthHandler(BaseHandler):
         # It notifies the user they are about to give access to their matrix account
         # to the client.
         self._sso_redirect_confirm_template = load_jinja2_templates(
-            hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+            hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
         )[0]
         # The following template is shown during user interactive authentication
         # in the fallback auth scenario. It notifies the user that they are
         # authenticating for an operation to occur on their account.
         self._sso_auth_confirm_template = load_jinja2_templates(
-            hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"],
+            hs.config.sso_template_dir, ["sso_auth_confirm.html"],
         )[0]
         # The following template is shown after a successful user interactive
         # authentication session. It tells the user they can close the window.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4e5c645525..81d859f807 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2681,8 +2681,7 @@ class FederationHandler(BaseHandler):
         member_handler = self.hs.get_room_member_handler()
         await member_handler.send_membership_event(None, event, context)
 
-    @defer.inlineCallbacks
-    def add_display_name_to_third_party_invite(
+    async def add_display_name_to_third_party_invite(
         self, room_version, event_dict, event, context
     ):
         key = (
@@ -2690,10 +2689,10 @@ class FederationHandler(BaseHandler):
             event.content["third_party_invite"]["signed"]["token"],
         )
         original_invite = None
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = await context.get_prev_state_ids()
         original_invite_id = prev_state_ids.get(key)
         if original_invite_id:
-            original_invite = yield self.store.get_event(
+            original_invite = await self.store.get_event(
                 original_invite_id, allow_none=True
             )
         if original_invite:
@@ -2714,14 +2713,13 @@ class FederationHandler(BaseHandler):
 
         builder = self.event_builder_factory.new(room_version, event_dict)
         EventValidator().validate_builder(builder)
-        event, context = yield self.event_creation_handler.create_new_client_event(
+        event, context = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
         EventValidator().validate_new(event, self.config)
         return (event, context)
 
-    @defer.inlineCallbacks
-    def _check_signature(self, event, context):
+    async def _check_signature(self, event, context):
         """
         Checks that the signature in the event is consistent with its invite.
 
@@ -2738,12 +2736,12 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = await context.get_prev_state_ids()
         invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
 
         invite_event = None
         if invite_event_id:
-            invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
+            invite_event = await self.store.get_event(invite_event_id, allow_none=True)
 
         if not invite_event:
             raise AuthError(403, "Could not find invite")
@@ -2792,7 +2790,7 @@ class FederationHandler(BaseHandler):
                             raise
                         try:
                             if "key_validity_url" in public_key_object:
-                                yield self._check_key_revocation(
+                                await self._check_key_revocation(
                                     public_key, public_key_object["key_validity_url"]
                                 )
                         except Exception:
@@ -2806,8 +2804,7 @@ class FederationHandler(BaseHandler):
                 last_exception = e
         raise last_exception
 
-    @defer.inlineCallbacks
-    def _check_key_revocation(self, public_key, url):
+    async def _check_key_revocation(self, public_key, url):
         """
         Checks whether public_key has been revoked.
 
@@ -2821,7 +2818,7 @@ class FederationHandler(BaseHandler):
                 for revocation.
         """
         try:
-            response = yield self.http_client.get_json(url, {"public_key": public_key})
+            response = await self.http_client.get_json(url, {"public_key": public_key})
         except Exception:
             raise SynapseError(502, "Third party certificate could not be checked")
         if "valid" not in response or not response["valid"]:
@@ -2916,8 +2913,7 @@ class FederationHandler(BaseHandler):
         else:
             user_joined_room(self.distributor, user, room_id)
 
-    @defer.inlineCallbacks
-    def get_room_complexity(self, remote_room_hosts, room_id):
+    async def get_room_complexity(self, remote_room_hosts, room_id):
         """
         Fetch the complexity of a remote room over federation.
 
@@ -2931,12 +2927,12 @@ class FederationHandler(BaseHandler):
         """
 
         for host in remote_room_hosts:
-            res = yield self.federation_client.get_room_complexity(host, room_id)
+            res = await self.federation_client.get_room_complexity(host, room_id)
 
             # We got a result, return it.
             if res:
-                defer.returnValue(res)
+                return res
 
         # We fell off the bottom, couldn't get the complexity from anyone. Oh
         # well.
-        defer.returnValue(None)
+        return None
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a622a600b4..0242521cc6 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -72,7 +72,6 @@ class MessageHandler(object):
         self.state_store = self.storage.state
         self._event_serializer = hs.get_event_client_serializer()
         self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
-        self._is_worker_app = bool(hs.config.worker_app)
 
         # The scheduled call to self._expire_event. None if no call is currently
         # scheduled.
@@ -260,7 +259,6 @@ class MessageHandler(object):
         Args:
             event (EventBase): The event to schedule the expiry of.
         """
-        assert not self._is_worker_app
 
         expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
         if not isinstance(expiry_ts, int) or event.is_state():
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
new file mode 100644
index 0000000000..178f263439
--- /dev/null
+++ b/synapse/handlers/oidc_handler.py
@@ -0,0 +1,998 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+from typing import Dict, Generic, List, Optional, Tuple, TypeVar
+from urllib.parse import urlencode
+
+import attr
+import pymacaroons
+from authlib.common.security import generate_token
+from authlib.jose import JsonWebToken
+from authlib.oauth2.auth import ClientAuth
+from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
+from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
+from jinja2 import Environment, Template
+from pymacaroons.exceptions import (
+    MacaroonDeserializationException,
+    MacaroonInvalidSignatureException,
+)
+from typing_extensions import TypedDict
+
+from twisted.web.client import readBody
+
+from synapse.config import ConfigError
+from synapse.http.server import finish_request
+from synapse.http.site import SynapseRequest
+from synapse.push.mailer import load_jinja2_templates
+from synapse.server import HomeServer
+from synapse.types import UserID, map_username_to_mxid_localpart
+
+logger = logging.getLogger(__name__)
+
+SESSION_COOKIE_NAME = b"oidc_session"
+
+#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
+#: OpenID.Core sec 3.1.3.3.
+Token = TypedDict(
+    "Token",
+    {
+        "access_token": str,
+        "token_type": str,
+        "id_token": Optional[str],
+        "refresh_token": Optional[str],
+        "expires_in": int,
+        "scope": Optional[str],
+    },
+)
+
+#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
+#: there is no real point of doing this in our case.
+JWK = Dict[str, str]
+
+#: A JWK Set, as per RFC7517 sec 5.
+JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+
+
+class OidcError(Exception):
+    """Used to catch errors when calling the token_endpoint
+    """
+
+    def __init__(self, error, error_description=None):
+        self.error = error
+        self.error_description = error_description
+
+    def __str__(self):
+        if self.error_description:
+            return "{}: {}".format(self.error, self.error_description)
+        return self.error
+
+
+class MappingException(Exception):
+    """Used to catch errors when mapping the UserInfo object
+    """
+
+
+class OidcHandler:
+    """Handles requests related to the OpenID Connect login flow.
+    """
+
+    def __init__(self, hs: HomeServer):
+        self._callback_url = hs.config.oidc_callback_url  # type: str
+        self._scopes = hs.config.oidc_scopes  # type: List[str]
+        self._client_auth = ClientAuth(
+            hs.config.oidc_client_id,
+            hs.config.oidc_client_secret,
+            hs.config.oidc_client_auth_method,
+        )  # type: ClientAuth
+        self._client_auth_method = hs.config.oidc_client_auth_method  # type: str
+        self._subject_claim = hs.config.oidc_subject_claim
+        self._provider_metadata = OpenIDProviderMetadata(
+            issuer=hs.config.oidc_issuer,
+            authorization_endpoint=hs.config.oidc_authorization_endpoint,
+            token_endpoint=hs.config.oidc_token_endpoint,
+            userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
+            jwks_uri=hs.config.oidc_jwks_uri,
+        )  # type: OpenIDProviderMetadata
+        self._provider_needs_discovery = hs.config.oidc_discover  # type: bool
+        self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
+            hs.config.oidc_user_mapping_provider_config
+        )  # type: OidcMappingProvider
+        self._skip_verification = hs.config.oidc_skip_verification  # type: bool
+
+        self._http_client = hs.get_proxied_http_client()
+        self._auth_handler = hs.get_auth_handler()
+        self._registration_handler = hs.get_registration_handler()
+        self._datastore = hs.get_datastore()
+        self._clock = hs.get_clock()
+        self._hostname = hs.hostname  # type: str
+        self._server_name = hs.config.server_name  # type: str
+        self._macaroon_secret_key = hs.config.macaroon_secret_key
+        self._error_template = load_jinja2_templates(
+            hs.config.sso_template_dir, ["sso_error.html"]
+        )[0]
+
+        # identifier for the external_ids table
+        self._auth_provider_id = "oidc"
+
+    def _render_error(
+        self, request, error: str, error_description: Optional[str] = None
+    ) -> None:
+        """Renders the error template and respond with it.
+
+        This is used to show errors to the user. The template of this page can
+        be found under ``synapse/res/templates/sso_error.html``.
+
+        Args:
+            request: The incoming request from the browser.
+                We'll respond with an HTML page describing the error.
+            error: A technical identifier for this error. Those include
+                well-known OAuth2/OIDC error types like invalid_request or
+                access_denied.
+            error_description: A human-readable description of the error.
+        """
+        html_bytes = self._error_template.render(
+            error=error, error_description=error_description
+        ).encode("utf-8")
+
+        request.setResponseCode(400)
+        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+        request.setHeader(b"Content-Length", b"%i" % len(html_bytes))
+        request.write(html_bytes)
+        finish_request(request)
+
+    def _validate_metadata(self):
+        """Verifies the provider metadata.
+
+        This checks the validity of the currently loaded provider. Not
+        everything is checked, only:
+
+          - ``issuer``
+          - ``authorization_endpoint``
+          - ``token_endpoint``
+          - ``response_types_supported`` (checks if "code" is in it)
+          - ``jwks_uri``
+
+        Raises:
+            ValueError: if something in the provider is not valid
+        """
+        # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
+        if self._skip_verification is True:
+            return
+
+        m = self._provider_metadata
+        m.validate_issuer()
+        m.validate_authorization_endpoint()
+        m.validate_token_endpoint()
+
+        if m.get("token_endpoint_auth_methods_supported") is not None:
+            m.validate_token_endpoint_auth_methods_supported()
+            if (
+                self._client_auth_method
+                not in m["token_endpoint_auth_methods_supported"]
+            ):
+                raise ValueError(
+                    '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format(
+                        auth_method=self._client_auth_method,
+                        supported=m["token_endpoint_auth_methods_supported"],
+                    )
+                )
+
+        if m.get("response_types_supported") is not None:
+            m.validate_response_types_supported()
+
+            if "code" not in m["response_types_supported"]:
+                raise ValueError(
+                    '"code" not in "response_types_supported" (%r)'
+                    % (m["response_types_supported"],)
+                )
+
+        # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+        if self._uses_userinfo:
+            if m.get("userinfo_endpoint") is None:
+                raise ValueError(
+                    'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+                )
+        else:
+            # If we're not using userinfo, we need a valid jwks to validate the ID token
+            if m.get("jwks") is None:
+                if m.get("jwks_uri") is not None:
+                    m.validate_jwks_uri()
+                else:
+                    raise ValueError('"jwks_uri" must be set')
+
+    @property
+    def _uses_userinfo(self) -> bool:
+        """Returns True if the ``userinfo_endpoint`` should be used.
+
+        This is based on the requested scopes: if the scopes include
+        ``openid``, the provider should give use an ID token containing the
+        user informations. If not, we should fetch them using the
+        ``access_token`` with the ``userinfo_endpoint``.
+        """
+
+        # Maybe that should be user-configurable and not inferred?
+        return "openid" not in self._scopes
+
+    async def load_metadata(self) -> OpenIDProviderMetadata:
+        """Load and validate the provider metadata.
+
+        The values metadatas are discovered if ``oidc_config.discovery`` is
+        ``True`` and then cached.
+
+        Raises:
+            ValueError: if something in the provider is not valid
+
+        Returns:
+            The provider's metadata.
+        """
+        # If we are using the OpenID Discovery documents, it needs to be loaded once
+        # FIXME: should there be a lock here?
+        if self._provider_needs_discovery:
+            url = get_well_known_url(self._provider_metadata["issuer"], external=True)
+            metadata_response = await self._http_client.get_json(url)
+            # TODO: maybe update the other way around to let user override some values?
+            self._provider_metadata.update(metadata_response)
+            self._provider_needs_discovery = False
+
+        self._validate_metadata()
+
+        return self._provider_metadata
+
+    async def load_jwks(self, force: bool = False) -> JWKS:
+        """Load the JSON Web Key Set used to sign ID tokens.
+
+        If we're not using the ``userinfo_endpoint``, user infos are extracted
+        from the ID token, which is a JWT signed by keys given by the provider.
+        The keys are then cached.
+
+        Args:
+            force: Force reloading the keys.
+
+        Returns:
+            The key set
+
+            Looks like this::
+
+                {
+                    'keys': [
+                        {
+                            'kid': 'abcdef',
+                            'kty': 'RSA',
+                            'alg': 'RS256',
+                            'use': 'sig',
+                            'e': 'XXXX',
+                            'n': 'XXXX',
+                        }
+                    ]
+                }
+        """
+        if self._uses_userinfo:
+            # We're not using jwt signing, return an empty jwk set
+            return {"keys": []}
+
+        # First check if the JWKS are loaded in the provider metadata.
+        # It can happen either if the provider gives its JWKS in the discovery
+        # document directly or if it was already loaded once.
+        metadata = await self.load_metadata()
+        jwk_set = metadata.get("jwks")
+        if jwk_set is not None and not force:
+            return jwk_set
+
+        # Loading the JWKS using the `jwks_uri` metadata
+        uri = metadata.get("jwks_uri")
+        if not uri:
+            raise RuntimeError('Missing "jwks_uri" in metadata')
+
+        jwk_set = await self._http_client.get_json(uri)
+
+        # Caching the JWKS in the provider's metadata
+        self._provider_metadata["jwks"] = jwk_set
+        return jwk_set
+
+    async def _exchange_code(self, code: str) -> Token:
+        """Exchange an authorization code for a token.
+
+        This calls the ``token_endpoint`` with the authorization code we
+        received in the callback to exchange it for a token. The call uses the
+        ``ClientAuth`` to authenticate with the client with its ID and secret.
+
+        Args:
+            code: The autorization code we got from the callback.
+
+        Returns:
+            A dict containing various tokens.
+
+            May look like this::
+
+                {
+                    'token_type': 'bearer',
+                    'access_token': 'abcdef',
+                    'expires_in': 3599,
+                    'id_token': 'ghijkl',
+                    'refresh_token': 'mnopqr',
+                }
+
+        Raises:
+            OidcError: when the ``token_endpoint`` returned an error.
+        """
+        metadata = await self.load_metadata()
+        token_endpoint = metadata.get("token_endpoint")
+        headers = {
+            "Content-Type": "application/x-www-form-urlencoded",
+            "User-Agent": self._http_client.user_agent,
+            "Accept": "application/json",
+        }
+
+        args = {
+            "grant_type": "authorization_code",
+            "code": code,
+            "redirect_uri": self._callback_url,
+        }
+        body = urlencode(args, True)
+
+        # Fill the body/headers with credentials
+        uri, headers, body = self._client_auth.prepare(
+            method="POST", uri=token_endpoint, headers=headers, body=body
+        )
+        headers = {k: [v] for (k, v) in headers.items()}
+
+        # Do the actual request
+        # We're not using the SimpleHttpClient util methods as we don't want to
+        # check the HTTP status code and we do the body encoding ourself.
+        response = await self._http_client.request(
+            method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
+        )
+
+        # This is used in multiple error messages below
+        status = "{code} {phrase}".format(
+            code=response.code, phrase=response.phrase.decode("utf-8")
+        )
+
+        resp_body = await readBody(response)
+
+        if response.code >= 500:
+            # In case of a server error, we should first try to decode the body
+            # and check for an error field. If not, we respond with a generic
+            # error message.
+            try:
+                resp = json.loads(resp_body.decode("utf-8"))
+                error = resp["error"]
+                description = resp.get("error_description", error)
+            except (ValueError, KeyError):
+                # Catch ValueError for the JSON decoding and KeyError for the "error" field
+                error = "server_error"
+                description = (
+                    (
+                        'Authorization server responded with a "{status}" error '
+                        "while exchanging the authorization code."
+                    ).format(status=status),
+                )
+
+            raise OidcError(error, description)
+
+        # Since it is a not a 5xx code, body should be a valid JSON. It will
+        # raise if not.
+        resp = json.loads(resp_body.decode("utf-8"))
+
+        if "error" in resp:
+            error = resp["error"]
+            # In case the authorization server responded with an error field,
+            # it should be a 4xx code. If not, warn about it but don't do
+            # anything special and report the original error message.
+            if response.code < 400:
+                logger.debug(
+                    "Invalid response from the authorization server: "
+                    'responded with a "{status}" '
+                    "but body has an error field: {error!r}".format(
+                        status=status, error=resp["error"]
+                    )
+                )
+
+            description = resp.get("error_description", error)
+            raise OidcError(error, description)
+
+        # Now, this should not be an error. According to RFC6749 sec 5.1, it
+        # should be a 200 code. We're a bit more flexible than that, and will
+        # only throw on a 4xx code.
+        if response.code >= 400:
+            description = (
+                'Authorization server responded with a "{status}" error '
+                'but did not include an "error" field in its response.'.format(
+                    status=status
+                )
+            )
+            logger.warning(description)
+            # Body was still valid JSON. Might be useful to log it for debugging.
+            logger.warning("Code exchange response: {resp!r}".format(resp=resp))
+            raise OidcError("server_error", description)
+
+        return resp
+
+    async def _fetch_userinfo(self, token: Token) -> UserInfo:
+        """Fetch user informations from the ``userinfo_endpoint``.
+
+        Args:
+            token: the token given by the ``token_endpoint``.
+                Must include an ``access_token`` field.
+
+        Returns:
+            UserInfo: an object representing the user.
+        """
+        metadata = await self.load_metadata()
+
+        resp = await self._http_client.get_json(
+            metadata["userinfo_endpoint"],
+            headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
+        )
+
+        return UserInfo(resp)
+
+    async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+        """Return an instance of UserInfo from token's ``id_token``.
+
+        Args:
+            token: the token given by the ``token_endpoint``.
+                Must include an ``id_token`` field.
+            nonce: the nonce value originally sent in the initial authorization
+                request. This value should match the one inside the token.
+
+        Returns:
+            An object representing the user.
+        """
+        metadata = await self.load_metadata()
+        claims_params = {
+            "nonce": nonce,
+            "client_id": self._client_auth.client_id,
+        }
+        if "access_token" in token:
+            # If we got an `access_token`, there should be an `at_hash` claim
+            # in the `id_token` that we can check against.
+            claims_params["access_token"] = token["access_token"]
+            claims_cls = CodeIDToken
+        else:
+            claims_cls = ImplicitIDToken
+
+        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+        jwt = JsonWebToken(alg_values)
+
+        claim_options = {"iss": {"values": [metadata["issuer"]]}}
+
+        # Try to decode the keys in cache first, then retry by forcing the keys
+        # to be reloaded
+        jwk_set = await self.load_jwks()
+        try:
+            claims = jwt.decode(
+                token["id_token"],
+                key=jwk_set,
+                claims_cls=claims_cls,
+                claims_options=claim_options,
+                claims_params=claims_params,
+            )
+        except ValueError:
+            jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
+            claims = jwt.decode(
+                token["id_token"],
+                key=jwk_set,
+                claims_cls=claims_cls,
+                claims_options=claim_options,
+                claims_params=claims_params,
+            )
+
+        claims.validate(leeway=120)  # allows 2 min of clock skew
+        return UserInfo(claims)
+
+    async def handle_redirect_request(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> None:
+        """Handle an incoming request to /login/sso/redirect
+
+        It redirects the browser to the authorization endpoint with a few
+        parameters:
+
+          - ``client_id``: the client ID set in ``oidc_config.client_id``
+          - ``response_type``: ``code``
+          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+          - ``scope``: the list of scopes set in ``oidc_config.scopes``
+          - ``state``: a random string
+          - ``nonce``: a random string
+
+        In addition to redirecting the client, we are setting a cookie with
+        a signed macaroon token containing the state, the nonce and the
+        client_redirect_url params. Those are then checked when the client
+        comes back from the provider.
+
+
+        Args:
+            request: the incoming request from the browser.
+                We'll respond to it with a redirect and a cookie.
+            client_redirect_url: the URL that we should redirect the client to
+                when everything is done
+        """
+
+        state = generate_token()
+        nonce = generate_token()
+
+        cookie = self._generate_oidc_session_token(
+            state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
+        )
+        request.addCookie(
+            SESSION_COOKIE_NAME,
+            cookie,
+            path="/_synapse/oidc",
+            max_age="3600",
+            httpOnly=True,
+            sameSite="lax",
+        )
+
+        metadata = await self.load_metadata()
+        authorization_endpoint = metadata.get("authorization_endpoint")
+        uri = prepare_grant_uri(
+            authorization_endpoint,
+            client_id=self._client_auth.client_id,
+            response_type="code",
+            redirect_uri=self._callback_url,
+            scope=self._scopes,
+            state=state,
+            nonce=nonce,
+        )
+        request.redirect(uri)
+        finish_request(request)
+
+    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+        """Handle an incoming request to /_synapse/oidc/callback
+
+        Since we might want to display OIDC-related errors in a user-friendly
+        way, we don't raise SynapseError from here. Instead, we call
+        ``self._render_error`` which displays an HTML page for the error.
+
+        Most of the OpenID Connect logic happens here:
+
+          - first, we check if there was any error returned by the provider and
+            display it
+          - then we fetch the session cookie, decode and verify it
+          - the ``state`` query parameter should match with the one stored in the
+            session cookie
+          - once we known this session is legit, exchange the code with the
+            provider using the ``token_endpoint`` (see ``_exchange_code``)
+          - once we have the token, use it to either extract the UserInfo from
+            the ``id_token`` (``_parse_id_token``), or use the ``access_token``
+            to fetch UserInfo from the ``userinfo_endpoint``
+            (``_fetch_userinfo``)
+          - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
+            finish the login
+
+        Args:
+            request: the incoming request from the browser.
+        """
+
+        # The provider might redirect with an error.
+        # In that case, just display it as-is.
+        if b"error" in request.args:
+            error = request.args[b"error"][0].decode()
+            description = request.args.get(b"error_description", [b""])[0].decode()
+
+            # Most of the errors returned by the provider could be due by
+            # either the provider misbehaving or Synapse being misconfigured.
+            # The only exception of that is "access_denied", where the user
+            # probably cancelled the login flow. In other cases, log those errors.
+            if error != "access_denied":
+                logger.error("Error from the OIDC provider: %s %s", error, description)
+
+            self._render_error(request, error, description)
+            return
+
+        # Fetch the session cookie
+        session = request.getCookie(SESSION_COOKIE_NAME)
+        if session is None:
+            logger.info("No session cookie found")
+            self._render_error(request, "missing_session", "No session cookie found")
+            return
+
+        # Remove the cookie. There is a good chance that if the callback failed
+        # once, it will fail next time and the code will already be exchanged.
+        # Removing it early avoids spamming the provider with token requests.
+        request.addCookie(
+            SESSION_COOKIE_NAME,
+            b"",
+            path="/_synapse/oidc",
+            expires="Thu, Jan 01 1970 00:00:00 UTC",
+            httpOnly=True,
+            sameSite="lax",
+        )
+
+        # Check for the state query parameter
+        if b"state" not in request.args:
+            logger.info("State parameter is missing")
+            self._render_error(request, "invalid_request", "State parameter is missing")
+            return
+
+        state = request.args[b"state"][0].decode()
+
+        # Deserialize the session token and verify it.
+        try:
+            nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
+        except MacaroonDeserializationException as e:
+            logger.exception("Invalid session")
+            self._render_error(request, "invalid_session", str(e))
+            return
+        except MacaroonInvalidSignatureException as e:
+            logger.exception("Could not verify session")
+            self._render_error(request, "mismatching_session", str(e))
+            return
+
+        # Exchange the code with the provider
+        if b"code" not in request.args:
+            logger.info("Code parameter is missing")
+            self._render_error(request, "invalid_request", "Code parameter is missing")
+            return
+
+        logger.info("Exchanging code")
+        code = request.args[b"code"][0].decode()
+        try:
+            token = await self._exchange_code(code)
+        except OidcError as e:
+            logger.exception("Could not exchange code")
+            self._render_error(request, e.error, e.error_description)
+            return
+
+        # Now that we have a token, get the userinfo, either by decoding the
+        # `id_token` or by fetching the `userinfo_endpoint`.
+        if self._uses_userinfo:
+            logger.info("Fetching userinfo")
+            try:
+                userinfo = await self._fetch_userinfo(token)
+            except Exception as e:
+                logger.exception("Could not fetch userinfo")
+                self._render_error(request, "fetch_error", str(e))
+                return
+        else:
+            logger.info("Extracting userinfo from id_token")
+            try:
+                userinfo = await self._parse_id_token(token, nonce=nonce)
+            except Exception as e:
+                logger.exception("Invalid id_token")
+                self._render_error(request, "invalid_token", str(e))
+                return
+
+        # Call the mapper to register/login the user
+        try:
+            user_id = await self._map_userinfo_to_user(userinfo, token)
+        except MappingException as e:
+            logger.exception("Could not map user")
+            self._render_error(request, "mapping_error", str(e))
+            return
+
+        # and finally complete the login
+        await self._auth_handler.complete_sso_login(
+            user_id, request, client_redirect_url
+        )
+
+    def _generate_oidc_session_token(
+        self,
+        state: str,
+        nonce: str,
+        client_redirect_url: str,
+        duration_in_ms: int = (60 * 60 * 1000),
+    ) -> str:
+        """Generates a signed token storing data about an OIDC session.
+
+        When Synapse initiates an authorization flow, it creates a random state
+        and a random nonce. Those parameters are given to the provider and
+        should be verified when the client comes back from the provider.
+        It is also used to store the client_redirect_url, which is used to
+        complete the SSO login flow.
+
+        Args:
+            state: The ``state`` parameter passed to the OIDC provider.
+            nonce: The ``nonce`` parameter passed to the OIDC provider.
+            client_redirect_url: The URL the client gave when it initiated the
+                flow.
+            duration_in_ms: An optional duration for the token in milliseconds.
+                Defaults to an hour.
+
+        Returns:
+            A signed macaroon token with the session informations.
+        """
+        macaroon = pymacaroons.Macaroon(
+            location=self._server_name, identifier="key", key=self._macaroon_secret_key,
+        )
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = session")
+        macaroon.add_first_party_caveat("state = %s" % (state,))
+        macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
+        macaroon.add_first_party_caveat(
+            "client_redirect_url = %s" % (client_redirect_url,)
+        )
+        now = self._clock.time_msec()
+        expiry = now + duration_in_ms
+        macaroon.add_first_party_caveat("time < %d" % (expiry,))
+        return macaroon.serialize()
+
+    def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
+        """Verifies and extract an OIDC session token.
+
+        This verifies that a given session token was issued by this homeserver
+        and extract the nonce and client_redirect_url caveats.
+
+        Args:
+            session: The session token to verify
+            state: The state the OIDC provider gave back
+
+        Returns:
+            The nonce and the client_redirect_url for this session
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(session)
+
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = session")
+        v.satisfy_exact("state = %s" % (state,))
+        v.satisfy_general(lambda c: c.startswith("nonce = "))
+        v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+        v.satisfy_general(self._verify_expiry)
+
+        v.verify(macaroon, self._macaroon_secret_key)
+
+        # Extract the `nonce` and `client_redirect_url` from the token
+        nonce = self._get_value_from_macaroon(macaroon, "nonce")
+        client_redirect_url = self._get_value_from_macaroon(
+            macaroon, "client_redirect_url"
+        )
+
+        return nonce, client_redirect_url
+
+    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
+        """Extracts a caveat value from a macaroon token.
+
+        Args:
+            macaroon: the token
+            key: the key of the caveat to extract
+
+        Returns:
+            The extracted value
+
+        Raises:
+            Exception: if the caveat was not in the macaroon
+        """
+        prefix = key + " = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(prefix):
+                return caveat.caveat_id[len(prefix) :]
+        raise Exception("No %s caveat in macaroon" % (key,))
+
+    def _verify_expiry(self, caveat: str) -> bool:
+        prefix = "time < "
+        if not caveat.startswith(prefix):
+            return False
+        expiry = int(caveat[len(prefix) :])
+        now = self._clock.time_msec()
+        return now < expiry
+
+    async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+        """Maps a UserInfo object to a mxid.
+
+        UserInfo should have a claim that uniquely identifies users. This claim
+        is usually `sub`, but can be configured with `oidc_config.subject_claim`.
+        It is then used as an `external_id`.
+
+        If we don't find the user that way, we should register the user,
+        mapping the localpart and the display name from the UserInfo.
+
+        If a user already exists with the mxid we've mapped, raise an exception.
+
+        Args:
+            userinfo: an object representing the user
+            token: a dict with the tokens obtained from the provider
+
+        Raises:
+            MappingException: if there was an error while mapping some properties
+
+        Returns:
+            The mxid of the user
+        """
+        try:
+            remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+        except Exception as e:
+            raise MappingException(
+                "Failed to extract subject from OIDC response: %s" % (e,)
+            )
+
+        logger.info(
+            "Looking for existing mapping for user %s:%s",
+            self._auth_provider_id,
+            remote_user_id,
+        )
+
+        registered_user_id = await self._datastore.get_user_by_external_id(
+            self._auth_provider_id, remote_user_id,
+        )
+
+        if registered_user_id is not None:
+            logger.info("Found existing mapping %s", registered_user_id)
+            return registered_user_id
+
+        try:
+            attributes = await self._user_mapping_provider.map_user_attributes(
+                userinfo, token
+            )
+        except Exception as e:
+            raise MappingException(
+                "Could not extract user attributes from OIDC response: " + str(e)
+            )
+
+        logger.debug(
+            "Retrieved user attributes from user mapping provider: %r", attributes
+        )
+
+        if not attributes["localpart"]:
+            raise MappingException("localpart is empty")
+
+        localpart = map_username_to_mxid_localpart(attributes["localpart"])
+
+        user_id = UserID(localpart, self._hostname)
+        if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
+            # This mxid is taken
+            raise MappingException(
+                "mxid '{}' is already taken".format(user_id.to_string())
+            )
+
+        # It's the first time this user is logging in and the mapped mxid was
+        # not taken, register the user
+        registered_user_id = await self._registration_handler.register_user(
+            localpart=localpart, default_display_name=attributes["display_name"],
+        )
+
+        await self._datastore.record_user_external_id(
+            self._auth_provider_id, remote_user_id, registered_user_id,
+        )
+        return registered_user_id
+
+
+UserAttribute = TypedDict(
+    "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+)
+C = TypeVar("C")
+
+
+class OidcMappingProvider(Generic[C]):
+    """A mapping provider maps a UserInfo object to user attributes.
+
+    It should provide the API described by this class.
+    """
+
+    def __init__(self, config: C):
+        """
+        Args:
+            config: A custom config object from this module, parsed by ``parse_config()``
+        """
+
+    @staticmethod
+    def parse_config(config: dict) -> C:
+        """Parse the dict provided by the homeserver's config
+
+        Args:
+            config: A dictionary containing configuration options for this provider
+
+        Returns:
+            A custom config object for this module
+        """
+        raise NotImplementedError()
+
+    def get_remote_user_id(self, userinfo: UserInfo) -> str:
+        """Get a unique user ID for this user.
+
+        Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object.
+
+        Args:
+            userinfo: An object representing the user given by the OIDC provider
+
+        Returns:
+            A unique user ID
+        """
+        raise NotImplementedError()
+
+    async def map_user_attributes(
+        self, userinfo: UserInfo, token: Token
+    ) -> UserAttribute:
+        """Map a ``UserInfo`` objects into user attributes.
+
+        Args:
+            userinfo: An object representing the user given by the OIDC provider
+            token: A dict with the tokens returned by the provider
+
+        Returns:
+            A dict containing the ``localpart`` and (optionally) the ``display_name``
+        """
+        raise NotImplementedError()
+
+
+# Used to clear out "None" values in templates
+def jinja_finalize(thing):
+    return thing if thing is not None else ""
+
+
+env = Environment(finalize=jinja_finalize)
+
+
+@attr.s
+class JinjaOidcMappingConfig:
+    subject_claim = attr.ib()  # type: str
+    localpart_template = attr.ib()  # type: Template
+    display_name_template = attr.ib()  # type: Optional[Template]
+
+
+class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
+    """An implementation of a mapping provider based on Jinja templates.
+
+    This is the default mapping provider.
+    """
+
+    def __init__(self, config: JinjaOidcMappingConfig):
+        self._config = config
+
+    @staticmethod
+    def parse_config(config: dict) -> JinjaOidcMappingConfig:
+        subject_claim = config.get("subject_claim", "sub")
+
+        if "localpart_template" not in config:
+            raise ConfigError(
+                "missing key: oidc_config.user_mapping_provider.config.localpart_template"
+            )
+
+        try:
+            localpart_template = env.from_string(config["localpart_template"])
+        except Exception as e:
+            raise ConfigError(
+                "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
+                % (e,)
+            )
+
+        display_name_template = None  # type: Optional[Template]
+        if "display_name_template" in config:
+            try:
+                display_name_template = env.from_string(config["display_name_template"])
+            except Exception as e:
+                raise ConfigError(
+                    "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
+                    % (e,)
+                )
+
+        return JinjaOidcMappingConfig(
+            subject_claim=subject_claim,
+            localpart_template=localpart_template,
+            display_name_template=display_name_template,
+        )
+
+    def get_remote_user_id(self, userinfo: UserInfo) -> str:
+        return userinfo[self._config.subject_claim]
+
+    async def map_user_attributes(
+        self, userinfo: UserInfo, token: Token
+    ) -> UserAttribute:
+        localpart = self._config.localpart_template.render(user=userinfo).strip()
+
+        display_name = None  # type: Optional[str]
+        if self._config.display_name_template is not None:
+            display_name = self._config.display_name_template.render(
+                user=userinfo
+            ).strip()
+
+            if display_name == "":
+                display_name = None
+
+        return UserAttribute(localpart=localpart, display_name=display_name)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index da12df7f53..73f9eeb399 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,8 +25,6 @@ from collections import OrderedDict
 
 from six import iteritems, string_types
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
 from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
-    @defer.inlineCallbacks
-    def upgrade_room(
+    async def upgrade_room(
         self, requester: Requester, old_room_id: str, new_version: RoomVersion
     ):
         """Replace a room with a new room with a different version
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
         Returns:
             Deferred[unicode]: the new room id
         """
-        yield self.ratelimit(requester)
+        await self.ratelimit(requester)
 
         user_id = requester.user.to_string()
 
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
         # If this user has sent multiple upgrade requests for the same room
         # and one of them is not complete yet, cache the response and
         # return it to all subsequent requests
-        ret = yield self._upgrade_response_cache.wrap(
+        ret = await self._upgrade_response_cache.wrap(
             (old_room_id, user_id),
             self._upgrade_room,
             requester,
@@ -856,8 +853,7 @@ class RoomCreationHandler(BaseHandler):
         for (etype, state_key), content in initial_state.items():
             await send(etype=etype, state_key=state_key, content=content)
 
-    @defer.inlineCallbacks
-    def _generate_room_id(
+    async def _generate_room_id(
         self, creator_id: str, is_public: str, room_version: RoomVersion,
     ):
         # autogen room IDs and try to create it. We may clash, so just
@@ -869,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
                 gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
                 if isinstance(gen_room_id, bytes):
                     gen_room_id = gen_room_id.decode("utf-8")
-                yield self.store.store_room(
+                await self.store.store_room(
                     room_id=gen_room_id,
                     room_creator_user_id=creator_id,
                     is_public=is_public,
@@ -888,8 +884,7 @@ class RoomContextHandler(object):
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    @defer.inlineCallbacks
-    def get_event_context(self, user, room_id, event_id, limit, event_filter):
+    async def get_event_context(self, user, room_id, event_id, limit, event_filter):
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
@@ -908,7 +903,7 @@ class RoomContextHandler(object):
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
-        users = yield self.store.get_users_in_room(room_id)
+        users = await self.store.get_users_in_room(room_id)
         is_peeking = user.to_string() not in users
 
         def filter_evts(events):
@@ -916,17 +911,17 @@ class RoomContextHandler(object):
                 self.storage, user.to_string(), events, is_peeking=is_peeking
             )
 
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, get_prev_content=True, allow_none=True
         )
         if not event:
             return None
 
-        filtered = yield (filter_evts([event]))
+        filtered = await filter_evts([event])
         if not filtered:
             raise AuthError(403, "You don't have permission to access that event.")
 
-        results = yield self.store.get_events_around(
+        results = await self.store.get_events_around(
             room_id, event_id, before_limit, after_limit, event_filter
         )
 
@@ -934,8 +929,8 @@ class RoomContextHandler(object):
             results["events_before"] = event_filter.filter(results["events_before"])
             results["events_after"] = event_filter.filter(results["events_after"])
 
-        results["events_before"] = yield filter_evts(results["events_before"])
-        results["events_after"] = yield filter_evts(results["events_after"])
+        results["events_before"] = await filter_evts(results["events_before"])
+        results["events_after"] = await filter_evts(results["events_after"])
         # filter_evts can return a pruned event in case the user is allowed to see that
         # there's something there but not see the content, so use the event that's in
         # `filtered` rather than the event we retrieved from the datastore.
@@ -962,7 +957,7 @@ class RoomContextHandler(object):
         # first? Shouldn't we be consistent with /sync?
         # https://github.com/matrix-org/matrix-doc/issues/687
 
-        state = yield self.state_store.get_state_for_events(
+        state = await self.state_store.get_state_for_events(
             [last_event_id], state_filter=state_filter
         )
 
@@ -970,7 +965,7 @@ class RoomContextHandler(object):
         if event_filter:
             state_events = event_filter.filter(state_events)
 
-        results["state"] = yield filter_evts(state_events)
+        results["state"] = await filter_evts(state_events)
 
         # We use a dummy token here as we only care about the room portion of
         # the token, which we replace.
@@ -989,13 +984,12 @@ class RoomEventSource(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
-    def get_new_events(
+    async def get_new_events(
         self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
     ):
         # We just ignore the key for now.
 
-        to_key = yield self.get_current_key()
+        to_key = await self.get_current_key()
 
         from_token = RoomStreamToken.parse(from_key)
         if from_token.topological:
@@ -1008,11 +1002,11 @@ class RoomEventSource(object):
             # See https://github.com/matrix-org/matrix-doc/issues/1144
             raise NotImplementedError()
         else:
-            room_events = yield self.store.get_membership_changes_for_user(
+            room_events = await self.store.get_membership_changes_for_user(
                 user.to_string(), from_key, to_key
             )
 
-            room_to_events = yield self.store.get_room_events_stream_for_rooms(
+            room_to_events = await self.store.get_room_events_stream_for_rooms(
                 room_ids=room_ids,
                 from_key=from_key,
                 to_key=to_key,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 53b49bc15f..ccc9659454 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -875,8 +875,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         self.distributor.declare("user_joined_room")
         self.distributor.declare("user_left_room")
 
-    @defer.inlineCallbacks
-    def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
+    async def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
         """
         Check if complexity of a remote room is too great.
 
@@ -888,7 +887,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             if unable to be fetched
         """
         max_complexity = self.hs.config.limit_remote_rooms.complexity
-        complexity = yield self.federation_handler.get_room_complexity(
+        complexity = await self.federation_handler.get_room_complexity(
             remote_room_hosts, room_id
         )
 
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 96f2dd36ad..e7015c704f 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import re
-from typing import Optional, Tuple
+from typing import Callable, Dict, Optional, Set, Tuple
 
 import attr
 import saml2
@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.http.server import finish_request
 from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
 from synapse.module_api.errors import RedirectException
 from synapse.types import (
@@ -81,17 +82,19 @@ class SamlHandler:
 
         self._error_html_content = hs.config.saml2_error_html_content
 
-    def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
+    def handle_redirect_request(
+        self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
+    ) -> bytes:
         """Handle an incoming request to /login/sso/redirect
 
         Args:
-            client_redirect_url (bytes): the URL that we should redirect the
+            client_redirect_url: the URL that we should redirect the
                 client to when everything is done
-            ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
         Returns:
-            bytes: URL to redirect to
+            URL to redirect to
         """
         reqid, info = self._saml_client.prepare_for_authenticate(
             relay_state=client_redirect_url
@@ -109,15 +112,15 @@ class SamlHandler:
         # this shouldn't happen!
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
-    async def handle_saml_response(self, request):
+    async def handle_saml_response(self, request: SynapseRequest) -> None:
         """Handle an incoming request to /_matrix/saml2/authn_response
 
         Args:
-            request (SynapseRequest): the incoming request from the browser. We'll
+            request: the incoming request from the browser. We'll
                 respond to it with a redirect.
 
         Returns:
-            Deferred[none]: Completes once we have handled the request.
+            Completes once we have handled the request.
         """
         resp_bytes = parse_string(request, "SAMLResponse", required=True)
         relay_state = parse_string(request, "RelayState", required=True)
@@ -310,6 +313,7 @@ DOT_REPLACE_PATTERN = re.compile(
 
 
 def dot_replace_for_mxid(username: str) -> str:
+    """Replace any characters which are not allowed in Matrix IDs with a dot."""
     username = username.lower()
     username = DOT_REPLACE_PATTERN.sub(".", username)
 
@@ -321,7 +325,7 @@ def dot_replace_for_mxid(username: str) -> str:
 MXID_MAPPER_MAP = {
     "hexencode": map_username_to_mxid_localpart,
     "dotreplace": dot_replace_for_mxid,
-}
+}  # type: Dict[str, Callable[[str], str]]
 
 
 @attr.s
@@ -349,7 +353,7 @@ class DefaultSamlMappingProvider(object):
 
     def get_remote_user_id(
         self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
-    ):
+    ) -> str:
         """Extracts the remote user id from the SAML response"""
         try:
             return saml_response.ava["uid"][0]
@@ -428,14 +432,14 @@ class DefaultSamlMappingProvider(object):
         return SamlConfig(mxid_source_attribute, mxid_mapper)
 
     @staticmethod
-    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+    def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
         """Returns the required attributes of a SAML
 
         Args:
             config: A SamlConfig object containing configuration params for this provider
 
         Returns:
-            tuple[set,set]: The first set equates to the saml auth response
+            The first set equates to the saml auth response
                 attributes that are required for the module to function, whereas the
                 second set consists of those attributes which can be used if
                 available, but are not necessary
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ec1542d416..4d40d3ac9c 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -18,8 +18,6 @@ import logging
 
 from unpaddedbase64 import decode_base64, encode_base64
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
@@ -39,8 +37,7 @@ class SearchHandler(BaseHandler):
         self.state_store = self.storage.state
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def get_old_rooms_from_upgraded_room(self, room_id):
+    async def get_old_rooms_from_upgraded_room(self, room_id):
         """Retrieves room IDs of old rooms in the history of an upgraded room.
 
         We do so by checking the m.room.create event of the room for a
@@ -60,7 +57,7 @@ class SearchHandler(BaseHandler):
         historical_room_ids = []
 
         # The initial room must have been known for us to get this far
-        predecessor = yield self.store.get_room_predecessor(room_id)
+        predecessor = await self.store.get_room_predecessor(room_id)
 
         while True:
             if not predecessor:
@@ -75,7 +72,7 @@ class SearchHandler(BaseHandler):
 
             # Don't add it to the list until we have checked that we are in the room
             try:
-                next_predecessor_room = yield self.store.get_room_predecessor(
+                next_predecessor_room = await self.store.get_room_predecessor(
                     predecessor_room_id
                 )
             except NotFoundError:
@@ -89,8 +86,7 @@ class SearchHandler(BaseHandler):
 
         return historical_room_ids
 
-    @defer.inlineCallbacks
-    def search(self, user, content, batch=None):
+    async def search(self, user, content, batch=None):
         """Performs a full text search for a user.
 
         Args:
@@ -179,7 +175,7 @@ class SearchHandler(BaseHandler):
         search_filter = Filter(filter_dict)
 
         # TODO: Search through left rooms too
-        rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
+        rooms = await self.store.get_rooms_for_local_user_where_membership_is(
             user.to_string(),
             membership_list=[Membership.JOIN],
             # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
@@ -192,7 +188,7 @@ class SearchHandler(BaseHandler):
             historical_room_ids = []
             for room_id in search_filter.rooms:
                 # Add any previous rooms to the search if they exist
-                ids = yield self.get_old_rooms_from_upgraded_room(room_id)
+                ids = await self.get_old_rooms_from_upgraded_room(room_id)
                 historical_room_ids += ids
 
             # Prevent any historical events from being filtered
@@ -223,7 +219,7 @@ class SearchHandler(BaseHandler):
         count = None
 
         if order_by == "rank":
-            search_result = yield self.store.search_msgs(room_ids, search_term, keys)
+            search_result = await self.store.search_msgs(room_ids, search_term, keys)
 
             count = search_result["count"]
 
@@ -238,7 +234,7 @@ class SearchHandler(BaseHandler):
 
             filtered_events = search_filter.filter([r["event"] for r in results])
 
-            events = yield filter_events_for_client(
+            events = await filter_events_for_client(
                 self.storage, user.to_string(), filtered_events
             )
 
@@ -267,7 +263,7 @@ class SearchHandler(BaseHandler):
             # But only go around 5 times since otherwise synapse will be sad.
             while len(room_events) < search_filter.limit() and i < 5:
                 i += 1
-                search_result = yield self.store.search_rooms(
+                search_result = await self.store.search_rooms(
                     room_ids,
                     search_term,
                     keys,
@@ -288,7 +284,7 @@ class SearchHandler(BaseHandler):
 
                 filtered_events = search_filter.filter([r["event"] for r in results])
 
-                events = yield filter_events_for_client(
+                events = await filter_events_for_client(
                     self.storage, user.to_string(), filtered_events
                 )
 
@@ -343,11 +339,11 @@ class SearchHandler(BaseHandler):
         # If client has asked for "context" for each event (i.e. some surrounding
         # events and state), fetch that
         if event_context is not None:
-            now_token = yield self.hs.get_event_sources().get_current_token()
+            now_token = await self.hs.get_event_sources().get_current_token()
 
             contexts = {}
             for event in allowed_events:
-                res = yield self.store.get_events_around(
+                res = await self.store.get_events_around(
                     event.room_id, event.event_id, before_limit, after_limit
                 )
 
@@ -357,11 +353,11 @@ class SearchHandler(BaseHandler):
                     len(res["events_after"]),
                 )
 
-                res["events_before"] = yield filter_events_for_client(
+                res["events_before"] = await filter_events_for_client(
                     self.storage, user.to_string(), res["events_before"]
                 )
 
-                res["events_after"] = yield filter_events_for_client(
+                res["events_after"] = await filter_events_for_client(
                     self.storage, user.to_string(), res["events_after"]
                 )
 
@@ -390,7 +386,7 @@ class SearchHandler(BaseHandler):
                         [(EventTypes.Member, sender) for sender in senders]
                     )
 
-                    state = yield self.state_store.get_state_for_event(
+                    state = await self.state_store.get_state_for_event(
                         last_event_id, state_filter
                     )
 
@@ -412,10 +408,10 @@ class SearchHandler(BaseHandler):
         time_now = self.clock.time_msec()
 
         for context in contexts.values():
-            context["events_before"] = yield self._event_serializer.serialize_events(
+            context["events_before"] = await self._event_serializer.serialize_events(
                 context["events_before"], time_now
             )
-            context["events_after"] = yield self._event_serializer.serialize_events(
+            context["events_after"] = await self._event_serializer.serialize_events(
                 context["events_after"], time_now
             )
 
@@ -423,7 +419,7 @@ class SearchHandler(BaseHandler):
         if include_state:
             rooms = {e.room_id for e in allowed_events}
             for room_id in rooms:
-                state = yield self.state_handler.get_current_state(room_id)
+                state = await self.state_handler.get_current_state(room_id)
                 state_results[room_id] = list(state.values())
 
             state_results.values()
@@ -437,7 +433,7 @@ class SearchHandler(BaseHandler):
                 {
                     "rank": rank_map[e.event_id],
                     "result": (
-                        yield self._event_serializer.serialize_event(e, time_now)
+                        await self._event_serializer.serialize_event(e, time_now)
                     ),
                     "context": contexts.get(e.event_id, {}),
                 }
@@ -452,7 +448,7 @@ class SearchHandler(BaseHandler):
         if state_results:
             s = {}
             for room_id, state in state_results.items():
-                s[room_id] = yield self._event_serializer.serialize_events(
+                s[room_id] = await self._event_serializer.serialize_events(
                     state, time_now
                 )
 
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 3797545824..3cef747a4d 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -49,7 +49,6 @@ from synapse.http.proxyagent import ProxyAgent
 from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.opentracing import set_tag, start_active_span, tags
 from synapse.util.async_helpers import timeout_deferred
-from synapse.util.caches import CACHE_SIZE_FACTOR
 
 logger = logging.getLogger(__name__)
 
@@ -241,7 +240,10 @@ class SimpleHttpClient(object):
         # tends to do so in batches, so we need to allow the pool to keep
         # lots of idle connections around.
         pool = HTTPConnectionPool(self.reactor)
-        pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
+        # XXX: The justification for using the cache factor here is that larger instances
+        # will need both more cache and more connections.
+        # Still, this should probably be a separate dial
+        pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
         pool.cachedConnectionTimeout = 2 * 60
 
         self.agent = ProxyAgent(
@@ -359,6 +361,7 @@ class SimpleHttpClient(object):
         actual_headers = {
             b"Content-Type": [b"application/x-www-form-urlencoded"],
             b"User-Agent": [self.user_agent],
+            b"Accept": [b"application/json"],
         }
         if headers:
             actual_headers.update(headers)
@@ -399,6 +402,7 @@ class SimpleHttpClient(object):
         actual_headers = {
             b"Content-Type": [b"application/json"],
             b"User-Agent": [self.user_agent],
+            b"Accept": [b"application/json"],
         }
         if headers:
             actual_headers.update(headers)
@@ -434,6 +438,10 @@ class SimpleHttpClient(object):
 
             ValueError: if the response was not JSON
         """
+        actual_headers = {b"Accept": [b"application/json"]}
+        if headers:
+            actual_headers.update(headers)
+
         body = yield self.get_raw(uri, args, headers=headers)
         return json.loads(body)
 
@@ -467,6 +475,7 @@ class SimpleHttpClient(object):
         actual_headers = {
             b"Content-Type": [b"application/json"],
             b"User-Agent": [self.user_agent],
+            b"Accept": [b"application/json"],
         }
         if headers:
             actual_headers.update(headers)
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index a248103191..ab7f948ed4 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -33,6 +33,8 @@ from prometheus_client import REGISTRY
 
 from twisted.web.resource import Resource
 
+from synapse.util import caches
+
 try:
     from prometheus_client.samples import Sample
 except ImportError:
@@ -103,13 +105,15 @@ def nameify_sample(sample):
 
 
 def generate_latest(registry, emit_help=False):
-    output = []
 
-    for metric in registry.collect():
+    # Trigger the cache metrics to be rescraped, which updates the common
+    # metrics but do not produce metrics themselves
+    for collector in caches.collectors_by_name.values():
+        collector.collect()
 
-        if metric.name.startswith("__unused"):
-            continue
+    output = []
 
+    for metric in registry.collect():
         if not metric.samples:
             # No samples, don't bother.
             continue
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 433ca2f416..e75d964ac8 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -51,6 +51,7 @@ push_rules_delta_state_cache_metric = register_cache(
     "cache",
     "push_rules_delta_state_cache_metric",
     cache=[],  # Meaningless size, as this isn't a cache that stores values
+    resizable=False,
 )
 
 
@@ -67,7 +68,8 @@ class BulkPushRuleEvaluator(object):
         self.room_push_rule_cache_metrics = register_cache(
             "cache",
             "room_push_rule_cache",
-            cache=[],  # Meaningless size, as this isn't a cache that stores values
+            cache=[],  # Meaningless size, as this isn't a cache that stores values,
+            resizable=False,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4cd702b5fa..11032491af 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -22,7 +22,7 @@ from six import string_types
 
 from synapse.events import EventBase
 from synapse.types import UserID
-from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
+from synapse.util.caches import register_cache
 from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
@@ -165,7 +165,7 @@ class PushRuleEvaluatorForEvent(object):
 
 
 # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
-regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
+regex_cache = LruCache(50000)
 register_cache("cache", "regex_push_cache", regex_cache)
 
 
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 39c99a2802..8b4312e5a3 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -92,6 +92,7 @@ CONDITIONAL_REQUIREMENTS = {
         'eliot<1.8.0;python_version<"3.5.3"',
     ],
     "saml2": ["pysaml2>=4.5.0"],
+    "oidc": ["authlib>=0.14.0"],
     "systemd": ["systemd-python>=231"],
     "url_preview": ["lxml>=3.5.0"],
     "test": ["mock>=2.0", "parameterized"],
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index 0459f582bf..b705a8e16c 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -52,9 +52,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
 
         self._instance_name = hs.get_instance_name()
 
-        # We pull the streams from the replication steamer (if we try and make
+        # We pull the streams from the replication handler (if we try and make
         # them ourselves we end up in an import loop).
-        self.streams = hs.get_replication_streamer().get_streams()
+        self.streams = hs.get_tcp_replication().get_streams()
 
     @staticmethod
     def _serialize_payload(stream_name, from_token, upto_token):
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 5d7c8871a4..2904bd0235 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,14 +18,10 @@ from typing import Optional
 
 import six
 
-from synapse.storage.data_stores.main.cache import (
-    CURRENT_STATE_CACHE_NAME,
-    CacheInvalidationWorkerStore,
-)
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
-
-from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 logger = logging.getLogger(__name__)
 
@@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: Database, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = SlavedIdTracker(
-                db_conn, "cache_invalidation_stream", "stream_id"
-            )  # type: Optional[SlavedIdTracker]
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                instance_name=hs.get_instance_name(),
+                table="cache_invalidation_stream_by_instance",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="cache_invalidation_stream_seq",
+            )  # type: Optional[MultiWriterIdGenerator]
         else:
             self._cache_id_gen = None
 
         self.hs = hs
-
-    def get_cache_stream_token(self):
-        if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token()
-        else:
-            return 0
-
-    def process_replication_rows(self, stream_name, token, rows):
-        if stream_name == "caches":
-            if self._cache_id_gen:
-                self._cache_id_gen.advance(token)
-            for row in rows:
-                if row.cache_func == CURRENT_STATE_CACHE_NAME:
-                    if row.keys is None:
-                        raise Exception(
-                            "Can't send an 'invalidate all' for current state cache"
-                        )
-
-                    room_id = row.keys[0]
-                    members_changed = set(row.keys[1:])
-                    self._invalidate_state_caches(room_id, members_changed)
-                else:
-                    self._attempt_to_invalidate_cache(row.cache_func, row.keys)
-
-    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
-        txn.call_after(cache_func.invalidate, keys)
-        txn.call_after(self._send_invalidation_poke, cache_func, keys)
-
-    def _send_invalidation_poke(self, cache_func, keys):
-        self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 65e54b1c71..2a4f5c7cfd 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "tag_account_data":
             self._account_data_id_gen.advance(token)
             for row in rows:
@@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
                     (row.user_id, row.room_id, row.data_type)
                 )
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedAccountDataStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index fbf996e33a..1a38f53dfb 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -15,7 +15,6 @@
 
 from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.storage.database import Database
-from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.descriptors import Cache
 
 from ._base import BaseSlavedStore
@@ -26,7 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore):
         super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
 
         self.client_ip_last_seen = Cache(
-            name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+            name="client_ip_last_seen", keylen=4, max_entries=50000
         )
 
     def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index c923751e50..6e7fd259d4 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
             expiry_ms=30 * 60 * 1000,
         )
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "to_device":
             self._device_inbox_id_gen.advance(token)
             for row in rows:
@@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
                     self._device_federation_outbox_stream_cache.entity_has_changed(
                         row.entity, token
                     )
-        return super(SlavedDeviceInboxStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 58fb0eaae3..9d8067342f 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             "DeviceListFederationStreamChangeCache", device_list_max
         )
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
             self._invalidate_caches_for_devices(token, rows)
@@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             self._device_list_id_gen.advance(token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedDeviceStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _invalidate_caches_for_devices(self, token, rows):
         for row in rows:
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 15011259df..b313720a4b 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -93,7 +93,7 @@ class SlavedEventStore(
     def get_room_min_stream_ordering(self):
         return self._backfill_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "events":
             self._stream_id_gen.advance(token)
             for row in rows:
@@ -111,9 +111,7 @@ class SlavedEventStore(
                     row.relates_to,
                     backfilled=True,
                 )
-        return super(SlavedEventStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _process_event_stream_row(self, token, row):
         data = row.data
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 01bcf0e882..1851e7d525 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
     def get_group_stream_token(self):
         return self._group_updates_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "groups":
             self._group_updates_id_gen.advance(token)
             for row in rows:
                 self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
 
-        return super(SlavedGroupServerStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index fae3125072..bd79ba99be 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore):
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "presence":
             self._presence_id_gen.advance(token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
                 self._get_presence_for_user.invalidate((row.user_id,))
-        return super(SlavedPresenceStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6138796da4..5d5816d7eb 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "push_rules":
             self._push_rules_stream_id_gen.advance(token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
                 self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedPushRuleStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 67be337945..cb78b49acb 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
     def get_pushers_stream_token(self):
         return self._pushers_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "pushers":
             self._pushers_id_gen.advance(token)
-        return super(SlavedPusherStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 993432edcb..be716cc558 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
         self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "receipts":
             self._receipts_id_gen.advance(token)
             for row in rows:
@@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
                 )
                 self._receipts_stream_cache.entity_has_changed(row.room_id, token)
 
-        return super(SlavedReceiptsStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 10dda8708f..8873bf37e5 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "public_rooms":
             self._public_room_id_gen.advance(token)
 
-        return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3bbf3c3569..20cb8a654f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -100,10 +100,10 @@ class ReplicationDataHandler:
             token: stream token for this batch of rows
             rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
-        self.store.process_replication_rows(stream_name, token, rows)
+        self.store.process_replication_rows(stream_name, instance_name, token, rows)
 
-    async def on_position(self, stream_name: str, token: int):
-        self.store.process_replication_rows(stream_name, token, [])
+    async def on_position(self, stream_name: str, instance_name: str, token: int):
+        self.store.process_replication_rows(stream_name, instance_name, token, [])
 
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f58e384d17..c04f622816 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
         return " ".join((self.app_id, self.push_key, self.user_id))
 
 
-class InvalidateCacheCommand(Command):
-    """Sent by the client to invalidate an upstream cache.
-
-    THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
-    NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
-    Mainly used to invalidate destination retry timing caches.
-
-    Format::
-
-        INVALIDATE_CACHE <cache_func> <keys_json>
-
-    Where <keys_json> is a json list.
-    """
-
-    NAME = "INVALIDATE_CACHE"
-
-    def __init__(self, cache_func, keys):
-        self.cache_func = cache_func
-        self.keys = keys
-
-    @classmethod
-    def from_line(cls, line):
-        cache_func, keys_json = line.split(" ", 1)
-
-        return cls(cache_func, json.loads(keys_json))
-
-    def to_line(self):
-        return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
 class UserIpCommand(Command):
     """Sent periodically when a worker sees activity from a client.
 
@@ -439,7 +408,6 @@ _COMMANDS = (
     UserSyncCommand,
     FederationAckCommand,
     RemovePusherCommand,
-    InvalidateCacheCommand,
     UserIpCommand,
     RemoteServerUpCommand,
     ClearUserSyncsCommand,
@@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
     ClearUserSyncsCommand.NAME,
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
-    InvalidateCacheCommand.NAME,
     UserIpCommand.NAME,
     ErrorCommand.NAME,
     RemoteServerUpCommand.NAME,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 4328b38e9d..acfa66a7a8 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,18 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    Iterable,
-    Iterator,
-    List,
-    Optional,
-    Set,
-    Tuple,
-    TypeVar,
-)
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
 
 from prometheus_client import Counter
 
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
     Command,
     FederationAckCommand,
-    InvalidateCacheCommand,
     PositionCommand,
     RdataCommand,
     RemoteServerUpCommand,
@@ -48,7 +36,12 @@ from synapse.replication.tcp.commands import (
     UserSyncCommand,
 )
 from synapse.replication.tcp.protocol import AbstractConnection
-from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.replication.tcp.streams import (
+    STREAMS_MAP,
+    CachesStream,
+    FederationStream,
+    Stream,
+)
 from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
@@ -85,6 +78,26 @@ class ReplicationCommandHandler:
             stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
         }  # type: Dict[str, Stream]
 
+        # List of streams that this instance is the source of
+        self._streams_to_replicate = []  # type: List[Stream]
+
+        for stream in self._streams.values():
+            if stream.NAME == CachesStream.NAME:
+                # All workers can write to the cache invalidation stream.
+                self._streams_to_replicate.append(stream)
+                continue
+
+            # Only add any other streams if we're on master.
+            if hs.config.worker_app is not None:
+                continue
+
+            if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+                # We only support federation stream if federation sending
+                # has been disabled on the master.
+                continue
+
+            self._streams_to_replicate.append(stream)
+
         self._position_linearizer = Linearizer(
             "replication_position", clock=self._clock
         )
@@ -162,16 +175,33 @@ class ReplicationCommandHandler:
             port = hs.config.worker_replication_port
             hs.get_reactor().connectTCP(host, port, self._factory)
 
+    def get_streams(self) -> Dict[str, Stream]:
+        """Get a map from stream name to all streams.
+        """
+        return self._streams
+
+    def get_streams_to_replicate(self) -> List[Stream]:
+        """Get a list of streams that this instances replicates.
+        """
+        return self._streams_to_replicate
+
     async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
-        # We only want to announce positions by the writer of the streams.
-        # Currently this is just the master process.
-        if not self._is_master:
-            return
+        self.send_positions_to_connection(conn)
 
-        for stream_name, stream in self._streams.items():
-            current_token = stream.current_token()
+    def send_positions_to_connection(self, conn: AbstractConnection):
+        """Send current position of all streams this process is source of to
+        the connection.
+        """
+
+        # We respond with current position of all streams this instance
+        # replicates.
+        for stream in self.get_streams_to_replicate():
             self.send_command(
-                PositionCommand(stream_name, self._instance_name, current_token)
+                PositionCommand(
+                    stream.NAME,
+                    self._instance_name,
+                    stream.current_token(self._instance_name),
+                )
             )
 
     async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
@@ -208,18 +238,6 @@ class ReplicationCommandHandler:
 
             self._notifier.on_new_replication_data()
 
-    async def on_INVALIDATE_CACHE(
-        self, conn: AbstractConnection, cmd: InvalidateCacheCommand
-    ):
-        invalidate_cache_counter.inc()
-
-        if self._is_master:
-            # We invalidate the cache locally, but then also stream that to other
-            # workers.
-            await self._store.invalidate_cache_and_stream(
-                cmd.cache_func, tuple(cmd.keys)
-            )
-
     async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
         user_ip_cache_counter.inc()
 
@@ -293,7 +311,7 @@ class ReplicationCommandHandler:
             rows: a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
         """
-        logger.debug("Received rdata %s -> %s", stream_name, token)
+        logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
         await self._replication_data_handler.on_rdata(
             stream_name, instance_name, token, rows
         )
@@ -324,7 +342,7 @@ class ReplicationCommandHandler:
             self._pending_batches.pop(stream_name, [])
 
             # Find where we previously streamed up to.
-            current_token = stream.current_token()
+            current_token = stream.current_token(cmd.instance_name)
 
             # If the position token matches our current token then we're up to
             # date and there's nothing to do. Otherwise, fetch all updates
@@ -361,7 +379,9 @@ class ReplicationCommandHandler:
             logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
 
             # We've now caught up to position sent to us, notify handler.
-            await self._replication_data_handler.on_position(stream_name, cmd.token)
+            await self._replication_data_handler.on_position(
+                cmd.stream_name, cmd.instance_name, cmd.token
+            )
 
             self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
@@ -489,12 +509,6 @@ class ReplicationCommandHandler:
         cmd = RemovePusherCommand(app_id, push_key, user_id)
         self.send_command(cmd)
 
-    def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
-        """Poke the master to invalidate a cache.
-        """
-        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
-        self.send_command(cmd)
-
     def send_user_ip(
         self,
         user_id: str,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 55bfa71dfd..e776b63183 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -70,7 +70,6 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         logger.info("Connected to redis")
         super().connectionMade()
         run_as_background_process("subscribe-replication", self._send_subscribe)
-        self.handler.new_connection(self)
 
     async def _send_subscribe(self):
         # it's important to make sure that we only send the REPLICATE command once we
@@ -81,9 +80,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         logger.info(
             "Successfully subscribed to redis stream, sending REPLICATE command"
         )
+        self.handler.new_connection(self)
         await self._async_send_command(ReplicateCommand())
         logger.info("REPLICATE successfully sent")
 
+        # We send out our positions when there is a new connection in case the
+        # other side missed updates. We do this for Redis connections as the
+        # otherside won't know we've connected and so won't issue a REPLICATE.
+        self.handler.send_positions_to_connection(self)
+
     def messageReceived(self, pattern: str, channel: str, message: str):
         """Received a message from redis.
         """
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 33d2f589ac..41569305df 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,6 @@
 
 import logging
 import random
-from typing import Dict, List
 
 from prometheus_client import Counter
 
@@ -25,7 +24,6 @@ from twisted.internet.protocol import Factory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
 from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
@@ -71,26 +69,11 @@ class ReplicationStreamer(object):
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
+        self._instance_name = hs.get_instance_name()
 
         self._replication_torture_level = hs.config.replication_torture_level
 
-        # Work out list of streams that this instance is the source of.
-        self.streams = []  # type: List[Stream]
-        if hs.config.worker_app is None:
-            for stream in STREAMS_MAP.values():
-                if stream == FederationStream and hs.config.send_federation:
-                    # We only support federation stream if federation sending
-                    # hase been disabled on the master.
-                    continue
-
-                self.streams.append(stream(hs))
-
-        self.streams_by_name = {stream.NAME: stream for stream in self.streams}
-
-        # Only bother registering the notifier callback if we have streams to
-        # publish.
-        if self.streams:
-            self.notifier.add_replication_callback(self.on_notifier_poke)
+        self.notifier.add_replication_callback(self.on_notifier_poke)
 
         # Keeps track of whether we are currently checking for updates
         self.is_looping = False
@@ -98,10 +81,8 @@ class ReplicationStreamer(object):
 
         self.command_handler = hs.get_tcp_replication()
 
-    def get_streams(self) -> Dict[str, Stream]:
-        """Get a mapp from stream name to stream instance.
-        """
-        return self.streams_by_name
+        # Set of streams to replicate.
+        self.streams = self.command_handler.get_streams_to_replicate()
 
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
@@ -145,7 +126,9 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.current_token():
+                        if stream.last_token == stream.current_token(
+                            self._instance_name
+                        ):
                             continue
 
                         if self._replication_torture_level:
@@ -157,7 +140,7 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.current_token(),
+                            stream.current_token(self._instance_name),
                         )
                         try:
                             updates, current_token, limited = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b0f87c365b..b48a6a3e91 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,19 +95,25 @@ class Stream(object):
     def __init__(
         self,
         local_instance_name: str,
-        current_token_function: Callable[[], Token],
+        current_token_function: Callable[[str], Token],
         update_function: UpdateFunction,
     ):
         """Instantiate a Stream
 
-        current_token_function and update_function are callbacks which should be
-        implemented by subclasses.
+        `current_token_function` and `update_function` are callbacks which
+        should be implemented by subclasses.
 
-        current_token_function is called to get the current token of the underlying
-        stream.
+        `current_token_function` takes an instance name, which is a writer to
+        the stream, and returns the position in the stream of the writer (as
+        viewed from the current process). On the writer process this is where
+        the writer has successfully written up to, whereas on other processes
+        this is the position which we have received updates up to over
+        replication. (Note that most streams have a single writer and so their
+        implementations ignore the instance name passed in).
 
-        update_function is called to get updates for this stream between a pair of
-        stream tokens. See the UpdateFunction type definition for more info.
+        `update_function` is called to get updates for this stream between a
+        pair of stream tokens. See the `UpdateFunction` type definition for more
+        info.
 
         Args:
             local_instance_name: The instance name of the current process
@@ -119,13 +125,13 @@ class Stream(object):
         self.update_function = update_function
 
         # The token from which we last asked for updates
-        self.last_token = self.current_token()
+        self.last_token = self.current_token(self.local_instance_name)
 
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.last_token = self.current_token()
+        self.last_token = self.current_token(self.local_instance_name)
 
     async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
@@ -137,7 +143,7 @@ class Stream(object):
             position in stream, and `limited` is whether there are more updates
             to fetch.
         """
-        current_token = self.current_token()
+        current_token = self.current_token(self.local_instance_name)
         updates, current_token, limited = await self.get_updates_since(
             self.local_instance_name, self.last_token, current_token
         )
@@ -169,6 +175,16 @@ class Stream(object):
         return updates, upto_token, limited
 
 
+def current_token_without_instance(
+    current_token: Callable[[], int]
+) -> Callable[[str], int]:
+    """Takes a current token callback function for a single writer stream
+    that doesn't take an instance name parameter and wraps it in a function that
+    does accept an instance name parameter but ignores it.
+    """
+    return lambda instance_name: current_token()
+
+
 def db_query_to_update_function(
     query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
 ) -> UpdateFunction:
@@ -234,7 +250,7 @@ class BackfillStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_current_backfill_token,
+            current_token_without_instance(store.get_current_backfill_token),
             db_query_to_update_function(store.get_all_new_backfill_event_rows),
         )
 
@@ -270,7 +286,9 @@ class PresenceStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(), store.get_current_presence_token, update_function
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_presence_token),
+            update_function,
         )
 
 
@@ -295,7 +313,9 @@ class TypingStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(), typing_handler.get_current_token, update_function
+            hs.get_instance_name(),
+            current_token_without_instance(typing_handler.get_current_token),
+            update_function,
         )
 
 
@@ -318,7 +338,7 @@ class ReceiptsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_max_receipt_stream_id,
+            current_token_without_instance(store.get_max_receipt_stream_id),
             db_query_to_update_function(store.get_all_updated_receipts),
         )
 
@@ -338,7 +358,7 @@ class PushRulesStream(Stream):
             hs.get_instance_name(), self._current_token, self._update_function
         )
 
-    def _current_token(self) -> int:
+    def _current_token(self, instance_name: str) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
@@ -372,7 +392,7 @@ class PushersStream(Stream):
 
         super().__init__(
             hs.get_instance_name(),
-            store.get_pushers_stream_token,
+            current_token_without_instance(store.get_pushers_stream_token),
             db_query_to_update_function(store.get_all_updated_pushers_rows),
         )
 
@@ -401,12 +421,26 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
-        store = hs.get_datastore()
+        self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token,
-            db_query_to_update_function(store.get_all_updated_caches),
+            self.store.get_cache_stream_token,
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, upto_token: int, limit: int
+    ):
+        rows = await self.store.get_all_updated_caches(
+            instance_name, from_token, upto_token, limit
         )
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) >= limit:
+            upto_token = updates[-1][0]
+            limited = True
+
+        return updates, upto_token, limited
 
 
 class PublicRoomsStream(Stream):
@@ -430,7 +464,7 @@ class PublicRoomsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_current_public_room_stream_id,
+            current_token_without_instance(store.get_current_public_room_stream_id),
             db_query_to_update_function(store.get_all_new_public_rooms),
         )
 
@@ -451,7 +485,7 @@ class DeviceListsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_device_stream_token,
+            current_token_without_instance(store.get_device_stream_token),
             db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
         )
 
@@ -469,7 +503,7 @@ class ToDeviceStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_to_device_stream_token,
+            current_token_without_instance(store.get_to_device_stream_token),
             db_query_to_update_function(store.get_all_new_device_messages),
         )
 
@@ -489,7 +523,7 @@ class TagAccountDataStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_max_account_data_stream_id,
+            current_token_without_instance(store.get_max_account_data_stream_id),
             db_query_to_update_function(store.get_all_updated_tags),
         )
 
@@ -509,7 +543,7 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self.store.get_max_account_data_stream_id,
+            current_token_without_instance(self.store.get_max_account_data_stream_id),
             db_query_to_update_function(self._update_function),
         )
 
@@ -540,7 +574,7 @@ class GroupServerStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_group_stream_token,
+            current_token_without_instance(store.get_group_stream_token),
             db_query_to_update_function(store.get_all_groups_changes),
         )
 
@@ -558,7 +592,7 @@ class UserSignatureStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_device_stream_token,
+            current_token_without_instance(store.get_device_stream_token),
             db_query_to_update_function(
                 store.get_all_user_signature_changes_for_remotes
             ),
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 890e75d827..f370390331 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -20,7 +20,7 @@ from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token
+from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
 
 
 """Handling of the 'events' replication stream
@@ -119,7 +119,7 @@ class EventsStream(Stream):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self._store.get_current_events_token,
+            current_token_without_instance(self._store.get_current_events_token),
             self._update_function,
         )
 
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index e8bd52e389..9bcd13b009 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,11 @@
 # limitations under the License.
 from collections import namedtuple
 
-from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
+from synapse.replication.tcp.streams._base import (
+    Stream,
+    current_token_without_instance,
+    make_http_update_function,
+)
 
 
 class FederationStream(Stream):
@@ -35,21 +39,35 @@ class FederationStream(Stream):
     ROW_TYPE = FederationStreamRow
 
     def __init__(self, hs):
-        # Not all synapse instances will have a federation sender instance,
-        # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
-        # so we stub the stream out when that is the case.
-        if hs.config.worker_app is None or hs.should_send_federation():
+        if hs.config.worker_app is None:
+            # master process: get updates from the FederationRemoteSendQueue.
+            # (if the master is configured to send federation itself, federation_sender
+            # will be a real FederationSender, which has stubs for current_token and
+            # get_replication_rows.)
             federation_sender = hs.get_federation_sender()
-            current_token = federation_sender.get_current_token
-            update_function = db_query_to_update_function(
-                federation_sender.get_replication_rows
+            current_token = current_token_without_instance(
+                federation_sender.get_current_token
             )
+            update_function = federation_sender.get_replication_rows
+
+        elif hs.should_send_federation():
+            # federation sender: Query master process
+            update_function = make_http_update_function(hs, self.NAME)
+            current_token = self._stub_current_token
+
         else:
-            current_token = lambda: 0
+            # other worker: stub out the update function (we're not interested in
+            # any updates so when we get a POSITION we do nothing)
             update_function = self._stub_update_function
+            current_token = self._stub_current_token
 
         super().__init__(hs.get_instance_name(), current_token, update_function)
 
     @staticmethod
+    def _stub_current_token(instance_name: str) -> int:
+        # dummy current-token method for use on workers
+        return 0
+
+    @staticmethod
     async def _stub_update_function(instance_name, from_token, upto_token, limit):
         return [], upto_token, False
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
index f0d7c66e1b..6b94d8c367 100644
--- a/synapse/res/templates/notice_expiry.html
+++ b/synapse/res/templates/notice_expiry.html
@@ -30,7 +30,7 @@
                         <tr>
                           <td colspan="2">
                             <div class="noticetext">Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.</div>
-                            <div class="noticetext">To extend the validity of your account, please click on the link bellow (or copy and paste it into a new browser tab):</div>
+                            <div class="noticetext">To extend the validity of your account, please click on the link below (or copy and paste it into a new browser tab):</div>
                             <div class="noticetext"><a href="{{ url }}">{{ url }}</a></div>
                           </td>
                         </tr>
diff --git a/synapse/res/templates/notice_expiry.txt b/synapse/res/templates/notice_expiry.txt
index 41f1c4279c..4ec27e8831 100644
--- a/synapse/res/templates/notice_expiry.txt
+++ b/synapse/res/templates/notice_expiry.txt
@@ -2,6 +2,6 @@ Hi {{ display_name }},
 
 Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
 
-To extend the validity of your account, please click on the link bellow (or copy and paste it to a new browser tab):
+To extend the validity of your account, please click on the link below (or copy and paste it to a new browser tab):
 
 {{ url }}
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
new file mode 100644
index 0000000000..43a211386b
--- /dev/null
+++ b/synapse/res/templates/sso_error.html
@@ -0,0 +1,18 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <title>SSO error</title>
+</head>
+<body>
+    <p>Oops! Something went wrong during authentication.</p>
+    <p>
+        Try logging in again from your Matrix client and if the problem persists
+        please contact the server's administrator.
+    </p>
+    <p>Error: <code>{{ error }}</code></p>
+    {% if error_description %}
+    <pre><code>{{ error_description }}</code></pre>
+    {% endif %}
+</body>
+</html>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index ed70d448a1..6b85148a32 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -32,6 +32,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
 from synapse.rest.admin.rooms import (
     JoinRoomAliasServlet,
     ListRoomRestServlet,
+    RoomRestServlet,
     ShutdownRoomRestServlet,
 )
 from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -193,6 +194,7 @@ def register_servlets(hs, http_server):
     """
     register_servlets_for_client_rest_resource(hs, http_server)
     ListRoomRestServlet(hs).register(http_server)
+    RoomRestServlet(hs).register(http_server)
     JoinRoomAliasServlet(hs).register(http_server)
     PurgeRoomServlet(hs).register(http_server)
     SendServerNoticeServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index d1bdb64111..7d40001988 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -26,6 +26,7 @@ from synapse.http.servlet import (
 )
 from synapse.rest.admin._base import (
     admin_patterns,
+    assert_requester_is_admin,
     assert_user_is_admin,
     historical_admin_path_patterns,
 )
@@ -169,7 +170,7 @@ class ListRoomRestServlet(RestServlet):
     in a dictionary containing room information. Supports pagination.
     """
 
-    PATTERNS = admin_patterns("/rooms")
+    PATTERNS = admin_patterns("/rooms$")
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
@@ -253,6 +254,29 @@ class ListRoomRestServlet(RestServlet):
         return 200, response
 
 
+class RoomRestServlet(RestServlet):
+    """Get room details.
+
+    TODO: Add on_POST to allow room creation without joining the room
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
+
+    def __init__(self, hs):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+
+    async def on_GET(self, request, room_id):
+        await assert_requester_is_admin(self.auth, request)
+
+        ret = await self.store.get_room_with_stats(room_id)
+        if not ret:
+            raise NotFoundError("Room not found")
+
+        return 200, ret
+
+
 class JoinRoomAliasServlet(RestServlet):
 
     PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 4de2f97d06..de7eca21f8 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -83,6 +83,7 @@ class LoginRestServlet(RestServlet):
         self.jwt_algorithm = hs.config.jwt_algorithm
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
+        self.oidc_enabled = hs.config.oidc_enabled
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self.handlers = hs.get_handlers()
@@ -96,9 +97,7 @@ class LoginRestServlet(RestServlet):
         flows = []
         if self.jwt_enabled:
             flows.append({"type": LoginRestServlet.JWT_TYPE})
-        if self.saml2_enabled:
-            flows.append({"type": LoginRestServlet.SSO_TYPE})
-            flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+
         if self.cas_enabled:
             flows.append({"type": LoginRestServlet.SSO_TYPE})
 
@@ -114,6 +113,11 @@ class LoginRestServlet(RestServlet):
             # fall back to the fallback API if they don't understand one of the
             # login flow types returned.
             flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+        elif self.saml2_enabled:
+            flows.append({"type": LoginRestServlet.SSO_TYPE})
+            flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+        elif self.oidc_enabled:
+            flows.append({"type": LoginRestServlet.SSO_TYPE})
 
         flows.extend(
             ({"type": t} for t in self.auth_handler.get_supported_login_types())
@@ -465,6 +469,22 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
         return self._saml_handler.handle_redirect_request(client_redirect_url)
 
 
+class OIDCRedirectServlet(RestServlet):
+    """Implementation for /login/sso/redirect for the OIDC login flow."""
+
+    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
+
+    def __init__(self, hs):
+        self._oidc_handler = hs.get_oidc_handler()
+
+    async def on_GET(self, request):
+        args = request.args
+        if b"redirectUrl" not in args:
+            return 400, "Redirect URL not specified for SSO auth"
+        client_redirect_url = args[b"redirectUrl"][0]
+        await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
+
+
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
     if hs.config.cas_enabled:
@@ -472,3 +492,5 @@ def register_servlets(hs, http_server):
         CasTicketServlet(hs).register(http_server)
     elif hs.config.saml2_enabled:
         SAMLRedirectServlet(hs).register(http_server)
+    elif hs.config.oidc_enabled:
+        OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py
new file mode 100644
index 0000000000..d958dd65bb
--- /dev/null
+++ b/synapse/rest/oidc/__init__.py
@@ -0,0 +1,27 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.web.resource import Resource
+
+from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCResource(Resource):
+    def __init__(self, hs):
+        Resource.__init__(self)
+        self.putChild(b"callback", OIDCCallbackResource(hs))
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py
new file mode 100644
index 0000000000..c03194f001
--- /dev/null
+++ b/synapse/rest/oidc/callback_resource.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.http.server import DirectServeResource, wrap_html_request_handler
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCCallbackResource(DirectServeResource):
+    isLeaf = 1
+
+    def __init__(self, hs):
+        super().__init__()
+        self._oidc_handler = hs.get_oidc_handler()
+
+    @wrap_html_request_handler
+    async def _async_render_GET(self, request):
+        return await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/server.py b/synapse/server.py
index bf97a16c09..b4aea81e24 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -204,6 +204,7 @@ class HomeServer(object):
         "account_validity_handler",
         "cas_handler",
         "saml_handler",
+        "oidc_handler",
         "event_client_serializer",
         "password_policy_handler",
         "storage",
@@ -562,6 +563,11 @@ class HomeServer(object):
 
         return SamlHandler(self)
 
+    def build_oidc_handler(self):
+        from synapse.handlers.oidc_handler import OidcHandler
+
+        return OidcHandler(self)
+
     def build_event_client_serializer(self):
         return EventClientSerializer(self)
 
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 18043a2593..31a9cc0389 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -13,6 +13,7 @@ import synapse.handlers.device
 import synapse.handlers.e2e_keys
 import synapse.handlers.message
 import synapse.handlers.presence
+import synapse.handlers.register
 import synapse.handlers.room
 import synapse.handlers.room_member
 import synapse.handlers.set_password
@@ -128,3 +129,7 @@ class HomeServer(object):
         pass
     def get_storage(self) -> synapse.storage.Storage:
         pass
+    def get_registration_handler(self) -> synapse.handlers.register.RegistrationHandler:
+        pass
+    def get_macaroon_generator(self) -> synapse.handlers.auth.MacaroonGenerator:
+        pass
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 4afefc6b1d..2fa529fcd0 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -35,7 +35,6 @@ from synapse.state import v1, v2
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.metrics import Measure, measure_func
 
@@ -53,7 +52,6 @@ state_groups_histogram = Histogram(
 KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
-SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache")
 EVICTION_TIMEOUT_SECONDS = 60 * 60
 
 
@@ -447,7 +445,7 @@ class StateResolutionHandler(object):
         self._state_cache = ExpiringCache(
             cache_name="state_cache",
             clock=self.clock,
-            max_len=SIZE_OF_CACHE,
+            max_len=100000,
             expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
             iterable=True,
             reset_expiry_on_get=True,
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 13de5f1f62..59073c0a42 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -47,6 +47,9 @@ class SQLBaseStore(metaclass=ABCMeta):
         self.db = database
         self.rand = random.SystemRandom()
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        pass
+
     def _invalidate_state_caches(self, room_id, members_changed):
         """Invalidates caches that are based on the current state, but does
         not stream invalidations down replication.
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index e1d03429ca..791961b296 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -15,6 +15,7 @@
 
 import logging
 
+from synapse.storage.data_stores.main.events import PersistEventsStore
 from synapse.storage.data_stores.state import StateGroupDataStore
 from synapse.storage.database import Database, make_conn
 from synapse.storage.engines import create_engine
@@ -39,6 +40,7 @@ class DataStores(object):
         self.databases = []
         self.main = None
         self.state = None
+        self.persist_events = None
 
         for database_config in hs.config.database.databases:
             db_name = database_config.name
@@ -64,6 +66,13 @@ class DataStores(object):
 
                     self.main = main_store_class(database, db_conn, hs)
 
+                    # If we're on a process that can persist events (currently
+                    # master), also instantiate a `PersistEventsStore`
+                    if hs.config.worker.worker_app is None:
+                        self.persist_events = PersistEventsStore(
+                            hs, database, self.main
+                        )
+
                 if "state" in database_config.data_stores:
                     logger.info("Starting 'state' data store")
 
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index ceba10882c..5df9dce79d 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -26,13 +26,15 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     ChainedIdGenerator,
     IdGenerator,
+    MultiWriterIdGenerator,
     StreamIdGenerator,
 )
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from .account_data import AccountDataStore
 from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .cache import CacheInvalidationStore
+from .cache import CacheInvalidationWorkerStore
+from .censor_events import CensorEventsStore
 from .client_ips import ClientIpStore
 from .deviceinbox import DeviceInboxStore
 from .devices import DeviceStore
@@ -41,16 +43,17 @@ from .e2e_room_keys import EndToEndRoomKeyStore
 from .end_to_end_keys import EndToEndKeyStore
 from .event_federation import EventFederationStore
 from .event_push_actions import EventPushActionsStore
-from .events import EventsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
 from .filtering import FilteringStore
 from .group_server import GroupServerStore
 from .keys import KeyStore
 from .media_repository import MediaRepositoryStore
+from .metrics import ServerMetricsStore
 from .monthly_active_users import MonthlyActiveUsersStore
 from .openid import OpenIdStore
 from .presence import PresenceStore, UserPresenceState
 from .profile import ProfileStore
+from .purge_events import PurgeEventsStore
 from .push_rule import PushRuleStore
 from .pusher import PusherStore
 from .receipts import ReceiptsStore
@@ -87,7 +90,7 @@ class DataStore(
     StateStore,
     SignatureStore,
     ApplicationServiceStore,
-    EventsStore,
+    PurgeEventsStore,
     EventFederationStore,
     MediaRepositoryStore,
     RejectionsStore,
@@ -112,8 +115,10 @@ class DataStore(
     MonthlyActiveUsersStore,
     StatsStore,
     RelationsStore,
-    CacheInvalidationStore,
+    CensorEventsStore,
     UIAuthStore,
+    CacheInvalidationWorkerStore,
+    ServerMetricsStore,
 ):
     def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
@@ -170,8 +175,14 @@ class DataStore(
         )
 
         if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = StreamIdGenerator(
-                db_conn, "cache_invalidation_stream", "stream_id"
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                instance_name="master",
+                table="cache_invalidation_stream_by_instance",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="cache_invalidation_stream_seq",
             )
         else:
             self._cache_id_gen = None
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 4dc5da3fe8..342a87a46b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,11 +16,10 @@
 
 import itertools
 import logging
-from typing import Any, Iterable, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Any, Iterable, Optional
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
@@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
 class CacheInvalidationWorkerStore(SQLBaseStore):
-    def get_all_updated_caches(self, last_id, current_id, limit):
+    def __init__(self, database: Database, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._instance_name = hs.get_instance_name()
+
+    async def get_all_updated_caches(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ):
+        """Fetches cache invalidation rows between the two given IDs written
+        by the given instance. Returns at most `limit` rows.
+        """
+
         if last_id == current_id:
-            return defer.succeed([])
+            return []
 
         def get_all_updated_caches_txn(txn):
             # We purposefully don't bound by the current token, as we want to
             # send across cache invalidations as quickly as possible. Cache
             # invalidations are idempotent, so duplicates are fine.
-            sql = (
-                "SELECT stream_id, cache_func, keys, invalidation_ts"
-                " FROM cache_invalidation_stream"
-                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, limit))
+            sql = """
+                SELECT stream_id, cache_func, keys, invalidation_ts
+                FROM cache_invalidation_stream_by_instance
+                WHERE stream_id > ? AND instance_name = ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, instance_name, limit))
             return txn.fetchall()
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "get_all_updated_caches", get_all_updated_caches_txn
         )
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == "caches":
+            if self._cache_id_gen:
+                self._cache_id_gen.advance(instance_name, token)
 
-class CacheInvalidationStore(CacheInvalidationWorkerStore):
-    async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
-        """Invalidates the cache and adds it to the cache stream so slaves
-        will know to invalidate their caches.
+            for row in rows:
+                if row.cache_func == CURRENT_STATE_CACHE_NAME:
+                    if row.keys is None:
+                        raise Exception(
+                            "Can't send an 'invalidate all' for current state cache"
+                        )
 
-        This should only be used to invalidate caches where slaves won't
-        otherwise know from other replication streams that the cache should
-        be invalidated.
-        """
-        cache_func = getattr(self, cache_name, None)
-        if not cache_func:
-            return
-
-        cache_func.invalidate(keys)
-        await self.runInteraction(
-            "invalidate_cache_and_stream",
-            self._send_invalidation_to_replication,
-            cache_func.__name__,
-            keys,
-        )
+                    room_id = row.keys[0]
+                    members_changed = set(row.keys[1:])
+                    self._invalidate_state_caches(room_id, members_changed)
+                else:
+                    self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _invalidate_cache_and_stream(self, txn, cache_func, keys):
         """Invalidates the cache and adds it to the cache stream so slaves
@@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
             # the transaction. However, we want to only get an ID when we want
             # to use it, here, so we need to call __enter__ manually, and have
             # __exit__ called after the transaction finishes.
-            ctx = self._cache_id_gen.get_next()
-            stream_id = ctx.__enter__()
-            txn.call_on_exception(ctx.__exit__, None, None, None)
-            txn.call_after(ctx.__exit__, None, None, None)
+            stream_id = self._cache_id_gen.get_next_txn(txn)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
             if keys is not None:
@@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
 
             self.db.simple_insert_txn(
                 txn,
-                table="cache_invalidation_stream",
+                table="cache_invalidation_stream_by_instance",
                 values={
                     "stream_id": stream_id,
+                    "instance_name": self._instance_name,
                     "cache_func": cache_name,
                     "keys": keys,
                     "invalidation_ts": self.clock.time_msec(),
                 },
             )
 
-    def get_cache_stream_token(self):
+    def get_cache_stream_token(self, instance_name):
         if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token()
+            return self._cache_id_gen.get_current_token(instance_name)
         else:
             return 0
diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py
new file mode 100644
index 0000000000..2d48261724
--- /dev/null
+++ b/synapse/storage/data_stores/main/censor_events.py
@@ -0,0 +1,208 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.internet import defer
+
+from synapse.events.utils import prune_event_dict
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.data_stores.main.events import encode_json
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs: "HomeServer"):
+        super().__init__(database, db_conn, hs)
+
+        def _censor_redactions():
+            return run_as_background_process(
+                "_censor_redactions", self._censor_redactions
+            )
+
+        if self.hs.config.redaction_retention_period is not None:
+            hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+
+    async def _censor_redactions(self):
+        """Censors all redactions older than the configured period that haven't
+        been censored yet.
+
+        By censor we mean update the event_json table with the redacted event.
+        """
+
+        if self.hs.config.redaction_retention_period is None:
+            return
+
+        if not (
+            await self.db.updates.has_completed_background_update(
+                "redactions_have_censored_ts_idx"
+            )
+        ):
+            # We don't want to run this until the appropriate index has been
+            # created.
+            return
+
+        before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+
+        # We fetch all redactions that:
+        #   1. point to an event we have,
+        #   2. has a received_ts from before the cut off, and
+        #   3. we haven't yet censored.
+        #
+        # This is limited to 100 events to ensure that we don't try and do too
+        # much at once. We'll get called again so this should eventually catch
+        # up.
+        sql = """
+            SELECT redactions.event_id, redacts FROM redactions
+            LEFT JOIN events AS original_event ON (
+                redacts = original_event.event_id
+            )
+            WHERE NOT have_censored
+            AND redactions.received_ts <= ?
+            ORDER BY redactions.received_ts ASC
+            LIMIT ?
+        """
+
+        rows = await self.db.execute(
+            "_censor_redactions_fetch", None, sql, before_ts, 100
+        )
+
+        updates = []
+
+        for redaction_id, event_id in rows:
+            redaction_event = await self.get_event(redaction_id, allow_none=True)
+            original_event = await self.get_event(
+                event_id, allow_rejected=True, allow_none=True
+            )
+
+            # The SQL above ensures that we have both the redaction and
+            # original event, so if the `get_event` calls return None it
+            # means that the redaction wasn't allowed. Either way we know that
+            # the result won't change so we mark the fact that we've checked.
+            if (
+                redaction_event
+                and original_event
+                and original_event.internal_metadata.is_redacted()
+            ):
+                # Redaction was allowed
+                pruned_json = encode_json(
+                    prune_event_dict(
+                        original_event.room_version, original_event.get_dict()
+                    )
+                )
+            else:
+                # Redaction wasn't allowed
+                pruned_json = None
+
+            updates.append((redaction_id, event_id, pruned_json))
+
+        def _update_censor_txn(txn):
+            for redaction_id, event_id, pruned_json in updates:
+                if pruned_json:
+                    self._censor_event_txn(txn, event_id, pruned_json)
+
+                self.db.simple_update_one_txn(
+                    txn,
+                    table="redactions",
+                    keyvalues={"event_id": redaction_id},
+                    updatevalues={"have_censored": True},
+                )
+
+        await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+
+    def _censor_event_txn(self, txn, event_id, pruned_json):
+        """Censor an event by replacing its JSON in the event_json table with the
+        provided pruned JSON.
+
+        Args:
+            txn (LoggingTransaction): The database transaction.
+            event_id (str): The ID of the event to censor.
+            pruned_json (str): The pruned JSON
+        """
+        self.db.simple_update_one_txn(
+            txn,
+            table="event_json",
+            keyvalues={"event_id": event_id},
+            updatevalues={"json": pruned_json},
+        )
+
+    @defer.inlineCallbacks
+    def expire_event(self, event_id):
+        """Retrieve and expire an event that has expired, and delete its associated
+        expiry timestamp. If the event can't be retrieved, delete its associated
+        timestamp so we don't try to expire it again in the future.
+
+        Args:
+             event_id (str): The ID of the event to delete.
+        """
+        # Try to retrieve the event's content from the database or the event cache.
+        event = yield self.get_event(event_id)
+
+        def delete_expired_event_txn(txn):
+            # Delete the expiry timestamp associated with this event from the database.
+            self._delete_event_expiry_txn(txn, event_id)
+
+            if not event:
+                # If we can't find the event, log a warning and delete the expiry date
+                # from the database so that we don't try to expire it again in the
+                # future.
+                logger.warning(
+                    "Can't expire event %s because we don't have it.", event_id
+                )
+                return
+
+            # Prune the event's dict then convert it to JSON.
+            pruned_json = encode_json(
+                prune_event_dict(event.room_version, event.get_dict())
+            )
+
+            # Update the event_json table to replace the event's JSON with the pruned
+            # JSON.
+            self._censor_event_txn(txn, event.event_id, pruned_json)
+
+            # We need to invalidate the event cache entry for this event because we
+            # changed its content in the database. We can't call
+            # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+            # right type.
+            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            # Send that invalidation to replication so that other workers also invalidate
+            # the event cache.
+            self._send_invalidation_to_replication(
+                txn, "_get_event_cache", (event.event_id,)
+            )
+
+        yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+    def _delete_event_expiry_txn(self, txn, event_id):
+        """Delete the expiry timestamp associated with an event ID without deleting the
+        actual event.
+
+        Args:
+            txn (LoggingTransaction): The transaction to use to perform the deletion.
+            event_id (str): The event ID to delete the associated expiry timestamp of.
+        """
+        return self.db.simple_delete_txn(
+            txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+        )
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 92bc06919b..71f8d43a76 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -22,7 +22,6 @@ from twisted.internet import defer
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import Database, make_tuple_comparison_clause
-from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.descriptors import Cache
 
 logger = logging.getLogger(__name__)
@@ -361,7 +360,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
 
         self.client_ip_last_seen = Cache(
-            name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+            name="client_ip_last_seen", keylen=4, max_entries=50000
         )
 
         super(ClientIpStore, self).__init__(database, db_conn, hs)
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 03f5141e6c..fe6d6ecfe0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -55,6 +55,10 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 
 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
+BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = (
+    "drop_device_lists_outbound_last_success_non_unique_idx"
+)
+
 
 class DeviceWorkerStore(SQLBaseStore):
     def get_device(self, user_id, device_id):
@@ -342,32 +346,23 @@ class DeviceWorkerStore(SQLBaseStore):
 
     def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
         # We update the device_lists_outbound_last_success with the successfully
-        # poked users. We do the join to see which users need to be inserted and
-        # which updated.
+        # poked users.
         sql = """
-            SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+            SELECT user_id, coalesce(max(o.stream_id), 0)
             FROM device_lists_outbound_pokes as o
-            LEFT JOIN device_lists_outbound_last_success as s
-                USING (destination, user_id)
             WHERE destination = ? AND o.stream_id <= ?
             GROUP BY user_id
         """
         txn.execute(sql, (destination, stream_id))
         rows = txn.fetchall()
 
-        sql = """
-            UPDATE device_lists_outbound_last_success
-            SET stream_id = ?
-            WHERE destination = ? AND user_id = ?
-        """
-        txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
-
-        sql = """
-            INSERT INTO device_lists_outbound_last_success
-            (destination, user_id, stream_id) VALUES (?, ?, ?)
-        """
-        txn.executemany(
-            sql, ((destination, row[0], row[1]) for row in rows if not row[2])
+        self.db.simple_upsert_many_txn(
+            txn=txn,
+            table="device_lists_outbound_last_success",
+            key_names=("destination", "user_id"),
+            key_values=((destination, user_id) for user_id, _ in rows),
+            value_names=("stream_id",),
+            value_values=((stream_id,) for _, stream_id in rows),
         )
 
         # Delete all sent outbound pokes
@@ -725,6 +720,21 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
         )
 
+        # create a unique index on device_lists_outbound_last_success
+        self.db.updates.register_background_index_update(
+            "device_lists_outbound_last_success_unique_idx",
+            index_name="device_lists_outbound_last_success_unique_idx",
+            table="device_lists_outbound_last_success",
+            columns=["destination", "user_id"],
+            unique=True,
+        )
+
+        # once that completes, we can remove the old non-unique index.
+        self.db.updates.register_background_update_handler(
+            BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX,
+            self._drop_device_lists_outbound_last_success_non_unique_idx,
+        )
+
     @defer.inlineCallbacks
     def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
         def f(conn):
@@ -799,6 +809,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
         return rows
 
+    async def _drop_device_lists_outbound_last_success_non_unique_idx(
+        self, progress, batch_size
+    ):
+        def f(txn):
+            txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx")
+
+        await self.db.runInteraction(
+            "drop_device_lists_outbound_last_success_non_unique_idx", f,
+        )
+        await self.db.updates._end_background_update(
+            BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX
+        )
+        return 1
+
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index bcf746b7ef..20698bfd16 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -25,7 +25,9 @@ from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
 
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -268,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
-    def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
-        """Returns a user's cross-signing key.
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            user_id (str): the user whose key is being requested
-            key_type (str): the type of key that is being requested: either 'master'
-                for a master key, 'self_signing' for a self-signing key, or
-                'user_signing' for a user-signing key
-            from_user_id (str): if specified, signatures made by this user on
-                the key will be included in the result
-
-        Returns:
-            dict of the key data or None if not found
-        """
-        sql = (
-            "SELECT keydata "
-            "  FROM e2e_cross_signing_keys "
-            " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
-        )
-        txn.execute(sql, (user_id, key_type))
-        row = txn.fetchone()
-        if not row:
-            return None
-        key = json.loads(row[0])
-
-        device_id = None
-        for k in key["keys"].values():
-            device_id = k
-
-        if from_user_id is not None:
-            sql = (
-                "SELECT key_id, signature "
-                "  FROM e2e_cross_signing_signatures "
-                " WHERE user_id = ? "
-                "   AND target_user_id = ? "
-                "   AND target_device_id = ? "
-            )
-            txn.execute(sql, (from_user_id, user_id, device_id))
-            row = txn.fetchone()
-            if row:
-                key.setdefault("signatures", {}).setdefault(from_user_id, {})[
-                    row[0]
-                ] = row[1]
-
-        return key
-
+    @defer.inlineCallbacks
     def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
         """Returns a user's cross-signing key.
 
@@ -329,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         Returns:
             dict of the key data or None if not found
         """
-        return self.db.runInteraction(
-            "get_e2e_cross_signing_key",
-            self._get_e2e_cross_signing_key_txn,
-            user_id,
-            key_type,
-            from_user_id,
-        )
+        res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+        user_keys = res.get(user_id)
+        if not user_keys:
+            return None
+        return user_keys.get(key_type)
 
     @cached(num_args=1)
     def _get_bare_e2e_cross_signing_keys(self, user_id):
@@ -391,26 +345,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         """
         result = {}
 
-        batch_size = 100
-        chunks = [
-            user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
-        ]
-        for user_chunk in chunks:
-            sql = """
+        for user_chunk in batch_iter(user_ids, 100):
+            clause, params = make_in_list_sql_clause(
+                txn.database_engine, "k.user_id", user_chunk
+            )
+            sql = (
+                """
                 SELECT k.user_id, k.keytype, k.keydata, k.stream_id
                   FROM e2e_cross_signing_keys k
                   INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
                                 FROM e2e_cross_signing_keys
                                GROUP BY user_id, keytype) s
                  USING (user_id, stream_id, keytype)
-                 WHERE k.user_id IN (%s)
-            """ % (
-                ",".join("?" for u in user_chunk),
+                 WHERE
+            """
+                + clause
             )
-            query_params = []
-            query_params.extend(user_chunk)
 
-            txn.execute(sql, query_params)
+            txn.execute(sql, params)
             rows = self.db.cursor_to_dict(txn)
 
             for row in rows:
@@ -453,15 +405,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                     device_id = k
                 devices[(user_id, device_id)] = key_type
 
-        device_list = list(devices)
-
-        # split into batches
-        batch_size = 100
-        chunks = [
-            device_list[i : i + batch_size]
-            for i in range(0, len(device_list), batch_size)
-        ]
-        for user_chunk in chunks:
+        for batch in batch_iter(devices.keys(), size=100):
             sql = """
                 SELECT target_user_id, target_device_id, key_id, signature
                   FROM e2e_cross_signing_signatures
@@ -469,11 +413,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                    AND (%s)
             """ % (
                 " OR ".join(
-                    "(target_user_id = ? AND target_device_id = ?)" for d in devices
+                    "(target_user_id = ? AND target_device_id = ?)" for _ in batch
                 )
             )
             query_params = [from_user_id]
-            for item in devices:
+            for item in batch:
                 # item is a (user_id, device_id) tuple
                 query_params.extend(item)
 
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index b99439cc37..24ce8c4330 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -640,89 +640,6 @@ class EventFederationStore(EventFederationWorkerStore):
             self._delete_old_forward_extrem_cache, 60 * 60 * 1000
         )
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
-        min_depth = self._get_min_depth_interaction(txn, room_id)
-
-        if min_depth is not None and depth >= min_depth:
-            return
-
-        self.db.simple_upsert_txn(
-            txn,
-            table="room_depth",
-            keyvalues={"room_id": room_id},
-            values={"min_depth": depth},
-        )
-
-    def _handle_mult_prev_events(self, txn, events):
-        """
-        For the given event, update the event edges table and forward and
-        backward extremities tables.
-        """
-        self.db.simple_insert_many_txn(
-            txn,
-            table="event_edges",
-            values=[
-                {
-                    "event_id": ev.event_id,
-                    "prev_event_id": e_id,
-                    "room_id": ev.room_id,
-                    "is_state": False,
-                }
-                for ev in events
-                for e_id in ev.prev_event_ids()
-            ],
-        )
-
-        self._update_backward_extremeties(txn, events)
-
-    def _update_backward_extremeties(self, txn, events):
-        """Updates the event_backward_extremities tables based on the new/updated
-        events being persisted.
-
-        This is called for new events *and* for events that were outliers, but
-        are now being persisted as non-outliers.
-
-        Forward extremities are handled when we first start persisting the events.
-        """
-        events_by_room = {}
-        for ev in events:
-            events_by_room.setdefault(ev.room_id, []).append(ev)
-
-        query = (
-            "INSERT INTO event_backward_extremities (event_id, room_id)"
-            " SELECT ?, ? WHERE NOT EXISTS ("
-            " SELECT 1 FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-            " )"
-            " AND NOT EXISTS ("
-            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
-            " AND outlier = ?"
-            " )"
-        )
-
-        txn.executemany(
-            query,
-            [
-                (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
-                for ev in events
-                for e_id in ev.prev_event_ids()
-                if not ev.internal_metadata.is_outlier()
-            ],
-        )
-
-        query = (
-            "DELETE FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-        )
-        txn.executemany(
-            query,
-            [
-                (ev.event_id, ev.room_id)
-                for ev in events
-                if not ev.internal_metadata.is_outlier()
-            ],
-        )
-
     def _delete_old_forward_extrem_cache(self):
         def _delete_old_forward_extrem_cache_txn(txn):
             # Delete entries older than a month, while making sure we don't delete
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 8eed590929..0321274de2 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -652,69 +652,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             self._start_rotate_notifs, 30 * 60 * 1000
         )
 
-    def _set_push_actions_for_event_and_users_txn(
-        self, txn, events_and_contexts, all_events_and_contexts
-    ):
-        """Handles moving push actions from staging table to main
-        event_push_actions table for all events in `events_and_contexts`.
-
-        Also ensures that all events in `all_events_and_contexts` are removed
-        from the push action staging area.
-
-        Args:
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
-        """
-
-        sql = """
-            INSERT INTO event_push_actions (
-                room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight
-            )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
-            FROM event_push_actions_staging
-            WHERE event_id = ?
-        """
-
-        if events_and_contexts:
-            txn.executemany(
-                sql,
-                (
-                    (
-                        event.room_id,
-                        event.internal_metadata.stream_ordering,
-                        event.depth,
-                        event.event_id,
-                    )
-                    for event, _ in events_and_contexts
-                ),
-            )
-
-        for event, _ in events_and_contexts:
-            user_ids = self.db.simple_select_onecol_txn(
-                txn,
-                table="event_push_actions_staging",
-                keyvalues={"event_id": event.event_id},
-                retcol="user_id",
-            )
-
-            for uid in user_ids:
-                txn.call_after(
-                    self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-                    (event.room_id, uid),
-                )
-
-        # Now we delete the staging area for *all* events that were being
-        # persisted.
-        txn.executemany(
-            "DELETE FROM event_push_actions_staging WHERE event_id = ?",
-            ((event.event_id,) for event, _ in all_events_and_contexts),
-        )
-
     @defer.inlineCallbacks
     def get_push_actions_for_user(
         self, user_id, before=None, limit=50, only_highlight=False
@@ -763,17 +700,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         )
         return result[0] or 0
 
-    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
-        # Sad that we have to blow away the cache for the whole room here
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id,),
-        )
-        txn.execute(
-            "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
-            (room_id, event_id),
-        )
-
     def _remove_old_push_actions_before_txn(
         self, txn, room_id, user_id, stream_ordering
     ):
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index e71c23541d..a97f8b3934 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,39 +17,44 @@
 
 import itertools
 import logging
-from collections import Counter as c_counter, OrderedDict, namedtuple
+from collections import OrderedDict, namedtuple
 from functools import wraps
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
 
-from six import iteritems, text_type
+from six import integer_types, iteritems, text_type
 from six.moves import range
 
+import attr
 from canonicaljson import json
 from prometheus_client import Counter
 
 from twisted.internet import defer
 
 import synapse.metrics
-from synapse.api.constants import EventContentFields, EventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    Membership,
+    RelationTypes,
+)
 from synapse.api.room_versions import RoomVersions
+from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
-from synapse.events.utils import prune_event_dict
 from synapse.logging.utils import log_function
-from synapse.metrics import BucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.data_stores.main.event_federation import EventFederationStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.data_stores.main.search import SearchEntry
 from synapse.storage.database import Database, LoggingTransaction
-from synapse.storage.persist_events import DeltaState
-from synapse.types import RoomStreamToken, StateMap, get_domain_from_id
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import StateMap, get_domain_from_id
 from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.iterutils import batch_iter
 
+if TYPE_CHECKING:
+    from synapse.storage.data_stores.main import DataStore
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
@@ -94,58 +99,49 @@ def _retry_on_integrity_error(func):
     return f
 
 
-# inherits from EventFederationStore so that we can call _update_backward_extremities
-# and _handle_mult_prev_events (though arguably those could both be moved in here)
-class EventsStore(
-    StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
-):
-    def __init__(self, database: Database, db_conn, hs):
-        super(EventsStore, self).__init__(database, db_conn, hs)
+@attr.s(slots=True)
+class DeltaState:
+    """Deltas to use to update the `current_state_events` table.
 
-        # Collect metrics on the number of forward extremities that exist.
-        # Counter of number of extremities to count
-        self._current_forward_extremities_amount = c_counter()
+    Attributes:
+        to_delete: List of type/state_keys to delete from current state
+        to_insert: Map of state to upsert into current state
+        no_longer_in_room: The server is not longer in the room, so the room
+            should e.g. be removed from `current_state_events` table.
+    """
 
-        BucketCollector(
-            "synapse_forward_extremities",
-            lambda: self._current_forward_extremities_amount,
-            buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
-        )
+    to_delete = attr.ib(type=List[Tuple[str, str]])
+    to_insert = attr.ib(type=StateMap[str])
+    no_longer_in_room = attr.ib(type=bool, default=False)
 
-        # Read the extrems every 60 minutes
-        def read_forward_extremities():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "read_forward_extremities", self._read_forward_extremities
-            )
 
-        hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+class PersistEventsStore:
+    """Contains all the functions for writing events to the database.
 
-        def _censor_redactions():
-            return run_as_background_process(
-                "_censor_redactions", self._censor_redactions
-            )
+    Should only be instantiated on one process (when using a worker mode setup).
+
+    Note: This is not part of the `DataStore` mixin.
+    """
 
-        if self.hs.config.redaction_retention_period is not None:
-            hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+    def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"):
+        self.hs = hs
+        self.db = db
+        self.store = main_data_store
+        self.database_engine = db.engine
+        self._clock = hs.get_clock()
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
-    @defer.inlineCallbacks
-    def _read_forward_extremities(self):
-        def fetch(txn):
-            txn.execute(
-                """
-                select count(*) c from event_forward_extremities
-                group by room_id
-                """
-            )
-            return txn.fetchall()
+        # Ideally we'd move these ID gens here, unfortunately some other ID
+        # generators are chained off them so doing so is a bit of a PITA.
+        self._backfill_id_gen = self.store._backfill_id_gen  # type: StreamIdGenerator
+        self._stream_id_gen = self.store._stream_id_gen  # type: StreamIdGenerator
 
-        res = yield self.db.runInteraction("read_forward_extremities", fetch)
-        self._current_forward_extremities_amount = c_counter([x[0] for x in res])
+        # This should only exist on master for now
+        assert (
+            hs.config.worker.worker_app is None
+        ), "Can only instantiate PersistEventsStore on master"
 
     @_retry_on_integrity_error
     @defer.inlineCallbacks
@@ -237,10 +233,10 @@ class EventsStore(
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
             for room_id, new_state in iteritems(current_state_for_room):
-                self.get_current_state_ids.prefill((room_id,), new_state)
+                self.store.get_current_state_ids.prefill((room_id,), new_state)
 
             for room_id, latest_event_ids in iteritems(new_forward_extremeties):
-                self.get_latest_event_ids_in_room.prefill(
+                self.store.get_latest_event_ids_in_room.prefill(
                     (room_id,), list(latest_event_ids)
                 )
 
@@ -586,7 +582,7 @@ class EventsStore(
                 )
 
             txn.call_after(
-                self._curr_state_delta_stream_cache.entity_has_changed,
+                self.store._curr_state_delta_stream_cache.entity_has_changed,
                 room_id,
                 stream_id,
             )
@@ -606,10 +602,13 @@ class EventsStore(
 
             for member in members_changed:
                 txn.call_after(
-                    self.get_rooms_for_user_with_stream_ordering.invalidate, (member,)
+                    self.store.get_rooms_for_user_with_stream_ordering.invalidate,
+                    (member,),
                 )
 
-            self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+            self.store._invalidate_state_caches_and_stream(
+                txn, room_id, members_changed
+            )
 
     def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
         """Update the room version in the database based off current state
@@ -647,7 +646,9 @@ class EventsStore(
             self.db.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
             )
-            txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
+            txn.call_after(
+                self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
+            )
 
         self.db.simple_insert_many_txn(
             txn,
@@ -713,10 +714,10 @@ class EventsStore(
         depth_updates = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
-            txn.call_after(self._invalidate_get_event_cache, event.event_id)
+            txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
             if not backfilled:
                 txn.call_after(
-                    self._events_stream_cache.entity_has_changed,
+                    self.store._events_stream_cache.entity_has_changed,
                     event.room_id,
                     event.internal_metadata.stream_ordering,
                 )
@@ -1088,13 +1089,15 @@ class EventsStore(
 
         def prefill():
             for cache_entry in to_prefill:
-                self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry)
+                self.store._get_event_cache.prefill(
+                    (cache_entry[0].event_id,), cache_entry
+                )
 
         txn.call_after(prefill)
 
     def _store_redaction(self, txn, event):
         # invalidate the cache for the redacted event
-        txn.call_after(self._invalidate_get_event_cache, event.redacts)
+        txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
 
         self.db.simple_insert_txn(
             txn,
@@ -1106,783 +1109,484 @@ class EventsStore(
             },
         )
 
-    async def _censor_redactions(self):
-        """Censors all redactions older than the configured period that haven't
-        been censored yet.
+    def insert_labels_for_event_txn(
+        self, txn, event_id, labels, room_id, topological_ordering
+    ):
+        """Store the mapping between an event's ID and its labels, with one row per
+        (event_id, label) tuple.
 
-        By censor we mean update the event_json table with the redacted event.
+        Args:
+            txn (LoggingTransaction): The transaction to execute.
+            event_id (str): The event's ID.
+            labels (list[str]): A list of text labels.
+            room_id (str): The ID of the room the event was sent to.
+            topological_ordering (int): The position of the event in the room's topology.
         """
+        return self.db.simple_insert_many_txn(
+            txn=txn,
+            table="event_labels",
+            values=[
+                {
+                    "event_id": event_id,
+                    "label": label,
+                    "room_id": room_id,
+                    "topological_ordering": topological_ordering,
+                }
+                for label in labels
+            ],
+        )
 
-        if self.hs.config.redaction_retention_period is None:
-            return
-
-        if not (
-            await self.db.updates.has_completed_background_update(
-                "redactions_have_censored_ts_idx"
-            )
-        ):
-            # We don't want to run this until the appropriate index has been
-            # created.
-            return
-
-        before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+        """Save the expiry timestamp associated with a given event ID.
 
-        # We fetch all redactions that:
-        #   1. point to an event we have,
-        #   2. has a received_ts from before the cut off, and
-        #   3. we haven't yet censored.
-        #
-        # This is limited to 100 events to ensure that we don't try and do too
-        # much at once. We'll get called again so this should eventually catch
-        # up.
-        sql = """
-            SELECT redactions.event_id, redacts FROM redactions
-            LEFT JOIN events AS original_event ON (
-                redacts = original_event.event_id
-            )
-            WHERE NOT have_censored
-            AND redactions.received_ts <= ?
-            ORDER BY redactions.received_ts ASC
-            LIMIT ?
+        Args:
+            txn (LoggingTransaction): The database transaction to use.
+            event_id (str): The event ID the expiry timestamp is associated with.
+            expiry_ts (int): The timestamp at which to expire (delete) the event.
         """
-
-        rows = await self.db.execute(
-            "_censor_redactions_fetch", None, sql, before_ts, 100
+        return self.db.simple_insert_txn(
+            txn=txn,
+            table="event_expiry",
+            values={"event_id": event_id, "expiry_ts": expiry_ts},
         )
 
-        updates = []
+    def _store_event_reference_hashes_txn(self, txn, events):
+        """Store a hash for a PDU
+        Args:
+            txn (cursor):
+            events (list): list of Events.
+        """
 
-        for redaction_id, event_id in rows:
-            redaction_event = await self.get_event(redaction_id, allow_none=True)
-            original_event = await self.get_event(
-                event_id, allow_rejected=True, allow_none=True
+        vals = []
+        for event in events:
+            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+            vals.append(
+                {
+                    "event_id": event.event_id,
+                    "algorithm": ref_alg,
+                    "hash": memoryview(ref_hash_bytes),
+                }
             )
 
-            # The SQL above ensures that we have both the redaction and
-            # original event, so if the `get_event` calls return None it
-            # means that the redaction wasn't allowed. Either way we know that
-            # the result won't change so we mark the fact that we've checked.
-            if (
-                redaction_event
-                and original_event
-                and original_event.internal_metadata.is_redacted()
-            ):
-                # Redaction was allowed
-                pruned_json = encode_json(
-                    prune_event_dict(
-                        original_event.room_version, original_event.get_dict()
-                    )
-                )
-            else:
-                # Redaction wasn't allowed
-                pruned_json = None
-
-            updates.append((redaction_id, event_id, pruned_json))
-
-        def _update_censor_txn(txn):
-            for redaction_id, event_id, pruned_json in updates:
-                if pruned_json:
-                    self._censor_event_txn(txn, event_id, pruned_json)
-
-                self.db.simple_update_one_txn(
-                    txn,
-                    table="redactions",
-                    keyvalues={"event_id": redaction_id},
-                    updatevalues={"have_censored": True},
-                )
-
-        await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+        self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
 
-    def _censor_event_txn(self, txn, event_id, pruned_json):
-        """Censor an event by replacing its JSON in the event_json table with the
-        provided pruned JSON.
-
-        Args:
-            txn (LoggingTransaction): The database transaction.
-            event_id (str): The ID of the event to censor.
-            pruned_json (str): The pruned JSON
+    def _store_room_members_txn(self, txn, events, backfilled):
+        """Store a room member in the database.
         """
-        self.db.simple_update_one_txn(
+        self.db.simple_insert_many_txn(
             txn,
-            table="event_json",
-            keyvalues={"event_id": event_id},
-            updatevalues={"json": pruned_json},
+            table="room_memberships",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "user_id": event.state_key,
+                    "sender": event.user_id,
+                    "room_id": event.room_id,
+                    "membership": event.membership,
+                    "display_name": event.content.get("displayname", None),
+                    "avatar_url": event.content.get("avatar_url", None),
+                }
+                for event in events
+            ],
         )
 
-    @defer.inlineCallbacks
-    def count_daily_messages(self):
-        """
-        Returns an estimate of the number of messages sent in the last day.
-
-        If it has been significantly less or more than one day since the last
-        call to this function, it will return None.
-        """
-
-        def _count_messages(txn):
-            sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
-                WHERE type = 'm.room.message'
-                AND stream_ordering > ?
-            """
-            txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
-            return count
-
-        ret = yield self.db.runInteraction("count_messages", _count_messages)
-        return ret
-
-    @defer.inlineCallbacks
-    def count_daily_sent_messages(self):
-        def _count_messages(txn):
-            # This is good enough as if you have silly characters in your own
-            # hostname then thats your own fault.
-            like_clause = "%:" + self.hs.hostname
-
-            sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
-                WHERE type = 'm.room.message'
-                    AND sender LIKE ?
-                AND stream_ordering > ?
-            """
-
-            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            (count,) = txn.fetchone()
-            return count
-
-        ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
-        return ret
-
-    @defer.inlineCallbacks
-    def count_daily_active_rooms(self):
-        def _count(txn):
-            sql = """
-                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
-                WHERE type = 'm.room.message'
-                AND stream_ordering > ?
-            """
-            txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
-            return count
-
-        ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
-        return ret
-
-    @cached(num_args=5, max_entries=10)
-    def get_all_new_events(
-        self,
-        last_backfill_id,
-        last_forward_id,
-        current_backfill_id,
-        current_forward_id,
-        limit,
-    ):
-        """Get all the new events that have arrived at the server either as
-        new events or as backfilled events"""
-        have_backfill_events = last_backfill_id != current_backfill_id
-        have_forward_events = last_forward_id != current_forward_id
-
-        if not have_backfill_events and not have_forward_events:
-            return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
-        def get_all_new_events_txn(txn):
-            sql = (
-                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " WHERE ? < stream_ordering AND stream_ordering <= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
+        for event in events:
+            txn.call_after(
+                self.store._membership_stream_cache.entity_has_changed,
+                event.state_key,
+                event.internal_metadata.stream_ordering,
             )
-            if have_forward_events:
-                txn.execute(sql, (last_forward_id, current_forward_id, limit))
-                new_forward_events = txn.fetchall()
-
-                if len(new_forward_events) == limit:
-                    upper_bound = new_forward_events[-1][0]
-                else:
-                    upper_bound = current_forward_id
-
-                sql = (
-                    "SELECT event_stream_ordering, event_id, state_group"
-                    " FROM ex_outlier_stream"
-                    " WHERE ? > event_stream_ordering"
-                    " AND event_stream_ordering >= ?"
-                    " ORDER BY event_stream_ordering DESC"
-                )
-                txn.execute(sql, (last_forward_id, upper_bound))
-                forward_ex_outliers = txn.fetchall()
-            else:
-                new_forward_events = []
-                forward_ex_outliers = []
-
-            sql = (
-                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " WHERE ? > stream_ordering AND stream_ordering >= ?"
-                " ORDER BY stream_ordering DESC"
-                " LIMIT ?"
+            txn.call_after(
+                self.store.get_invited_rooms_for_local_user.invalidate,
+                (event.state_key,),
             )
-            if have_backfill_events:
-                txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
-                new_backfill_events = txn.fetchall()
 
-                if len(new_backfill_events) == limit:
-                    upper_bound = new_backfill_events[-1][0]
-                else:
-                    upper_bound = current_backfill_id
-
-                sql = (
-                    "SELECT -event_stream_ordering, event_id, state_group"
-                    " FROM ex_outlier_stream"
-                    " WHERE ? > event_stream_ordering"
-                    " AND event_stream_ordering >= ?"
-                    " ORDER BY event_stream_ordering DESC"
-                )
-                txn.execute(sql, (-last_backfill_id, -upper_bound))
-                backward_ex_outliers = txn.fetchall()
-            else:
-                new_backfill_events = []
-                backward_ex_outliers = []
-
-            return AllNewEventsResult(
-                new_forward_events,
-                new_backfill_events,
-                forward_ex_outliers,
-                backward_ex_outliers,
+            # We update the local_invites table only if the event is "current",
+            # i.e., its something that has just happened. If the event is an
+            # outlier it is only current if its an "out of band membership",
+            # like a remote invite or a rejection of a remote invite.
+            is_new_state = not backfilled and (
+                not event.internal_metadata.is_outlier()
+                or event.internal_metadata.is_out_of_band_membership()
             )
+            is_mine = self.is_mine_id(event.state_key)
+            if is_new_state and is_mine:
+                if event.membership == Membership.INVITE:
+                    self.db.simple_insert_txn(
+                        txn,
+                        table="local_invites",
+                        values={
+                            "event_id": event.event_id,
+                            "invitee": event.state_key,
+                            "inviter": event.sender,
+                            "room_id": event.room_id,
+                            "stream_id": event.internal_metadata.stream_ordering,
+                        },
+                    )
+                else:
+                    sql = (
+                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+                        " AND replaced_by is NULL"
+                    )
 
-        return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
-
-    def purge_history(self, room_id, token, delete_local_events):
-        """Deletes room history before a certain point
-
-        Args:
-            room_id (str):
+                    txn.execute(
+                        sql,
+                        (
+                            event.internal_metadata.stream_ordering,
+                            event.event_id,
+                            event.room_id,
+                            event.state_key,
+                        ),
+                    )
 
-            token (str): A topological token to delete events before
+                # We also update the `local_current_membership` table with
+                # latest invite info. This will usually get updated by the
+                # `current_state_events` handling, unless its an outlier.
+                if event.internal_metadata.is_outlier():
+                    # This should only happen for out of band memberships, so
+                    # we add a paranoia check.
+                    assert event.internal_metadata.is_out_of_band_membership()
+
+                    self.db.simple_upsert_txn(
+                        txn,
+                        table="local_current_membership",
+                        keyvalues={
+                            "room_id": event.room_id,
+                            "user_id": event.state_key,
+                        },
+                        values={
+                            "event_id": event.event_id,
+                            "membership": event.membership,
+                        },
+                    )
 
-            delete_local_events (bool):
-                if True, we will delete local events as well as remote ones
-                (instead of just marking them as outliers and deleting their
-                state groups).
+    def _handle_event_relations(self, txn, event):
+        """Handles inserting relation data during peristence of events
 
-        Returns:
-            Deferred[set[int]]: The set of state groups that are referenced by
-            deleted events.
+        Args:
+            txn
+            event (EventBase)
         """
+        relation = event.content.get("m.relates_to")
+        if not relation:
+            # No relations
+            return
 
-        return self.db.runInteraction(
-            "purge_history",
-            self._purge_history_txn,
-            room_id,
-            token,
-            delete_local_events,
-        )
+        rel_type = relation.get("rel_type")
+        if rel_type not in (
+            RelationTypes.ANNOTATION,
+            RelationTypes.REFERENCE,
+            RelationTypes.REPLACE,
+        ):
+            # Unknown relation type
+            return
 
-    def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
-        token = RoomStreamToken.parse(token_str)
-
-        # Tables that should be pruned:
-        #     event_auth
-        #     event_backward_extremities
-        #     event_edges
-        #     event_forward_extremities
-        #     event_json
-        #     event_push_actions
-        #     event_reference_hashes
-        #     event_search
-        #     event_to_state_groups
-        #     events
-        #     rejections
-        #     room_depth
-        #     state_groups
-        #     state_groups_state
-
-        # we will build a temporary table listing the events so that we don't
-        # have to keep shovelling the list back and forth across the
-        # connection. Annoyingly the python sqlite driver commits the
-        # transaction on CREATE, so let's do this first.
-        #
-        # furthermore, we might already have the table from a previous (failed)
-        # purge attempt, so let's drop the table first.
+        parent_id = relation.get("event_id")
+        if not parent_id:
+            # Invalid relation
+            return
 
-        txn.execute("DROP TABLE IF EXISTS events_to_purge")
+        aggregation_key = relation.get("key")
 
-        txn.execute(
-            "CREATE TEMPORARY TABLE events_to_purge ("
-            "    event_id TEXT NOT NULL,"
-            "    should_delete BOOLEAN NOT NULL"
-            ")"
+        self.db.simple_insert_txn(
+            txn,
+            table="event_relations",
+            values={
+                "event_id": event.event_id,
+                "relates_to_id": parent_id,
+                "relation_type": rel_type,
+                "aggregation_key": aggregation_key,
+            },
         )
 
-        # First ensure that we're not about to delete all the forward extremeties
-        txn.execute(
-            "SELECT e.event_id, e.depth FROM events as e "
-            "INNER JOIN event_forward_extremities as f "
-            "ON e.event_id = f.event_id "
-            "AND e.room_id = f.room_id "
-            "WHERE f.room_id = ?",
-            (room_id,),
+        txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,))
+        txn.call_after(
+            self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
         )
-        rows = txn.fetchall()
-        max_depth = max(row[1] for row in rows)
-
-        if max_depth < token.topological:
-            # We need to ensure we don't delete all the events from the database
-            # otherwise we wouldn't be able to send any events (due to not
-            # having any backwards extremeties)
-            raise SynapseError(
-                400, "topological_ordering is greater than forward extremeties"
-            )
-
-        logger.info("[purge] looking for events to delete")
-
-        should_delete_expr = "state_key IS NULL"
-        should_delete_params = ()
-        if not delete_local_events:
-            should_delete_expr += " AND event_id NOT LIKE ?"
-
-            # We include the parameter twice since we use the expression twice
-            should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
 
-        should_delete_params += (room_id, token.topological)
+        if rel_type == RelationTypes.REPLACE:
+            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
 
-        # Note that we insert events that are outliers and aren't going to be
-        # deleted, as nothing will happen to them.
-        txn.execute(
-            "INSERT INTO events_to_purge"
-            " SELECT event_id, %s"
-            " FROM events AS e LEFT JOIN state_events USING (event_id)"
-            " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
-            % (should_delete_expr, should_delete_expr),
-            should_delete_params,
-        )
+    def _handle_redaction(self, txn, redacted_event_id):
+        """Handles receiving a redaction and checking whether we need to remove
+        any redacted relations from the database.
 
-        # We create the indices *after* insertion as that's a lot faster.
+        Args:
+            txn
+            redacted_event_id (str): The event that was redacted.
+        """
 
-        # create an index on should_delete because later we'll be looking for
-        # the should_delete / shouldn't_delete subsets
-        txn.execute(
-            "CREATE INDEX events_to_purge_should_delete"
-            " ON events_to_purge(should_delete)"
+        self.db.simple_delete_txn(
+            txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
 
-        # We do joins against events_to_purge for e.g. calculating state
-        # groups to purge, etc., so lets make an index.
-        txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
-
-        txn.execute("SELECT event_id, should_delete FROM events_to_purge")
-        event_rows = txn.fetchall()
-        logger.info(
-            "[purge] found %i events before cutoff, of which %i can be deleted",
-            len(event_rows),
-            sum(1 for e in event_rows if e[1]),
-        )
+    def _store_room_topic_txn(self, txn, event):
+        if hasattr(event, "content") and "topic" in event.content:
+            self.store_event_search_txn(
+                txn, event, "content.topic", event.content["topic"]
+            )
 
-        logger.info("[purge] Finding new backward extremities")
+    def _store_room_name_txn(self, txn, event):
+        if hasattr(event, "content") and "name" in event.content:
+            self.store_event_search_txn(
+                txn, event, "content.name", event.content["name"]
+            )
 
-        # We calculate the new entries for the backward extremeties by finding
-        # events to be purged that are pointed to by events we're not going to
-        # purge.
-        txn.execute(
-            "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
-            " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
-            " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
-            " WHERE ep2.event_id IS NULL"
-        )
-        new_backwards_extrems = txn.fetchall()
+    def _store_room_message_txn(self, txn, event):
+        if hasattr(event, "content") and "body" in event.content:
+            self.store_event_search_txn(
+                txn, event, "content.body", event.content["body"]
+            )
 
-        logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+    def _store_retention_policy_for_room_txn(self, txn, event):
+        if hasattr(event, "content") and (
+            "min_lifetime" in event.content or "max_lifetime" in event.content
+        ):
+            if (
+                "min_lifetime" in event.content
+                and not isinstance(event.content.get("min_lifetime"), integer_types)
+            ) or (
+                "max_lifetime" in event.content
+                and not isinstance(event.content.get("max_lifetime"), integer_types)
+            ):
+                # Ignore the event if one of the value isn't an integer.
+                return
 
-        txn.execute(
-            "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
-        )
+            self.db.simple_insert_txn(
+                txn=txn,
+                table="room_retention",
+                values={
+                    "room_id": event.room_id,
+                    "event_id": event.event_id,
+                    "min_lifetime": event.content.get("min_lifetime"),
+                    "max_lifetime": event.content.get("max_lifetime"),
+                },
+            )
 
-        # Update backward extremeties
-        txn.executemany(
-            "INSERT INTO event_backward_extremities (room_id, event_id)"
-            " VALUES (?, ?)",
-            [(room_id, event_id) for event_id, in new_backwards_extrems],
-        )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_retention_policy_for_room, (event.room_id,)
+            )
 
-        logger.info("[purge] finding state groups referenced by deleted events")
+    def store_event_search_txn(self, txn, event, key, value):
+        """Add event to the search table
 
-        # Get all state groups that are referenced by events that are to be
-        # deleted.
-        txn.execute(
-            """
-            SELECT DISTINCT state_group FROM events_to_purge
-            INNER JOIN event_to_state_groups USING (event_id)
+        Args:
+            txn (cursor):
+            event (EventBase):
+            key (str):
+            value (str):
         """
+        self.store.store_search_entries_txn(
+            txn,
+            (
+                SearchEntry(
+                    key=key,
+                    value=value,
+                    event_id=event.event_id,
+                    room_id=event.room_id,
+                    stream_ordering=event.internal_metadata.stream_ordering,
+                    origin_server_ts=event.origin_server_ts,
+                ),
+            ),
         )
 
-        referenced_state_groups = {sg for sg, in txn}
-        logger.info(
-            "[purge] found %i referenced state groups", len(referenced_state_groups)
-        )
+    def _set_push_actions_for_event_and_users_txn(
+        self, txn, events_and_contexts, all_events_and_contexts
+    ):
+        """Handles moving push actions from staging table to main
+        event_push_actions table for all events in `events_and_contexts`.
 
-        logger.info("[purge] removing events from event_to_state_groups")
-        txn.execute(
-            "DELETE FROM event_to_state_groups "
-            "WHERE event_id IN (SELECT event_id from events_to_purge)"
-        )
-        for event_id, _ in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+        Also ensures that all events in `all_events_and_contexts` are removed
+        from the push action staging area.
 
-        # Delete all remote non-state events
-        for table in (
-            "events",
-            "event_json",
-            "event_auth",
-            "event_edges",
-            "event_forward_extremities",
-            "event_reference_hashes",
-            "event_search",
-            "rejections",
-        ):
-            logger.info("[purge] removing events from %s", table)
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            all_events_and_contexts (list[(EventBase, EventContext)]): all
+                events that we were going to persist. This includes events
+                we've already persisted, etc, that wouldn't appear in
+                events_and_context.
+        """
 
-            txn.execute(
-                "DELETE FROM %s WHERE event_id IN ("
-                "    SELECT event_id FROM events_to_purge WHERE should_delete"
-                ")" % (table,)
+        sql = """
+            INSERT INTO event_push_actions (
+                room_id, event_id, user_id, actions, stream_ordering,
+                topological_ordering, notif, highlight
             )
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            FROM event_push_actions_staging
+            WHERE event_id = ?
+        """
 
-        # event_push_actions lacks an index on event_id, and has one on
-        # (room_id, event_id) instead.
-        for table in ("event_push_actions",):
-            logger.info("[purge] removing events from %s", table)
+        if events_and_contexts:
+            txn.executemany(
+                sql,
+                (
+                    (
+                        event.room_id,
+                        event.internal_metadata.stream_ordering,
+                        event.depth,
+                        event.event_id,
+                    )
+                    for event, _ in events_and_contexts
+                ),
+            )
 
-            txn.execute(
-                "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
-                "    SELECT event_id FROM events_to_purge WHERE should_delete"
-                ")" % (table,),
-                (room_id,),
+        for event, _ in events_and_contexts:
+            user_ids = self.db.simple_select_onecol_txn(
+                txn,
+                table="event_push_actions_staging",
+                keyvalues={"event_id": event.event_id},
+                retcol="user_id",
             )
 
-        # Mark all state and own events as outliers
-        logger.info("[purge] marking remaining events as outliers")
-        txn.execute(
-            "UPDATE events SET outlier = ?"
-            " WHERE event_id IN ("
-            "    SELECT event_id FROM events_to_purge "
-            "    WHERE NOT should_delete"
-            ")",
-            (True,),
+            for uid in user_ids:
+                txn.call_after(
+                    self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+                    (event.room_id, uid),
+                )
+
+        # Now we delete the staging area for *all* events that were being
+        # persisted.
+        txn.executemany(
+            "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+            ((event.event_id,) for event, _ in all_events_and_contexts),
         )
 
-        # synapse tries to take out an exclusive lock on room_depth whenever it
-        # persists events (because upsert), and once we run this update, we
-        # will block that for the rest of our transaction.
-        #
-        # So, let's stick it at the end so that we don't block event
-        # persistence.
-        #
-        # We do this by calculating the minimum depth of the backwards
-        # extremities. However, the events in event_backward_extremities
-        # are ones we don't have yet so we need to look at the events that
-        # point to it via event_edges table.
-        txn.execute(
-            """
-            SELECT COALESCE(MIN(depth), 0)
-            FROM event_backward_extremities AS eb
-            INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
-            INNER JOIN events AS e ON e.event_id = eg.event_id
-            WHERE eb.room_id = ?
-        """,
+    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+        # Sad that we have to blow away the cache for the whole room here
+        txn.call_after(
+            self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
             (room_id,),
         )
-        (min_depth,) = txn.fetchone()
-
-        logger.info("[purge] updating room_depth to %d", min_depth)
-
         txn.execute(
-            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
-            (min_depth, room_id),
+            "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
+            (room_id, event_id),
         )
 
-        # finally, drop the temp table. this will commit the txn in sqlite,
-        # so make sure to keep this actually last.
-        txn.execute("DROP TABLE events_to_purge")
-
-        logger.info("[purge] done")
-
-        return referenced_state_groups
-
-    def purge_room(self, room_id):
-        """Deletes all record of a room
+    def _store_rejections_txn(self, txn, event_id, reason):
+        self.db.simple_insert_txn(
+            txn,
+            table="rejections",
+            values={
+                "event_id": event_id,
+                "reason": reason,
+                "last_check": self._clock.time_msec(),
+            },
+        )
 
-        Args:
-            room_id (str)
+    def _store_event_state_mappings_txn(
+        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
+    ):
+        state_groups = {}
+        for event, context in events_and_contexts:
+            if event.internal_metadata.is_outlier():
+                continue
 
-        Returns:
-            Deferred[List[int]]: The list of state groups to delete.
-        """
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.state_group_before_event
+                continue
 
-        return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
+            state_groups[event.event_id] = context.state_group
 
-    def _purge_room_txn(self, txn, room_id):
-        # First we fetch all the state groups that should be deleted, before
-        # we delete that information.
-        txn.execute(
-            """
-                SELECT DISTINCT state_group FROM events
-                INNER JOIN event_to_state_groups USING(event_id)
-                WHERE events.room_id = ?
-            """,
-            (room_id,),
+        self.db.simple_insert_many_txn(
+            txn,
+            table="event_to_state_groups",
+            values=[
+                {"state_group": state_group_id, "event_id": event_id}
+                for event_id, state_group_id in iteritems(state_groups)
+            ],
         )
 
-        state_groups = [row[0] for row in txn]
-
-        # Now we delete tables which lack an index on room_id but have one on event_id
-        for table in (
-            "event_auth",
-            "event_edges",
-            "event_push_actions_staging",
-            "event_reference_hashes",
-            "event_relations",
-            "event_to_state_groups",
-            "redactions",
-            "rejections",
-            "state_events",
-        ):
-            logger.info("[purge] removing %s from %s", room_id, table)
-
-            txn.execute(
-                """
-                DELETE FROM %s WHERE event_id IN (
-                  SELECT event_id FROM events WHERE room_id=?
-                )
-                """
-                % (table,),
-                (room_id,),
+        for event_id, state_group_id in iteritems(state_groups):
+            txn.call_after(
+                self.store._get_state_group_for_event.prefill,
+                (event_id,),
+                state_group_id,
             )
 
-        # and finally, the tables with an index on room_id (or no useful index)
-        for table in (
-            "current_state_events",
-            "event_backward_extremities",
-            "event_forward_extremities",
-            "event_json",
-            "event_push_actions",
-            "event_search",
-            "events",
-            "group_rooms",
-            "public_room_list_stream",
-            "receipts_graph",
-            "receipts_linearized",
-            "room_aliases",
-            "room_depth",
-            "room_memberships",
-            "room_stats_state",
-            "room_stats_current",
-            "room_stats_historical",
-            "room_stats_earliest_token",
-            "rooms",
-            "stream_ordering_to_exterm",
-            "users_in_public_rooms",
-            "users_who_share_private_rooms",
-            # no useful index, but let's clear them anyway
-            "appservice_room_list",
-            "e2e_room_keys",
-            "event_push_summary",
-            "pusher_throttle",
-            "group_summary_rooms",
-            "local_invites",
-            "room_account_data",
-            "room_tags",
-            "local_current_membership",
-        ):
-            logger.info("[purge] removing %s from %s", room_id, table)
-            txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
-
-        # Other tables we do NOT need to clear out:
-        #
-        #  - blocked_rooms
-        #    This is important, to make sure that we don't accidentally rejoin a blocked
-        #    room after it was purged
-        #
-        #  - user_directory
-        #    This has a room_id column, but it is unused
-        #
-
-        # Other tables that we might want to consider clearing out include:
-        #
-        #  - event_reports
-        #       Given that these are intended for abuse management my initial
-        #       inclination is to leave them in place.
-        #
-        #  - current_state_delta_stream
-        #  - ex_outlier_stream
-        #  - room_tags_revisions
-        #       The problem with these is that they are largeish and there is no room_id
-        #       index on them. In any case we should be clearing out 'stream' tables
-        #       periodically anyway (#5888)
-
-        # TODO: we could probably usefully do a bunch of cache invalidation here
+    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+        min_depth = self.store._get_min_depth_interaction(txn, room_id)
 
-        logger.info("[purge] done")
-
-        return state_groups
-
-    async def is_event_after(self, event_id1, event_id2):
-        """Returns True if event_id1 is after event_id2 in the stream
-        """
-        to_1, so_1 = await self._get_event_ordering(event_id1)
-        to_2, so_2 = await self._get_event_ordering(event_id2)
-        return (to_1, so_1) > (to_2, so_2)
+        if min_depth is not None and depth >= min_depth:
+            return
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def _get_event_ordering(self, event_id):
-        res = yield self.db.simple_select_one(
-            table="events",
-            retcols=["topological_ordering", "stream_ordering"],
-            keyvalues={"event_id": event_id},
-            allow_none=True,
+        self.db.simple_upsert_txn(
+            txn,
+            table="room_depth",
+            keyvalues={"room_id": room_id},
+            values={"min_depth": depth},
         )
 
-        if not res:
-            raise SynapseError(404, "Could not find event %s" % (event_id,))
-
-        return (int(res["topological_ordering"]), int(res["stream_ordering"]))
-
-    def insert_labels_for_event_txn(
-        self, txn, event_id, labels, room_id, topological_ordering
-    ):
-        """Store the mapping between an event's ID and its labels, with one row per
-        (event_id, label) tuple.
-
-        Args:
-            txn (LoggingTransaction): The transaction to execute.
-            event_id (str): The event's ID.
-            labels (list[str]): A list of text labels.
-            room_id (str): The ID of the room the event was sent to.
-            topological_ordering (int): The position of the event in the room's topology.
+    def _handle_mult_prev_events(self, txn, events):
         """
-        return self.db.simple_insert_many_txn(
-            txn=txn,
-            table="event_labels",
+        For the given event, update the event edges table and forward and
+        backward extremities tables.
+        """
+        self.db.simple_insert_many_txn(
+            txn,
+            table="event_edges",
             values=[
                 {
-                    "event_id": event_id,
-                    "label": label,
-                    "room_id": room_id,
-                    "topological_ordering": topological_ordering,
+                    "event_id": ev.event_id,
+                    "prev_event_id": e_id,
+                    "room_id": ev.room_id,
+                    "is_state": False,
                 }
-                for label in labels
+                for ev in events
+                for e_id in ev.prev_event_ids()
             ],
         )
 
-    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
-        """Save the expiry timestamp associated with a given event ID.
-
-        Args:
-            txn (LoggingTransaction): The database transaction to use.
-            event_id (str): The event ID the expiry timestamp is associated with.
-            expiry_ts (int): The timestamp at which to expire (delete) the event.
-        """
-        return self.db.simple_insert_txn(
-            txn=txn,
-            table="event_expiry",
-            values={"event_id": event_id, "expiry_ts": expiry_ts},
-        )
-
-    @defer.inlineCallbacks
-    def expire_event(self, event_id):
-        """Retrieve and expire an event that has expired, and delete its associated
-        expiry timestamp. If the event can't be retrieved, delete its associated
-        timestamp so we don't try to expire it again in the future.
-
-        Args:
-             event_id (str): The ID of the event to delete.
-        """
-        # Try to retrieve the event's content from the database or the event cache.
-        event = yield self.get_event(event_id)
-
-        def delete_expired_event_txn(txn):
-            # Delete the expiry timestamp associated with this event from the database.
-            self._delete_event_expiry_txn(txn, event_id)
-
-            if not event:
-                # If we can't find the event, log a warning and delete the expiry date
-                # from the database so that we don't try to expire it again in the
-                # future.
-                logger.warning(
-                    "Can't expire event %s because we don't have it.", event_id
-                )
-                return
-
-            # Prune the event's dict then convert it to JSON.
-            pruned_json = encode_json(
-                prune_event_dict(event.room_version, event.get_dict())
-            )
-
-            # Update the event_json table to replace the event's JSON with the pruned
-            # JSON.
-            self._censor_event_txn(txn, event.event_id, pruned_json)
-
-            # We need to invalidate the event cache entry for this event because we
-            # changed its content in the database. We can't call
-            # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
-            # right type.
-            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
-            # Send that invalidation to replication so that other workers also invalidate
-            # the event cache.
-            self._send_invalidation_to_replication(
-                txn, "_get_event_cache", (event.event_id,)
-            )
+        self._update_backward_extremeties(txn, events)
 
-        yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+    def _update_backward_extremeties(self, txn, events):
+        """Updates the event_backward_extremities tables based on the new/updated
+        events being persisted.
 
-    def _delete_event_expiry_txn(self, txn, event_id):
-        """Delete the expiry timestamp associated with an event ID without deleting the
-        actual event.
+        This is called for new events *and* for events that were outliers, but
+        are now being persisted as non-outliers.
 
-        Args:
-            txn (LoggingTransaction): The transaction to use to perform the deletion.
-            event_id (str): The event ID to delete the associated expiry timestamp of.
+        Forward extremities are handled when we first start persisting the events.
         """
-        return self.db.simple_delete_txn(
-            txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+        events_by_room = {}
+        for ev in events:
+            events_by_room.setdefault(ev.room_id, []).append(ev)
+
+        query = (
+            "INSERT INTO event_backward_extremities (event_id, room_id)"
+            " SELECT ?, ? WHERE NOT EXISTS ("
+            " SELECT 1 FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+            " )"
+            " AND NOT EXISTS ("
+            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+            " AND outlier = ?"
+            " )"
         )
 
-    def get_next_event_to_expire(self):
-        """Retrieve the entry with the lowest expiry timestamp in the event_expiry
-        table, or None if there's no more event to expire.
-
-        Returns: Deferred[Optional[Tuple[str, int]]]
-            A tuple containing the event ID as its first element and an expiry timestamp
-            as its second one, if there's at least one row in the event_expiry table.
-            None otherwise.
-        """
-
-        def get_next_event_to_expire_txn(txn):
-            txn.execute(
-                """
-                SELECT event_id, expiry_ts FROM event_expiry
-                ORDER BY expiry_ts ASC LIMIT 1
-                """
-            )
-
-            return txn.fetchone()
-
-        return self.db.runInteraction(
-            desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+        txn.executemany(
+            query,
+            [
+                (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+                for ev in events
+                for e_id in ev.prev_event_ids()
+                if not ev.internal_metadata.is_outlier()
+            ],
         )
 
-
-AllNewEventsResult = namedtuple(
-    "AllNewEventsResult",
-    [
-        "new_forward_events",
-        "new_backfill_events",
-        "forward_ex_outliers",
-        "backward_ex_outliers",
-    ],
-)
+        query = (
+            "DELETE FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+        )
+        txn.executemany(
+            query,
+            [
+                (ev.event_id, ev.room_id)
+                for ev in events
+                if not ev.internal_metadata.is_outlier()
+            ],
+        )
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 73df6b33ba..970c31bd05 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -27,7 +27,7 @@ from constantly import NamedConstant, Names
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
@@ -40,7 +40,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
 from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -75,7 +75,10 @@ class EventsWorkerStore(SQLBaseStore):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
         self._get_event_cache = Cache(
-            "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+            "*getEvent*",
+            keylen=3,
+            max_entries=hs.config.caches.event_cache_size,
+            apply_cache_factor_from_config=False,
         )
 
         self._event_fetch_lock = threading.Condition()
@@ -1154,4 +1157,152 @@ class EventsWorkerStore(SQLBaseStore):
         rows = await self.db.runInteraction(
             "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
         )
+
         return rows, to_token, True
+
+    @cached(num_args=5, max_entries=10)
+    def get_all_new_events(
+        self,
+        last_backfill_id,
+        last_forward_id,
+        current_backfill_id,
+        current_forward_id,
+        limit,
+    ):
+        """Get all the new events that have arrived at the server either as
+        new events or as backfilled events"""
+        have_backfill_events = last_backfill_id != current_backfill_id
+        have_forward_events = last_forward_id != current_forward_id
+
+        if not have_backfill_events and not have_forward_events:
+            return defer.succeed(AllNewEventsResult([], [], [], [], []))
+
+        def get_all_new_events_txn(txn):
+            sql = (
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            if have_forward_events:
+                txn.execute(sql, (last_forward_id, current_forward_id, limit))
+                new_forward_events = txn.fetchall()
+
+                if len(new_forward_events) == limit:
+                    upper_bound = new_forward_events[-1][0]
+                else:
+                    upper_bound = current_forward_id
+
+                sql = (
+                    "SELECT event_stream_ordering, event_id, state_group"
+                    " FROM ex_outlier_stream"
+                    " WHERE ? > event_stream_ordering"
+                    " AND event_stream_ordering >= ?"
+                    " ORDER BY event_stream_ordering DESC"
+                )
+                txn.execute(sql, (last_forward_id, upper_bound))
+                forward_ex_outliers = txn.fetchall()
+            else:
+                new_forward_events = []
+                forward_ex_outliers = []
+
+            sql = (
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering DESC"
+                " LIMIT ?"
+            )
+            if have_backfill_events:
+                txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
+                new_backfill_events = txn.fetchall()
+
+                if len(new_backfill_events) == limit:
+                    upper_bound = new_backfill_events[-1][0]
+                else:
+                    upper_bound = current_backfill_id
+
+                sql = (
+                    "SELECT -event_stream_ordering, event_id, state_group"
+                    " FROM ex_outlier_stream"
+                    " WHERE ? > event_stream_ordering"
+                    " AND event_stream_ordering >= ?"
+                    " ORDER BY event_stream_ordering DESC"
+                )
+                txn.execute(sql, (-last_backfill_id, -upper_bound))
+                backward_ex_outliers = txn.fetchall()
+            else:
+                new_backfill_events = []
+                backward_ex_outliers = []
+
+            return AllNewEventsResult(
+                new_forward_events,
+                new_backfill_events,
+                forward_ex_outliers,
+                backward_ex_outliers,
+            )
+
+        return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
+
+    async def is_event_after(self, event_id1, event_id2):
+        """Returns True if event_id1 is after event_id2 in the stream
+        """
+        to_1, so_1 = await self._get_event_ordering(event_id1)
+        to_2, so_2 = await self._get_event_ordering(event_id2)
+        return (to_1, so_1) > (to_2, so_2)
+
+    @cachedInlineCallbacks(max_entries=5000)
+    def _get_event_ordering(self, event_id):
+        res = yield self.db.simple_select_one(
+            table="events",
+            retcols=["topological_ordering", "stream_ordering"],
+            keyvalues={"event_id": event_id},
+            allow_none=True,
+        )
+
+        if not res:
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+        return (int(res["topological_ordering"]), int(res["stream_ordering"]))
+
+    def get_next_event_to_expire(self):
+        """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+        table, or None if there's no more event to expire.
+
+        Returns: Deferred[Optional[Tuple[str, int]]]
+            A tuple containing the event ID as its first element and an expiry timestamp
+            as its second one, if there's at least one row in the event_expiry table.
+            None otherwise.
+        """
+
+        def get_next_event_to_expire_txn(txn):
+            txn.execute(
+                """
+                SELECT event_id, expiry_ts FROM event_expiry
+                ORDER BY expiry_ts ASC LIMIT 1
+                """
+            )
+
+            return txn.fetchone()
+
+        return self.db.runInteraction(
+            desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+        )
+
+
+AllNewEventsResult = namedtuple(
+    "AllNewEventsResult",
+    [
+        "new_forward_events",
+        "new_backfill_events",
+        "forward_ex_outliers",
+        "backward_ex_outliers",
+    ],
+)
diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/data_stores/main/metrics.py
new file mode 100644
index 0000000000..dad5bbc602
--- /dev/null
+++ b/synapse/storage/data_stores/main/metrics.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import typing
+from collections import Counter
+
+from twisted.internet import defer
+
+from synapse.metrics import BucketCollector
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.event_push_actions import (
+    EventPushActionsWorkerStore,
+)
+from synapse.storage.database import Database
+
+
+class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
+    """Functions to pull various metrics from the DB, for e.g. phone home
+    stats and prometheus metrics.
+    """
+
+    def __init__(self, database: Database, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        # Collect metrics on the number of forward extremities that exist.
+        # Counter of number of extremities to count
+        self._current_forward_extremities_amount = (
+            Counter()
+        )  # type: typing.Counter[int]
+
+        BucketCollector(
+            "synapse_forward_extremities",
+            lambda: self._current_forward_extremities_amount,
+            buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
+        )
+
+        # Read the extrems every 60 minutes
+        def read_forward_extremities():
+            # run as a background process to make sure that the database transactions
+            # have a logcontext to report to
+            return run_as_background_process(
+                "read_forward_extremities", self._read_forward_extremities
+            )
+
+        hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+
+    async def _read_forward_extremities(self):
+        def fetch(txn):
+            txn.execute(
+                """
+                select count(*) c from event_forward_extremities
+                group by room_id
+                """
+            )
+            return txn.fetchall()
+
+        res = await self.db.runInteraction("read_forward_extremities", fetch)
+        self._current_forward_extremities_amount = Counter([x[0] for x in res])
+
+    @defer.inlineCallbacks
+    def count_daily_messages(self):
+        """
+        Returns an estimate of the number of messages sent in the last day.
+
+        If it has been significantly less or more than one day since the last
+        call to this function, it will return None.
+        """
+
+        def _count_messages(txn):
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.message'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        ret = yield self.db.runInteraction("count_messages", _count_messages)
+        return ret
+
+    @defer.inlineCallbacks
+    def count_daily_sent_messages(self):
+        def _count_messages(txn):
+            # This is good enough as if you have silly characters in your own
+            # hostname then thats your own fault.
+            like_clause = "%:" + self.hs.hostname
+
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.message'
+                    AND sender LIKE ?
+                AND stream_ordering > ?
+            """
+
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+            (count,) = txn.fetchone()
+            return count
+
+        ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
+        return ret
+
+    @defer.inlineCallbacks
+    def count_daily_active_rooms(self):
+        def _count(txn):
+            sql = """
+                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                WHERE type = 'm.room.message'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
+        return ret
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
new file mode 100644
index 0000000000..a93e1ef198
--- /dev/null
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -0,0 +1,399 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, Tuple
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.types import RoomStreamToken
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
+    def purge_history(self, room_id, token, delete_local_events):
+        """Deletes room history before a certain point
+
+        Args:
+            room_id (str):
+
+            token (str): A topological token to delete events before
+
+            delete_local_events (bool):
+                if True, we will delete local events as well as remote ones
+                (instead of just marking them as outliers and deleting their
+                state groups).
+
+        Returns:
+            Deferred[set[int]]: The set of state groups that are referenced by
+            deleted events.
+        """
+
+        return self.db.runInteraction(
+            "purge_history",
+            self._purge_history_txn,
+            room_id,
+            token,
+            delete_local_events,
+        )
+
+    def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
+        token = RoomStreamToken.parse(token_str)
+
+        # Tables that should be pruned:
+        #     event_auth
+        #     event_backward_extremities
+        #     event_edges
+        #     event_forward_extremities
+        #     event_json
+        #     event_push_actions
+        #     event_reference_hashes
+        #     event_search
+        #     event_to_state_groups
+        #     events
+        #     rejections
+        #     room_depth
+        #     state_groups
+        #     state_groups_state
+
+        # we will build a temporary table listing the events so that we don't
+        # have to keep shovelling the list back and forth across the
+        # connection. Annoyingly the python sqlite driver commits the
+        # transaction on CREATE, so let's do this first.
+        #
+        # furthermore, we might already have the table from a previous (failed)
+        # purge attempt, so let's drop the table first.
+
+        txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+        txn.execute(
+            "CREATE TEMPORARY TABLE events_to_purge ("
+            "    event_id TEXT NOT NULL,"
+            "    should_delete BOOLEAN NOT NULL"
+            ")"
+        )
+
+        # First ensure that we're not about to delete all the forward extremeties
+        txn.execute(
+            "SELECT e.event_id, e.depth FROM events as e "
+            "INNER JOIN event_forward_extremities as f "
+            "ON e.event_id = f.event_id "
+            "AND e.room_id = f.room_id "
+            "WHERE f.room_id = ?",
+            (room_id,),
+        )
+        rows = txn.fetchall()
+        max_depth = max(row[1] for row in rows)
+
+        if max_depth < token.topological:
+            # We need to ensure we don't delete all the events from the database
+            # otherwise we wouldn't be able to send any events (due to not
+            # having any backwards extremeties)
+            raise SynapseError(
+                400, "topological_ordering is greater than forward extremeties"
+            )
+
+        logger.info("[purge] looking for events to delete")
+
+        should_delete_expr = "state_key IS NULL"
+        should_delete_params = ()  # type: Tuple[Any, ...]
+        if not delete_local_events:
+            should_delete_expr += " AND event_id NOT LIKE ?"
+
+            # We include the parameter twice since we use the expression twice
+            should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
+
+        should_delete_params += (room_id, token.topological)
+
+        # Note that we insert events that are outliers and aren't going to be
+        # deleted, as nothing will happen to them.
+        txn.execute(
+            "INSERT INTO events_to_purge"
+            " SELECT event_id, %s"
+            " FROM events AS e LEFT JOIN state_events USING (event_id)"
+            " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
+            % (should_delete_expr, should_delete_expr),
+            should_delete_params,
+        )
+
+        # We create the indices *after* insertion as that's a lot faster.
+
+        # create an index on should_delete because later we'll be looking for
+        # the should_delete / shouldn't_delete subsets
+        txn.execute(
+            "CREATE INDEX events_to_purge_should_delete"
+            " ON events_to_purge(should_delete)"
+        )
+
+        # We do joins against events_to_purge for e.g. calculating state
+        # groups to purge, etc., so lets make an index.
+        txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
+
+        txn.execute("SELECT event_id, should_delete FROM events_to_purge")
+        event_rows = txn.fetchall()
+        logger.info(
+            "[purge] found %i events before cutoff, of which %i can be deleted",
+            len(event_rows),
+            sum(1 for e in event_rows if e[1]),
+        )
+
+        logger.info("[purge] Finding new backward extremities")
+
+        # We calculate the new entries for the backward extremeties by finding
+        # events to be purged that are pointed to by events we're not going to
+        # purge.
+        txn.execute(
+            "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+            " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+            " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
+            " WHERE ep2.event_id IS NULL"
+        )
+        new_backwards_extrems = txn.fetchall()
+
+        logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+
+        txn.execute(
+            "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
+        )
+
+        # Update backward extremeties
+        txn.executemany(
+            "INSERT INTO event_backward_extremities (room_id, event_id)"
+            " VALUES (?, ?)",
+            [(room_id, event_id) for event_id, in new_backwards_extrems],
+        )
+
+        logger.info("[purge] finding state groups referenced by deleted events")
+
+        # Get all state groups that are referenced by events that are to be
+        # deleted.
+        txn.execute(
+            """
+            SELECT DISTINCT state_group FROM events_to_purge
+            INNER JOIN event_to_state_groups USING (event_id)
+        """
+        )
+
+        referenced_state_groups = {sg for sg, in txn}
+        logger.info(
+            "[purge] found %i referenced state groups", len(referenced_state_groups)
+        )
+
+        logger.info("[purge] removing events from event_to_state_groups")
+        txn.execute(
+            "DELETE FROM event_to_state_groups "
+            "WHERE event_id IN (SELECT event_id from events_to_purge)"
+        )
+        for event_id, _ in event_rows:
+            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+
+        # Delete all remote non-state events
+        for table in (
+            "events",
+            "event_json",
+            "event_auth",
+            "event_edges",
+            "event_forward_extremities",
+            "event_reference_hashes",
+            "event_search",
+            "rejections",
+        ):
+            logger.info("[purge] removing events from %s", table)
+
+            txn.execute(
+                "DELETE FROM %s WHERE event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,)
+            )
+
+        # event_push_actions lacks an index on event_id, and has one on
+        # (room_id, event_id) instead.
+        for table in ("event_push_actions",):
+            logger.info("[purge] removing events from %s", table)
+
+            txn.execute(
+                "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,),
+                (room_id,),
+            )
+
+        # Mark all state and own events as outliers
+        logger.info("[purge] marking remaining events as outliers")
+        txn.execute(
+            "UPDATE events SET outlier = ?"
+            " WHERE event_id IN ("
+            "    SELECT event_id FROM events_to_purge "
+            "    WHERE NOT should_delete"
+            ")",
+            (True,),
+        )
+
+        # synapse tries to take out an exclusive lock on room_depth whenever it
+        # persists events (because upsert), and once we run this update, we
+        # will block that for the rest of our transaction.
+        #
+        # So, let's stick it at the end so that we don't block event
+        # persistence.
+        #
+        # We do this by calculating the minimum depth of the backwards
+        # extremities. However, the events in event_backward_extremities
+        # are ones we don't have yet so we need to look at the events that
+        # point to it via event_edges table.
+        txn.execute(
+            """
+            SELECT COALESCE(MIN(depth), 0)
+            FROM event_backward_extremities AS eb
+            INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
+            INNER JOIN events AS e ON e.event_id = eg.event_id
+            WHERE eb.room_id = ?
+        """,
+            (room_id,),
+        )
+        (min_depth,) = txn.fetchone()
+
+        logger.info("[purge] updating room_depth to %d", min_depth)
+
+        txn.execute(
+            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+            (min_depth, room_id),
+        )
+
+        # finally, drop the temp table. this will commit the txn in sqlite,
+        # so make sure to keep this actually last.
+        txn.execute("DROP TABLE events_to_purge")
+
+        logger.info("[purge] done")
+
+        return referenced_state_groups
+
+    def purge_room(self, room_id):
+        """Deletes all record of a room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[List[int]]: The list of state groups to delete.
+        """
+
+        return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
+
+    def _purge_room_txn(self, txn, room_id):
+        # First we fetch all the state groups that should be deleted, before
+        # we delete that information.
+        txn.execute(
+            """
+                SELECT DISTINCT state_group FROM events
+                INNER JOIN event_to_state_groups USING(event_id)
+                WHERE events.room_id = ?
+            """,
+            (room_id,),
+        )
+
+        state_groups = [row[0] for row in txn]
+
+        # Now we delete tables which lack an index on room_id but have one on event_id
+        for table in (
+            "event_auth",
+            "event_edges",
+            "event_push_actions_staging",
+            "event_reference_hashes",
+            "event_relations",
+            "event_to_state_groups",
+            "redactions",
+            "rejections",
+            "state_events",
+        ):
+            logger.info("[purge] removing %s from %s", room_id, table)
+
+            txn.execute(
+                """
+                DELETE FROM %s WHERE event_id IN (
+                  SELECT event_id FROM events WHERE room_id=?
+                )
+                """
+                % (table,),
+                (room_id,),
+            )
+
+        # and finally, the tables with an index on room_id (or no useful index)
+        for table in (
+            "current_state_events",
+            "event_backward_extremities",
+            "event_forward_extremities",
+            "event_json",
+            "event_push_actions",
+            "event_search",
+            "events",
+            "group_rooms",
+            "public_room_list_stream",
+            "receipts_graph",
+            "receipts_linearized",
+            "room_aliases",
+            "room_depth",
+            "room_memberships",
+            "room_stats_state",
+            "room_stats_current",
+            "room_stats_historical",
+            "room_stats_earliest_token",
+            "rooms",
+            "stream_ordering_to_exterm",
+            "users_in_public_rooms",
+            "users_who_share_private_rooms",
+            # no useful index, but let's clear them anyway
+            "appservice_room_list",
+            "e2e_room_keys",
+            "event_push_summary",
+            "pusher_throttle",
+            "group_summary_rooms",
+            "local_invites",
+            "room_account_data",
+            "room_tags",
+            "local_current_membership",
+        ):
+            logger.info("[purge] removing %s from %s", room_id, table)
+            txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
+
+        # Other tables we do NOT need to clear out:
+        #
+        #  - blocked_rooms
+        #    This is important, to make sure that we don't accidentally rejoin a blocked
+        #    room after it was purged
+        #
+        #  - user_directory
+        #    This has a room_id column, but it is unused
+        #
+
+        # Other tables that we might want to consider clearing out include:
+        #
+        #  - event_reports
+        #       Given that these are intended for abuse management my initial
+        #       inclination is to leave them in place.
+        #
+        #  - current_state_delta_stream
+        #  - ex_outlier_stream
+        #  - room_tags_revisions
+        #       The problem with these is that they are largeish and there is no room_id
+        #       index on them. In any case we should be clearing out 'stream' tables
+        #       periodically anyway (#5888)
+
+        # TODO: we could probably usefully do a bunch of cache invalidation here
+
+        logger.info("[purge] done")
+
+        return state_groups
diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py
index 1c07c7a425..27e5a2084a 100644
--- a/synapse/storage/data_stores/main/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -21,17 +21,6 @@ logger = logging.getLogger(__name__)
 
 
 class RejectionsStore(SQLBaseStore):
-    def _store_rejections_txn(self, txn, event_id, reason):
-        self.db.simple_insert_txn(
-            txn,
-            table="rejections",
-            values={
-                "event_id": event_id,
-                "reason": reason,
-                "last_check": self._clock.time_msec(),
-            },
-        )
-
     def get_rejection_reason(self, event_id):
         return self.db.simple_select_one_onecol(
             table="rejections",
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
index 046c2b4845..7d477f8d01 100644
--- a/synapse/storage/data_stores/main/relations.py
+++ b/synapse/storage/data_stores/main/relations.py
@@ -324,62 +324,4 @@ class RelationsWorkerStore(SQLBaseStore):
 
 
 class RelationsStore(RelationsWorkerStore):
-    def _handle_event_relations(self, txn, event):
-        """Handles inserting relation data during peristence of events
-
-        Args:
-            txn
-            event (EventBase)
-        """
-        relation = event.content.get("m.relates_to")
-        if not relation:
-            # No relations
-            return
-
-        rel_type = relation.get("rel_type")
-        if rel_type not in (
-            RelationTypes.ANNOTATION,
-            RelationTypes.REFERENCE,
-            RelationTypes.REPLACE,
-        ):
-            # Unknown relation type
-            return
-
-        parent_id = relation.get("event_id")
-        if not parent_id:
-            # Invalid relation
-            return
-
-        aggregation_key = relation.get("key")
-
-        self.db.simple_insert_txn(
-            txn,
-            table="event_relations",
-            values={
-                "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
-            },
-        )
-
-        txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
-        txn.call_after(
-            self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
-        )
-
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
-
-    def _handle_redaction(self, txn, redacted_event_id):
-        """Handles receiving a redaction and checking whether we need to remove
-        any redacted relations from the database.
-
-        Args:
-            txn
-            redacted_event_id (str): The event that was redacted.
-        """
-
-        self.db.simple_delete_txn(
-            txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
-        )
+    pass
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 147eba1df7..46f643c6b9 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -21,8 +21,6 @@ from abc import abstractmethod
 from enum import Enum
 from typing import Any, Dict, List, Optional, Tuple
 
-from six import integer_types
-
 from canonicaljson import json
 
 from twisted.internet import defer
@@ -98,6 +96,37 @@ class RoomWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
+    def get_room_with_stats(self, room_id: str):
+        """Retrieve room with statistics.
+
+        Args:
+            room_id: The ID of the room to retrieve.
+        Returns:
+            A dict containing the room information, or None if the room is unknown.
+        """
+
+        def get_room_with_stats_txn(txn, room_id):
+            sql = """
+                SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
+                  curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
+                  rooms.creator, state.encryption, state.is_federatable AS federatable,
+                  rooms.is_public AS public, state.join_rules, state.guest_access,
+                  state.history_visibility, curr.current_state_events AS state_events
+                FROM rooms
+                LEFT JOIN room_stats_state state USING (room_id)
+                LEFT JOIN room_stats_current curr USING (room_id)
+                WHERE room_id = ?
+                """
+            txn.execute(sql, [room_id])
+            res = self.db.cursor_to_dict(txn)[0]
+            res["federatable"] = bool(res["federatable"])
+            res["public"] = bool(res["public"])
+            return res
+
+        return self.db.runInteraction(
+            "get_room_with_stats", get_room_with_stats_txn, room_id
+        )
+
     def get_public_room_ids(self):
         return self.db.simple_select_onecol(
             table="rooms",
@@ -1271,53 +1300,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
         return self.db.runInteraction("get_rooms", f)
 
-    def _store_room_topic_txn(self, txn, event):
-        if hasattr(event, "content") and "topic" in event.content:
-            self.store_event_search_txn(
-                txn, event, "content.topic", event.content["topic"]
-            )
-
-    def _store_room_name_txn(self, txn, event):
-        if hasattr(event, "content") and "name" in event.content:
-            self.store_event_search_txn(
-                txn, event, "content.name", event.content["name"]
-            )
-
-    def _store_room_message_txn(self, txn, event):
-        if hasattr(event, "content") and "body" in event.content:
-            self.store_event_search_txn(
-                txn, event, "content.body", event.content["body"]
-            )
-
-    def _store_retention_policy_for_room_txn(self, txn, event):
-        if hasattr(event, "content") and (
-            "min_lifetime" in event.content or "max_lifetime" in event.content
-        ):
-            if (
-                "min_lifetime" in event.content
-                and not isinstance(event.content.get("min_lifetime"), integer_types)
-            ) or (
-                "max_lifetime" in event.content
-                and not isinstance(event.content.get("max_lifetime"), integer_types)
-            ):
-                # Ignore the event if one of the value isn't an integer.
-                return
-
-            self.db.simple_insert_txn(
-                txn=txn,
-                table="room_retention",
-                values={
-                    "room_id": event.room_id,
-                    "event_id": event.event_id,
-                    "min_lifetime": event.content.get("min_lifetime"),
-                    "max_lifetime": event.content.get("max_lifetime"),
-                },
-            )
-
-            self._invalidate_cache_and_stream(
-                txn, self.get_retention_policy_for_room, (event.room_id,)
-            )
-
     def add_event_report(
         self, room_id, event_id, user_id, reason, content, received_ts
     ):
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index e626b7f6f7..48810a3e91 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -153,16 +153,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 self._check_safe_current_state_events_membership_updated_txn,
             )
 
-    @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
-    def get_hosts_in_room(self, room_id, cache_context):
-        """Returns the set of all hosts currently in the room
-        """
-        user_ids = yield self.get_users_in_room(
-            room_id, on_invalidate=cache_context.invalidate
-        )
-        hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
-        return hosts
-
     @cached(max_entries=100000, iterable=True)
     def get_users_in_room(self, room_id):
         return self.db.runInteraction(
@@ -1061,96 +1051,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
         super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database.
-        """
-        self.db.simple_insert_many_txn(
-            txn,
-            table="room_memberships",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "user_id": event.state_key,
-                    "sender": event.user_id,
-                    "room_id": event.room_id,
-                    "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
-                }
-                for event in events
-            ],
-        )
-
-        for event in events:
-            txn.call_after(
-                self._membership_stream_cache.entity_has_changed,
-                event.state_key,
-                event.internal_metadata.stream_ordering,
-            )
-            txn.call_after(
-                self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
-            )
-
-            # We update the local_invites table only if the event is "current",
-            # i.e., its something that has just happened. If the event is an
-            # outlier it is only current if its an "out of band membership",
-            # like a remote invite or a rejection of a remote invite.
-            is_new_state = not backfilled and (
-                not event.internal_metadata.is_outlier()
-                or event.internal_metadata.is_out_of_band_membership()
-            )
-            is_mine = self.hs.is_mine_id(event.state_key)
-            if is_new_state and is_mine:
-                if event.membership == Membership.INVITE:
-                    self.db.simple_insert_txn(
-                        txn,
-                        table="local_invites",
-                        values={
-                            "event_id": event.event_id,
-                            "invitee": event.state_key,
-                            "inviter": event.sender,
-                            "room_id": event.room_id,
-                            "stream_id": event.internal_metadata.stream_ordering,
-                        },
-                    )
-                else:
-                    sql = (
-                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
-                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-                        " AND replaced_by is NULL"
-                    )
-
-                    txn.execute(
-                        sql,
-                        (
-                            event.internal_metadata.stream_ordering,
-                            event.event_id,
-                            event.room_id,
-                            event.state_key,
-                        ),
-                    )
-
-                # We also update the `local_current_membership` table with
-                # latest invite info. This will usually get updated by the
-                # `current_state_events` handling, unless its an outlier.
-                if event.internal_metadata.is_outlier():
-                    # This should only happen for out of band memberships, so
-                    # we add a paranoia check.
-                    assert event.internal_metadata.is_out_of_band_membership()
-
-                    self.db.simple_upsert_txn(
-                        txn,
-                        table="local_current_membership",
-                        keyvalues={
-                            "room_id": event.room_id,
-                            "user_id": event.state_key,
-                        },
-                        values={
-                            "event_id": event.event_id,
-                            "membership": event.membership,
-                        },
-                    )
-
     @defer.inlineCallbacks
     def locally_reject_invite(self, user_id, room_id):
         sql = (
diff --git a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
new file mode 100644
index 0000000000..d5e6deb878
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will create a unique index on
+-- device_lists_outbound_last_success
+INSERT into background_updates (ordering, update_name, progress_json)
+    VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}');
+
+-- once that completes, we can drop the old index.
+INSERT into background_updates (ordering, update_name, progress_json, depends_on)
+    VALUES (
+        5804,
+        'drop_device_lists_outbound_last_success_non_unique_idx',
+        '{}',
+        'device_lists_outbound_last_success_unique_idx'
+    );
diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
new file mode 100644
index 0000000000..aa46eb0e10
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We keep the old table here to enable us to roll back. It doesn't matter
+-- that we have dropped all the data here.
+TRUNCATE cache_invalidation_stream;
+
+CREATE TABLE cache_invalidation_stream_by_instance (
+    stream_id       BIGINT NOT NULL,
+    instance_name   TEXT NOT NULL,
+    cache_func      TEXT NOT NULL,
+    keys            TEXT[],
+    invalidation_ts BIGINT
+);
+
+CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
+
+CREATE SEQUENCE cache_invalidation_stream_seq;
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 47ebb8a214..ee75b92344 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -347,29 +347,6 @@ class SearchStore(SearchBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
         super(SearchStore, self).__init__(database, db_conn, hs)
 
-    def store_event_search_txn(self, txn, event, key, value):
-        """Add event to the search table
-
-        Args:
-            txn (cursor):
-            event (EventBase):
-            key (str):
-            value (str):
-        """
-        self.store_search_entries_txn(
-            txn,
-            (
-                SearchEntry(
-                    key=key,
-                    value=value,
-                    event_id=event.event_id,
-                    room_id=event.room_id,
-                    stream_ordering=event.internal_metadata.stream_ordering,
-                    origin_server_ts=event.origin_server_ts,
-                ),
-            ),
-        )
-
     @defer.inlineCallbacks
     def search_msgs(self, room_ids, search_term, keys):
         """Performs a full text search over events with given keys.
diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py
index 563216b63c..36244d9f5d 100644
--- a/synapse/storage/data_stores/main/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -13,23 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import six
-
 from unpaddedbase64 import encode_base64
 
 from twisted.internet import defer
 
-from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
-    db_binary_type = six.moves.builtins.buffer
-else:
-    db_binary_type = memoryview
-
 
 class SignatureWorkerStore(SQLBaseStore):
     @cached()
@@ -79,23 +69,3 @@ class SignatureWorkerStore(SQLBaseStore):
 
 class SignatureStore(SignatureWorkerStore):
     """Persistence for event signatures and hashes"""
-
-    def _store_event_reference_hashes_txn(self, txn, events):
-        """Store a hash for a PDU
-        Args:
-            txn (cursor):
-            events (list): list of Events.
-        """
-
-        vals = []
-        for event in events:
-            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append(
-                {
-                    "event_id": event.event_id,
-                    "algorithm": ref_alg,
-                    "hash": db_binary_type(ref_hash_bytes),
-                }
-            )
-
-        self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 3a3b9a8e72..21052fcc7a 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -16,17 +16,12 @@
 import collections.abc
 import logging
 from collections import namedtuple
-from typing import Iterable, Tuple
-
-from six import iteritems
 
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
@@ -473,33 +468,3 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
 
     def __init__(self, database: Database, db_conn, hs):
         super(StateStore, self).__init__(database, db_conn, hs)
-
-    def _store_event_state_mappings_txn(
-        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
-    ):
-        state_groups = {}
-        for event, context in events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                continue
-
-            # if the event was rejected, just give it the same state as its
-            # predecessor.
-            if context.rejected:
-                state_groups[event.event_id] = context.state_group_before_event
-                continue
-
-            state_groups[event.event_id] = context.state_group
-
-        self.db.simple_insert_many_txn(
-            txn,
-            table="event_to_state_groups",
-            values=[
-                {"state_group": state_group_id, "event_id": event_id}
-                for event_id, state_group_id in iteritems(state_groups)
-            ],
-        )
-
-        for event_id, state_group_id in iteritems(state_groups):
-            txn.call_after(
-                self._get_state_group_for_event.prefill, (event_id,), state_group_id
-            )
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 57a5267663..f3ad1e4369 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -28,7 +28,6 @@ from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateSt
 from synapse.storage.database import Database
 from synapse.storage.state import StateFilter
 from synapse.types import StateMap
-from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
 
@@ -90,11 +89,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         self._state_group_cache = DictionaryCache(
             "*stateGroupCache*",
             # TODO: this hasn't been tuned yet
-            50000 * get_cache_factor_for("stateGroupCache"),
+            50000,
         )
         self._state_group_members_cache = DictionaryCache(
-            "*stateGroupMembersCache*",
-            500000 * get_cache_factor_for("stateGroupMembersCache"),
+            "*stateGroupMembersCache*", 500000,
         )
 
     @cached(max_entries=10000, iterable=True)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 50f475bfd3..c3d0863429 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
 from synapse.util.stringutils import exception_to_unicode
 
 logger = logging.getLogger(__name__)
@@ -78,6 +79,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
     "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
     "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
     "event_search": "event_search_event_id_idx",
+    "device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx",
 }
 
 
@@ -889,20 +891,24 @@ class Database(object):
         txn.execute(sql, list(allvalues.values()))
 
     def simple_upsert_many_txn(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[str]],
+    ) -> None:
         """
         Upsert, many times.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_many_txn_native_upsert(
@@ -914,20 +920,24 @@ class Database(object):
             )
 
     def simple_upsert_many_txn_emulated(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Iterable[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[str]],
+    ) -> None:
         """
         Upsert, many times, but without native UPSERT support or batching.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         # No value columns, therefore make a blank list so that the following
         # zip() works correctly.
@@ -941,20 +951,24 @@ class Database(object):
             self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
 
     def simple_upsert_many_txn_native_upsert(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[Any]],
+    ) -> None:
         """
         Upsert, many times, using batching where possible.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         allnames = []  # type: List[str]
         allnames.extend(key_names)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 0f9ac1cf09..41881ea20b 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -23,7 +23,6 @@ from typing import Iterable, List, Optional, Set, Tuple
 from six import iteritems
 from six.moves import range
 
-import attr
 from prometheus_client import Counter, Histogram
 
 from twisted.internet import defer
@@ -35,6 +34,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.state import StateResolutionStore
 from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main.events import DeltaState
 from synapse.types import StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
@@ -73,22 +73,6 @@ stale_forward_extremities_counter = Histogram(
 )
 
 
-@attr.s(slots=True)
-class DeltaState:
-    """Deltas to use to update the `current_state_events` table.
-
-    Attributes:
-        to_delete: List of type/state_keys to delete from current state
-        to_insert: Map of state to upsert into current state
-        no_longer_in_room: The server is not longer in the room, so the room
-            should e.g. be removed from `current_state_events` table.
-    """
-
-    to_delete = attr.ib(type=List[Tuple[str, str]])
-    to_insert = attr.ib(type=StateMap[str])
-    no_longer_in_room = attr.ib(type=bool, default=False)
-
-
 class _EventPeristenceQueue(object):
     """Queues up events so that they can be persisted in bulk with only one
     concurrent transaction per room.
@@ -205,6 +189,7 @@ class EventsPersistenceStorage(object):
         # store for now.
         self.main_store = stores.main
         self.state_store = stores.state
+        self.persist_events_store = stores.persist_events
 
         self._clock = hs.get_clock()
         self.is_mine_id = hs.is_mine_id
@@ -445,7 +430,7 @@ class EventsPersistenceStorage(object):
                         if current_state is not None:
                             current_state_for_room[room_id] = current_state
 
-            await self.main_store._persist_events_and_state_updates(
+            await self.persist_events_store._persist_events_and_state_updates(
                 chunk,
                 current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
@@ -491,13 +476,15 @@ class EventsPersistenceStorage(object):
         )
 
         # Remove any events which are prev_events of any existing events.
-        existing_prevs = await self.main_store._get_events_which_are_prevs(result)
+        existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
+            result
+        )
         result.difference_update(existing_prevs)
 
         # Finally handle the case where the new events have soft-failed prev
         # events. If they do we need to remove them and their prev events,
         # otherwise we end up with dangling extremities.
-        existing_prevs = await self.main_store._get_prevs_before_rejected(
+        existing_prevs = await self.persist_events_store._get_prevs_before_rejected(
             e_id for event in new_events for e_id in event.prev_event_ids()
         )
         result.difference_update(existing_prevs)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1712932f31..640f242584 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -29,6 +29,8 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
+# XXX: If you're about to bump this to 59 (or higher) please create an update
+# that drops the unused `cache_invalidation_stream` table, as per #7436!
 SCHEMA_VERSION = 58
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d851beaa5..86d04ea9ac 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
 import contextlib
 import threading
 from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
 
 
 class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
             )
-        self._unfinished_ids = deque()
+        self._unfinished_ids = deque()  # type: Deque[int]
 
     def get_next(self):
         """
@@ -163,7 +168,7 @@ class ChainedIdGenerator(object):
         self.chained_generator = chained_generator
         self._lock = threading.Lock()
         self._current_max = _load_current_id(db_conn, table, column)
-        self._unfinished_ids = deque()
+        self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
 
     def get_next(self):
         """
@@ -198,3 +203,163 @@ class ChainedIdGenerator(object):
                 return stream_id - 1, chained_id
 
             return self._current_max, self.chained_generator.get_current_token()
+
+
+class MultiWriterIdGenerator:
+    """An ID generator that tracks a stream that can have multiple writers.
+
+    Uses a Postgres sequence to coordinate ID assignment, but positions of other
+    writers will only get updated when `advance` is called (by replication).
+
+    Note: Only works with Postgres.
+
+    Args:
+        db_conn
+        db
+        instance_name: The name of this instance.
+        table: Database table associated with stream.
+        instance_column: Column that stores the row's writer's instance name
+        id_column: Column that stores the stream ID.
+        sequence_name: The name of the postgres sequence used to generate new
+            IDs.
+    """
+
+    def __init__(
+        self,
+        db_conn,
+        db: Database,
+        instance_name: str,
+        table: str,
+        instance_column: str,
+        id_column: str,
+        sequence_name: str,
+    ):
+        self._db = db
+        self._instance_name = instance_name
+        self._sequence_name = sequence_name
+
+        # We lock as some functions may be called from DB threads.
+        self._lock = threading.Lock()
+
+        self._current_positions = self._load_current_ids(
+            db_conn, table, instance_column, id_column
+        )
+
+        # Set of local IDs that we're still processing. The current position
+        # should be less than the minimum of this set (if not empty).
+        self._unfinished_ids = set()  # type: Set[int]
+
+    def _load_current_ids(
+        self, db_conn, table: str, instance_column: str, id_column: str
+    ) -> Dict[str, int]:
+        sql = """
+            SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+            GROUP BY %(instance)s
+        """ % {
+            "instance": instance_column,
+            "id": id_column,
+            "table": table,
+        }
+
+        cur = db_conn.cursor()
+        cur.execute(sql)
+
+        # `cur` is an iterable over returned rows, which are 2-tuples.
+        current_positions = dict(cur)
+
+        cur.close()
+
+        return current_positions
+
+    def _load_next_id_txn(self, txn):
+        txn.execute("SELECT nextval(?)", (self._sequence_name,))
+        (next_id,) = txn.fetchone()
+        return next_id
+
+    async def get_next(self):
+        """
+        Usage:
+            with await stream_id_gen.get_next() as stream_id:
+                # ... persist event ...
+        """
+        next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+        # Assert the fetched ID is actually greater than what we currently
+        # believe the ID to be. If not, then the sequence and table have got
+        # out of sync somehow.
+        assert self.get_current_token() < next_id
+
+        with self._lock:
+            self._unfinished_ids.add(next_id)
+
+        @contextlib.contextmanager
+        def manager():
+            try:
+                yield next_id
+            finally:
+                self._mark_id_as_finished(next_id)
+
+        return manager()
+
+    def get_next_txn(self, txn: LoggingTransaction):
+        """
+        Usage:
+
+            stream_id = stream_id_gen.get_next(txn)
+            # ... persist event ...
+        """
+
+        next_id = self._load_next_id_txn(txn)
+
+        with self._lock:
+            self._unfinished_ids.add(next_id)
+
+        txn.call_after(self._mark_id_as_finished, next_id)
+        txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+        return next_id
+
+    def _mark_id_as_finished(self, next_id: int):
+        """The ID has finished being processed so we should advance the
+        current poistion if possible.
+        """
+
+        with self._lock:
+            self._unfinished_ids.discard(next_id)
+
+            # Figure out if its safe to advance the position by checking there
+            # aren't any lower allocated IDs that are yet to finish.
+            if all(c > next_id for c in self._unfinished_ids):
+                curr = self._current_positions.get(self._instance_name, 0)
+                self._current_positions[self._instance_name] = max(curr, next_id)
+
+    def get_current_token(self, instance_name: str = None) -> int:
+        """Gets the current position of a named writer (defaults to current
+        instance).
+
+        Returns 0 if we don't have a position for the named writer (likely due
+        to it being a new writer).
+        """
+
+        if instance_name is None:
+            instance_name = self._instance_name
+
+        with self._lock:
+            return self._current_positions.get(instance_name, 0)
+
+    def get_positions(self) -> Dict[str, int]:
+        """Get a copy of the current positon map.
+        """
+
+        with self._lock:
+            return dict(self._current_positions)
+
+    def advance(self, instance_name: str, new_id: int):
+        """Advance the postion of the named writer to the given ID, if greater
+        than existing entry.
+        """
+
+        with self._lock:
+            self._current_positions[instance_name] = max(
+                new_id, self._current_positions.get(instance_name, 0)
+            )
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index da5077b471..4b8a0c7a8f 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019, 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,27 +15,17 @@
 # limitations under the License.
 
 import logging
-import os
-from typing import Dict
+from typing import Callable, Dict, Optional
 
 import six
 from six.moves import intern
 
-from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily
-
-logger = logging.getLogger(__name__)
-
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
+import attr
+from prometheus_client.core import Gauge
 
+from synapse.config.cache import add_resizable_cache
 
-def get_cache_factor_for(cache_name):
-    env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper()
-    factor = os.environ.get(env_var)
-    if factor:
-        return float(factor)
-
-    return CACHE_SIZE_FACTOR
-
+logger = logging.getLogger(__name__)
 
 caches_by_name = {}
 collectors_by_name = {}  # type: Dict
@@ -44,6 +34,7 @@ cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
 cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
 cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
 cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
+cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
 
 response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
 response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
@@ -53,67 +44,82 @@ response_cache_evicted = Gauge(
 response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
 
 
-def register_cache(cache_type, cache_name, cache, collect_callback=None):
-    """Register a cache object for metric collection.
+@attr.s
+class CacheMetric(object):
+
+    _cache = attr.ib()
+    _cache_type = attr.ib(type=str)
+    _cache_name = attr.ib(type=str)
+    _collect_callback = attr.ib(type=Optional[Callable])
+
+    hits = attr.ib(default=0)
+    misses = attr.ib(default=0)
+    evicted_size = attr.ib(default=0)
+
+    def inc_hits(self):
+        self.hits += 1
+
+    def inc_misses(self):
+        self.misses += 1
+
+    def inc_evictions(self, size=1):
+        self.evicted_size += size
+
+    def describe(self):
+        return []
+
+    def collect(self):
+        try:
+            if self._cache_type == "response_cache":
+                response_cache_size.labels(self._cache_name).set(len(self._cache))
+                response_cache_hits.labels(self._cache_name).set(self.hits)
+                response_cache_evicted.labels(self._cache_name).set(self.evicted_size)
+                response_cache_total.labels(self._cache_name).set(
+                    self.hits + self.misses
+                )
+            else:
+                cache_size.labels(self._cache_name).set(len(self._cache))
+                cache_hits.labels(self._cache_name).set(self.hits)
+                cache_evicted.labels(self._cache_name).set(self.evicted_size)
+                cache_total.labels(self._cache_name).set(self.hits + self.misses)
+                if getattr(self._cache, "max_size", None):
+                    cache_max_size.labels(self._cache_name).set(self._cache.max_size)
+            if self._collect_callback:
+                self._collect_callback()
+        except Exception as e:
+            logger.warning("Error calculating metrics for %s: %s", self._cache_name, e)
+            raise
+
+
+def register_cache(
+    cache_type: str,
+    cache_name: str,
+    cache,
+    collect_callback: Optional[Callable] = None,
+    resizable: bool = True,
+    resize_callback: Optional[Callable] = None,
+) -> CacheMetric:
+    """Register a cache object for metric collection and resizing.
 
     Args:
-        cache_type (str):
-        cache_name (str): name of the cache
-        cache (object): cache itself
-        collect_callback (callable|None): if not None, a function which is called during
-            metric collection to update additional metrics.
+        cache_type
+        cache_name: name of the cache
+        cache: cache itself
+        collect_callback: If given, a function which is called during metric
+            collection to update additional metrics.
+        resizable: Whether this cache supports being resized.
+        resize_callback: A function which can be called to resize the cache.
 
     Returns:
         CacheMetric: an object which provides inc_{hits,misses,evictions} methods
     """
+    if resizable:
+        if not resize_callback:
+            resize_callback = getattr(cache, "set_cache_factor")
+        add_resizable_cache(cache_name, resize_callback)
 
-    # Check if the metric is already registered. Unregister it, if so.
-    # This usually happens during tests, as at runtime these caches are
-    # effectively singletons.
+    metric = CacheMetric(cache, cache_type, cache_name, collect_callback)
     metric_name = "cache_%s_%s" % (cache_type, cache_name)
-    if metric_name in collectors_by_name.keys():
-        REGISTRY.unregister(collectors_by_name[metric_name])
-
-    class CacheMetric(object):
-
-        hits = 0
-        misses = 0
-        evicted_size = 0
-
-        def inc_hits(self):
-            self.hits += 1
-
-        def inc_misses(self):
-            self.misses += 1
-
-        def inc_evictions(self, size=1):
-            self.evicted_size += size
-
-        def describe(self):
-            return []
-
-        def collect(self):
-            try:
-                if cache_type == "response_cache":
-                    response_cache_size.labels(cache_name).set(len(cache))
-                    response_cache_hits.labels(cache_name).set(self.hits)
-                    response_cache_evicted.labels(cache_name).set(self.evicted_size)
-                    response_cache_total.labels(cache_name).set(self.hits + self.misses)
-                else:
-                    cache_size.labels(cache_name).set(len(cache))
-                    cache_hits.labels(cache_name).set(self.hits)
-                    cache_evicted.labels(cache_name).set(self.evicted_size)
-                    cache_total.labels(cache_name).set(self.hits + self.misses)
-                if collect_callback:
-                    collect_callback()
-            except Exception as e:
-                logger.warning("Error calculating metrics for %s: %s", cache_name, e)
-                raise
-
-            yield GaugeMetricFamily("__unused", "")
-
-    metric = CacheMetric()
-    REGISTRY.register(metric)
     caches_by_name[cache_name] = cache
     collectors_by_name[metric_name] = metric
     return metric
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 2e8f6543e5..cd48262420 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,6 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import functools
 import inspect
 import logging
@@ -30,7 +31,6 @@ from twisted.internet import defer
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
 
@@ -81,7 +81,6 @@ class CacheEntry(object):
 class Cache(object):
     __slots__ = (
         "cache",
-        "max_entries",
         "name",
         "keylen",
         "thread",
@@ -89,7 +88,29 @@ class Cache(object):
         "_pending_deferred_cache",
     )
 
-    def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
+    def __init__(
+        self,
+        name: str,
+        max_entries: int = 1000,
+        keylen: int = 1,
+        tree: bool = False,
+        iterable: bool = False,
+        apply_cache_factor_from_config: bool = True,
+    ):
+        """
+        Args:
+            name: The name of the cache
+            max_entries: Maximum amount of entries that the cache will hold
+            keylen: The length of the tuple used as the cache key
+            tree: Use a TreeCache instead of a dict as the underlying cache type
+            iterable: If True, count each item in the cached object as an entry,
+                rather than each cached object
+            apply_cache_factor_from_config: Whether cache factors specified in the
+                config file affect `max_entries`
+
+        Returns:
+            Cache
+        """
         cache_type = TreeCache if tree else dict
         self._pending_deferred_cache = cache_type()
 
@@ -99,6 +120,7 @@ class Cache(object):
             cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
             evicted_callback=self._on_evicted,
+            apply_cache_factor_from_config=apply_cache_factor_from_config,
         )
 
         self.name = name
@@ -111,6 +133,10 @@ class Cache(object):
             collect_callback=self._metrics_collection_callback,
         )
 
+    @property
+    def max_entries(self):
+        return self.cache.max_size
+
     def _on_evicted(self, evicted_count):
         self.metrics.inc_evictions(evicted_count)
 
@@ -370,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase):
             cache_context=cache_context,
         )
 
-        max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
-
         self.max_entries = max_entries
         self.tree = tree
         self.iterable = iterable
 
-    def __get__(self, obj, objtype=None):
+    def __get__(self, obj, owner):
         cache = Cache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index cddf1ed515..2726b67b6d 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -18,6 +18,7 @@ from collections import OrderedDict
 
 from six import iteritems, itervalues
 
+from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.caches import register_cache
 
@@ -51,15 +52,16 @@ class ExpiringCache(object):
                 an item on access. Defaults to False.
             iterable (bool): If true, the size is calculated by summing the
                 sizes of all entries, rather than the number of entries.
-
         """
         self._cache_name = cache_name
 
+        self._original_max_size = max_len
+
+        self._max_size = int(max_len * cache_config.properties.default_factor_size)
+
         self._clock = clock
 
-        self._max_len = max_len
         self._expiry_ms = expiry_ms
-
         self._reset_expiry_on_get = reset_expiry_on_get
 
         self._cache = OrderedDict()
@@ -82,9 +84,11 @@ class ExpiringCache(object):
     def __setitem__(self, key, value):
         now = self._clock.time_msec()
         self._cache[key] = _CacheEntry(now, value)
+        self.evict()
 
+    def evict(self):
         # Evict if there are now too many items
-        while self._max_len and len(self) > self._max_len:
+        while self._max_size and len(self) > self._max_size:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
                 self.metrics.inc_evictions(len(value.value))
@@ -170,6 +174,23 @@ class ExpiringCache(object):
         else:
             return len(self._cache)
 
+    def set_cache_factor(self, factor: float) -> bool:
+        """
+        Set the cache factor for this individual cache.
+
+        This will trigger a resize if it changes, which may require evicting
+        items from the cache.
+
+        Returns:
+            bool: Whether the cache changed size or not.
+        """
+        new_size = int(self._original_max_size * factor)
+        if new_size != self._max_size:
+            self._max_size = new_size
+            self.evict()
+            return True
+        return False
+
 
 class _CacheEntry(object):
     __slots__ = ["time", "value"]
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 1536cb64f3..29fabac3cd 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -13,10 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import threading
 from functools import wraps
+from typing import Callable, Optional, Type, Union
 
+from synapse.config import cache as cache_config
 from synapse.util.caches.treecache import TreeCache
 
 
@@ -52,17 +53,18 @@ class LruCache(object):
 
     def __init__(
         self,
-        max_size,
-        keylen=1,
-        cache_type=dict,
-        size_callback=None,
-        evicted_callback=None,
+        max_size: int,
+        keylen: int = 1,
+        cache_type: Type[Union[dict, TreeCache]] = dict,
+        size_callback: Optional[Callable] = None,
+        evicted_callback: Optional[Callable] = None,
+        apply_cache_factor_from_config: bool = True,
     ):
         """
         Args:
-            max_size (int):
+            max_size: The maximum amount of entries the cache can hold
 
-            keylen (int):
+            keylen: The length of the tuple used as the cache key
 
             cache_type (type):
                 type of underlying cache to be used. Typically one of dict
@@ -73,9 +75,23 @@ class LruCache(object):
             evicted_callback (func(int)|None):
                 if not None, called on eviction with the size of the evicted
                 entry
+
+            apply_cache_factor_from_config (bool): If true, `max_size` will be
+                multiplied by a cache factor derived from the homeserver config
         """
         cache = cache_type()
         self.cache = cache  # Used for introspection.
+
+        # Save the original max size, and apply the default size factor.
+        self._original_max_size = max_size
+        # We previously didn't apply the cache factor here, and as such some caches were
+        # not affected by the global cache factor. Add an option here to disable applying
+        # the cache factor when a cache is created
+        if apply_cache_factor_from_config:
+            self.max_size = int(max_size * cache_config.properties.default_factor_size)
+        else:
+            self.max_size = int(max_size)
+
         list_root = _Node(None, None, None, None)
         list_root.next_node = list_root
         list_root.prev_node = list_root
@@ -83,7 +99,7 @@ class LruCache(object):
         lock = threading.Lock()
 
         def evict():
-            while cache_len() > max_size:
+            while cache_len() > self.max_size:
                 todelete = list_root.prev_node
                 evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
@@ -236,6 +252,7 @@ class LruCache(object):
             return key in cache
 
         self.sentinel = object()
+        self._on_resize = evict
         self.get = cache_get
         self.set = cache_set
         self.setdefault = cache_set_default
@@ -266,3 +283,20 @@ class LruCache(object):
 
     def __contains__(self, key):
         return self.contains(key)
+
+    def set_cache_factor(self, factor: float) -> bool:
+        """
+        Set the cache factor for this individual cache.
+
+        This will trigger a resize if it changes, which may require evicting
+        items from the cache.
+
+        Returns:
+            bool: Whether the cache changed size or not.
+        """
+        new_size = int(self._original_max_size * factor)
+        if new_size != self.max_size:
+            self.max_size = new_size
+            self._on_resize()
+            return True
+        return False
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index b68f9fe0d4..a6c60888e5 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -38,7 +38,7 @@ class ResponseCache(object):
         self.timeout_sec = timeout_ms / 1000.0
 
         self._name = name
-        self._metrics = register_cache("response_cache", name, self)
+        self._metrics = register_cache("response_cache", name, self, resizable=False)
 
     def size(self):
         return len(self.pending_result_cache)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index e54f80d76e..2a161bf244 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+import math
 from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
 
 from six import integer_types
@@ -46,7 +47,8 @@ class StreamChangeCache:
         max_size=10000,
         prefilled_cache: Optional[Mapping[EntityType, int]] = None,
     ):
-        self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR)
+        self._original_max_size = max_size
+        self._max_size = math.floor(max_size)
         self._entity_to_key = {}  # type: Dict[EntityType, int]
 
         # map from stream id to the a set of entities which changed at that stream id.
@@ -58,12 +60,31 @@ class StreamChangeCache:
         #
         self._earliest_known_stream_pos = current_stream_pos
         self.name = name
-        self.metrics = caches.register_cache("cache", self.name, self._cache)
+        self.metrics = caches.register_cache(
+            "cache", self.name, self._cache, resize_callback=self.set_cache_factor
+        )
 
         if prefilled_cache:
             for entity, stream_pos in prefilled_cache.items():
                 self.entity_has_changed(entity, stream_pos)
 
+    def set_cache_factor(self, factor: float) -> bool:
+        """
+        Set the cache factor for this individual cache.
+
+        This will trigger a resize if it changes, which may require evicting
+        items from the cache.
+
+        Returns:
+            bool: Whether the cache changed size or not.
+        """
+        new_size = math.floor(self._original_max_size * factor)
+        if new_size != self._max_size:
+            self.max_size = new_size
+            self._evict()
+            return True
+        return False
+
     def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
         """Returns True if the entity may have been updated since stream_pos
         """
@@ -171,6 +192,7 @@ class StreamChangeCache:
             e1 = self._cache[stream_pos] = set()
         e1.add(entity)
         self._entity_to_key[entity] = stream_pos
+        self._evict()
 
         # if the cache is too big, remove entries
         while len(self._cache) > self._max_size:
@@ -179,6 +201,13 @@ class StreamChangeCache:
             for entity in r:
                 del self._entity_to_key[entity]
 
+    def _evict(self):
+        while len(self._cache) > self._max_size:
+            k, r = self._cache.popitem(0)
+            self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
+            for entity in r:
+                self._entity_to_key.pop(entity, None)
+
     def get_max_pos_of_last_change(self, entity: EntityType) -> int:
 
         """Returns an upper bound of the stream id of the last change to an
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 99646c7cf0..6437aa907e 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -38,7 +38,7 @@ class TTLCache(object):
 
         self._timer = timer
 
-        self._metrics = register_cache("ttl", cache_name, self)
+        self._metrics = register_cache("ttl", cache_name, self, resizable=False)
 
     def set(self, key, value, ttl):
         """Add/update an entry in the cache