summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2020-06-10 17:24:43 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2020-06-10 17:26:48 +0100
commitcde3bda815c43494b0b824191fbe84ec86c7cbc0 (patch)
tree1a65220b47280bc8a5fd84f8feb2fabcfbc69298 /synapse/storage
parentMerge branch 'release-v1.13.0' of github.com:matrix-org/synapse into dinsic-r... (diff)
parentFix typo in PR link (diff)
downloadsynapse-cde3bda815c43494b0b824191fbe84ec86c7cbc0.tar.xz
Merge branch 'release-v1.14.0' of github.com:matrix-org/synapse into dinsic-release-v1.14.x
* 'release-v1.14.0' of github.com:matrix-org/synapse: (108 commits)
  Fix typo in PR link
  Update debian changelog
  1.14.0
  Improve changelog wording
  1.14.0rc2
  Fix sample config docs error (#7581)
  Fix up comments
  Fix specifying cache factors via env vars with * in name. (#7580)
  Don't apply cache factor to event cache. (#7578)
  Ensure ReplicationStreamer is always started when replication enabled. (#7579)
  Remove the changes to the debian changelog
  Not full release yet, this is rc1
  Merge event persistence move changelog entries
  More changelog fix
  Changelog fixes
  1.14.0
  Replace device_27_unique_idx bg update with a fg one (#7562)
  Fix incorrect exception handling in KeyUploadServlet.on_POST (#7563)
  Fix recording of federation stream token (#7564)
  Simplify reap_monthly_active_users (#7558)
  ...
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py11
-rw-r--r--synapse/storage/data_stores/__init__.py9
-rw-r--r--synapse/storage/data_stores/main/__init__.py40
-rw-r--r--synapse/storage/data_stores/main/account_data.py62
-rw-r--r--synapse/storage/data_stores/main/appservice.py2
-rw-r--r--synapse/storage/data_stores/main/cache.py146
-rw-r--r--synapse/storage/data_stores/main/censor_events.py208
-rw-r--r--synapse/storage/data_stores/main/client_ips.py3
-rw-r--r--synapse/storage/data_stores/main/devices.py100
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py98
-rw-r--r--synapse/storage/data_stores/main/event_federation.py83
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py74
-rw-r--r--synapse/storage/data_stores/main/events.py1232
-rw-r--r--synapse/storage/data_stores/main/events_worker.py192
-rw-r--r--synapse/storage/data_stores/main/group_server.py92
-rw-r--r--synapse/storage/data_stores/main/keys.py10
-rw-r--r--synapse/storage/data_stores/main/metrics.py128
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py124
-rw-r--r--synapse/storage/data_stores/main/profile.py2
-rw-r--r--synapse/storage/data_stores/main/purge_events.py399
-rw-r--r--synapse/storage/data_stores/main/push_rule.py14
-rw-r--r--synapse/storage/data_stores/main/rejections.py11
-rw-r--r--synapse/storage/data_stores/main/relations.py60
-rw-r--r--synapse/storage/data_stores/main/room.py80
-rw-r--r--synapse/storage/data_stores/main/roommember.py148
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres30
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py80
-rw-r--r--synapse/storage/data_stores/main/search.py119
-rw-r--r--synapse/storage/data_stores/main/signatures.py30
-rw-r--r--synapse/storage/data_stores/main/state.py40
-rw-r--r--synapse/storage/data_stores/main/transactions.py9
-rw-r--r--synapse/storage/data_stores/state/store.py6
-rw-r--r--synapse/storage/database.py88
-rw-r--r--synapse/storage/persist_events.py37
-rw-r--r--synapse/storage/prepare_database.py15
-rw-r--r--synapse/storage/util/id_generators.py180
36 files changed, 2295 insertions, 1667 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 13de5f1f62..bfce541ca7 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,9 +19,6 @@ import random
 from abc import ABCMeta
 from typing import Any, Optional
 
-from six import PY2
-from six.moves import builtins
-
 from canonicaljson import json
 
 from synapse.storage.database import LoggingTransaction  # noqa: F401
@@ -47,6 +44,9 @@ class SQLBaseStore(metaclass=ABCMeta):
         self.db = database
         self.rand = random.SystemRandom()
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        pass
+
     def _invalidate_state_caches(self, room_id, members_changed):
         """Invalidates caches that are based on the current state, but does
         not stream invalidations down replication.
@@ -100,11 +100,6 @@ def db_to_json(db_content):
     if isinstance(db_content, memoryview):
         db_content = db_content.tobytes()
 
-    # psycopg2 on Python 2 returns buffer objects, which we need to cast to
-    # bytes to decode
-    if PY2 and isinstance(db_content, builtins.buffer):
-        db_content = bytes(db_content)
-
     # Decode it to a Unicode string before feeding it to json.loads, so we
     # consistenty get a Unicode-containing object out.
     if isinstance(db_content, (bytes, bytearray)):
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index e1d03429ca..599ee470d4 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -15,6 +15,7 @@
 
 import logging
 
+from synapse.storage.data_stores.main.events import PersistEventsStore
 from synapse.storage.data_stores.state import StateGroupDataStore
 from synapse.storage.database import Database, make_conn
 from synapse.storage.engines import create_engine
@@ -39,6 +40,7 @@ class DataStores(object):
         self.databases = []
         self.main = None
         self.state = None
+        self.persist_events = None
 
         for database_config in hs.config.database.databases:
             db_name = database_config.name
@@ -64,6 +66,13 @@ class DataStores(object):
 
                     self.main = main_store_class(database, db_conn, hs)
 
+                    # If we're on a process that can persist events also
+                    # instantiate a `PersistEventsStore`
+                    if hs.config.worker.writers.events == hs.get_instance_name():
+                        self.persist_events = PersistEventsStore(
+                            hs, database, self.main
+                        )
+
                 if "state" in database_config.data_stores:
                     logger.info("Starting 'state' data store")
 
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index ceba10882c..4b4763c701 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -24,15 +24,16 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
-    ChainedIdGenerator,
     IdGenerator,
+    MultiWriterIdGenerator,
     StreamIdGenerator,
 )
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from .account_data import AccountDataStore
 from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .cache import CacheInvalidationStore
+from .cache import CacheInvalidationWorkerStore
+from .censor_events import CensorEventsStore
 from .client_ips import ClientIpStore
 from .deviceinbox import DeviceInboxStore
 from .devices import DeviceStore
@@ -41,16 +42,17 @@ from .e2e_room_keys import EndToEndRoomKeyStore
 from .end_to_end_keys import EndToEndKeyStore
 from .event_federation import EventFederationStore
 from .event_push_actions import EventPushActionsStore
-from .events import EventsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
 from .filtering import FilteringStore
 from .group_server import GroupServerStore
 from .keys import KeyStore
 from .media_repository import MediaRepositoryStore
+from .metrics import ServerMetricsStore
 from .monthly_active_users import MonthlyActiveUsersStore
 from .openid import OpenIdStore
 from .presence import PresenceStore, UserPresenceState
 from .profile import ProfileStore
+from .purge_events import PurgeEventsStore
 from .push_rule import PushRuleStore
 from .pusher import PusherStore
 from .receipts import ReceiptsStore
@@ -87,7 +89,7 @@ class DataStore(
     StateStore,
     SignatureStore,
     ApplicationServiceStore,
-    EventsStore,
+    PurgeEventsStore,
     EventFederationStore,
     MediaRepositoryStore,
     RejectionsStore,
@@ -112,27 +114,16 @@ class DataStore(
     MonthlyActiveUsersStore,
     StatsStore,
     RelationsStore,
-    CacheInvalidationStore,
+    CensorEventsStore,
     UIAuthStore,
+    CacheInvalidationWorkerStore,
+    ServerMetricsStore,
 ):
     def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
 
-        self._stream_id_gen = StreamIdGenerator(
-            db_conn,
-            "events",
-            "stream_ordering",
-            extra_tables=[("local_invites", "stream_id")],
-        )
-        self._backfill_id_gen = StreamIdGenerator(
-            db_conn,
-            "events",
-            "stream_ordering",
-            step=-1,
-            extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
-        )
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
@@ -159,9 +150,6 @@ class DataStore(
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-        self._push_rules_stream_id_gen = ChainedIdGenerator(
-            self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
-        )
         self._pushers_id_gen = StreamIdGenerator(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
@@ -170,8 +158,14 @@ class DataStore(
         )
 
         if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = StreamIdGenerator(
-                db_conn, "cache_invalidation_stream", "stream_id"
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                instance_name="master",
+                table="cache_invalidation_stream_by_instance",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="cache_invalidation_stream_seq",
             )
         else:
             self._cache_id_gen = None
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 46b494b334..f9eef1b78e 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -16,6 +16,7 @@
 
 import abc
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
@@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
-    def get_all_updated_account_data(
-        self, last_global_id, last_room_id, current_id, limit
-    ):
-        """Get all the client account_data that has changed on the server
+    async def get_updated_global_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
+
         Args:
-            last_global_id(int): The position to fetch from for top level data
-            last_room_id(int): The position to fetch from for per room data
-            current_id(int): The position to fetch up to.
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
         Returns:
-            A deferred pair of lists of tuples of stream_id int, user_id string,
-            room_id string, and type string.
+            A list of tuples of stream_id int, user_id string,
+            and type string.
         """
-        if last_room_id == current_id and last_global_id == current_id:
-            return defer.succeed(([], []))
+        if last_id == current_id:
+            return []
 
-        def get_updated_account_data_txn(txn):
+        def get_updated_global_account_data_txn(txn):
             sql = (
                 "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_global_id, current_id, limit))
-            global_results = txn.fetchall()
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return await self.db.runInteraction(
+            "get_updated_global_account_data", get_updated_global_account_data_txn
+        )
+
+    async def get_updated_room_account_data(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple[int, str, str, str]]:
+        """Get the global account_data that has changed, for the account_data stream
 
+        Args:
+            last_id: the last stream_id from the previous batch.
+            current_id: the maximum stream_id to return up to
+            limit: the maximum number of rows to return
+
+        Returns:
+            A list of tuples of stream_id int, user_id string,
+            room_id string and type string.
+        """
+        if last_id == current_id:
+            return []
+
+        def get_updated_room_account_data_txn(txn):
             sql = (
                 "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_room_id, current_id, limit))
-            room_results = txn.fetchall()
-            return global_results, room_results
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
 
-        return self.db.runInteraction(
-            "get_all_updated_account_data_txn", get_updated_account_data_txn
+        return await self.db.runInteraction(
+            "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
     def get_updated_account_data_for_user(self, user_id, stream_id):
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 9c52aa5340..7a1fe8cdd2 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 
 
 def _make_exclusive_regex(services_cache):
-    # We precompie a regex constructed from all the regexes that the AS's
+    # We precompile a regex constructed from all the regexes that the AS's
     # have registered for exclusive users.
     exclusive_user_regexes = [
         regex.pattern
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 4dc5da3fe8..eac5a4e55b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -18,9 +18,13 @@ import itertools
 import logging
 from typing import Any, Iterable, Optional, Tuple
 
-from twisted.internet import defer
-
+from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams.events import (
+    EventsStreamCurrentStateRow,
+    EventsStreamEventRow,
+)
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
@@ -33,28 +37,132 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
 class CacheInvalidationWorkerStore(SQLBaseStore):
-    def get_all_updated_caches(self, last_id, current_id, limit):
+    def __init__(self, database: Database, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._instance_name = hs.get_instance_name()
+
+    async def get_all_updated_caches(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ):
+        """Fetches cache invalidation rows between the two given IDs written
+        by the given instance. Returns at most `limit` rows.
+        """
+
         if last_id == current_id:
-            return defer.succeed([])
+            return []
 
         def get_all_updated_caches_txn(txn):
             # We purposefully don't bound by the current token, as we want to
             # send across cache invalidations as quickly as possible. Cache
             # invalidations are idempotent, so duplicates are fine.
-            sql = (
-                "SELECT stream_id, cache_func, keys, invalidation_ts"
-                " FROM cache_invalidation_stream"
-                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, limit))
+            sql = """
+                SELECT stream_id, cache_func, keys, invalidation_ts
+                FROM cache_invalidation_stream_by_instance
+                WHERE stream_id > ? AND instance_name = ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, instance_name, limit))
             return txn.fetchall()
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "get_all_updated_caches", get_all_updated_caches_txn
         )
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == "events":
+            for row in rows:
+                self._process_event_stream_row(token, row)
+        elif stream_name == "backfill":
+            for row in rows:
+                self._invalidate_caches_for_event(
+                    -token,
+                    row.event_id,
+                    row.room_id,
+                    row.type,
+                    row.state_key,
+                    row.redacts,
+                    row.relates_to,
+                    backfilled=True,
+                )
+        elif stream_name == "caches":
+            if self._cache_id_gen:
+                self._cache_id_gen.advance(instance_name, token)
+
+            for row in rows:
+                if row.cache_func == CURRENT_STATE_CACHE_NAME:
+                    if row.keys is None:
+                        raise Exception(
+                            "Can't send an 'invalidate all' for current state cache"
+                        )
+
+                    room_id = row.keys[0]
+                    members_changed = set(row.keys[1:])
+                    self._invalidate_state_caches(room_id, members_changed)
+                else:
+                    self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
+
+    def _process_event_stream_row(self, token, row):
+        data = row.data
+
+        if row.type == EventsStreamEventRow.TypeId:
+            self._invalidate_caches_for_event(
+                token,
+                data.event_id,
+                data.room_id,
+                data.type,
+                data.state_key,
+                data.redacts,
+                data.relates_to,
+                backfilled=False,
+            )
+        elif row.type == EventsStreamCurrentStateRow.TypeId:
+            self._curr_state_delta_stream_cache.entity_has_changed(
+                row.data.room_id, token
+            )
+
+            if data.type == EventTypes.Member:
+                self.get_rooms_for_user_with_stream_ordering.invalidate(
+                    (data.state_key,)
+                )
+        else:
+            raise Exception("Unknown events stream row type %s" % (row.type,))
+
+    def _invalidate_caches_for_event(
+        self,
+        stream_ordering,
+        event_id,
+        room_id,
+        etype,
+        state_key,
+        redacts,
+        relates_to,
+        backfilled,
+    ):
+        self._invalidate_get_event_cache(event_id)
+
+        self.get_latest_event_ids_in_room.invalidate((room_id,))
+
+        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+
+        if not backfilled:
+            self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
+
+        if redacts:
+            self._invalidate_get_event_cache(redacts)
+
+        if etype == EventTypes.Member:
+            self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
+            self.get_invited_rooms_for_local_user.invalidate((state_key,))
+
+        if relates_to:
+            self.get_relations_for_event.invalidate_many((relates_to,))
+            self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+            self.get_applicable_edit.invalidate((relates_to,))
 
-class CacheInvalidationStore(CacheInvalidationWorkerStore):
     async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
@@ -68,7 +176,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
             return
 
         cache_func.invalidate(keys)
-        await self.runInteraction(
+        await self.db.runInteraction(
             "invalidate_cache_and_stream",
             self._send_invalidation_to_replication,
             cache_func.__name__,
@@ -147,10 +255,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
             # the transaction. However, we want to only get an ID when we want
             # to use it, here, so we need to call __enter__ manually, and have
             # __exit__ called after the transaction finishes.
-            ctx = self._cache_id_gen.get_next()
-            stream_id = ctx.__enter__()
-            txn.call_on_exception(ctx.__exit__, None, None, None)
-            txn.call_after(ctx.__exit__, None, None, None)
+            stream_id = self._cache_id_gen.get_next_txn(txn)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
             if keys is not None:
@@ -158,17 +263,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
 
             self.db.simple_insert_txn(
                 txn,
-                table="cache_invalidation_stream",
+                table="cache_invalidation_stream_by_instance",
                 values={
                     "stream_id": stream_id,
+                    "instance_name": self._instance_name,
                     "cache_func": cache_name,
                     "keys": keys,
                     "invalidation_ts": self.clock.time_msec(),
                 },
             )
 
-    def get_cache_stream_token(self):
+    def get_cache_stream_token(self, instance_name):
         if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token()
+            return self._cache_id_gen.get_current_token(instance_name)
         else:
             return 0
diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py
new file mode 100644
index 0000000000..2d48261724
--- /dev/null
+++ b/synapse/storage/data_stores/main/censor_events.py
@@ -0,0 +1,208 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.internet import defer
+
+from synapse.events.utils import prune_event_dict
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.data_stores.main.events import encode_json
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs: "HomeServer"):
+        super().__init__(database, db_conn, hs)
+
+        def _censor_redactions():
+            return run_as_background_process(
+                "_censor_redactions", self._censor_redactions
+            )
+
+        if self.hs.config.redaction_retention_period is not None:
+            hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+
+    async def _censor_redactions(self):
+        """Censors all redactions older than the configured period that haven't
+        been censored yet.
+
+        By censor we mean update the event_json table with the redacted event.
+        """
+
+        if self.hs.config.redaction_retention_period is None:
+            return
+
+        if not (
+            await self.db.updates.has_completed_background_update(
+                "redactions_have_censored_ts_idx"
+            )
+        ):
+            # We don't want to run this until the appropriate index has been
+            # created.
+            return
+
+        before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+
+        # We fetch all redactions that:
+        #   1. point to an event we have,
+        #   2. has a received_ts from before the cut off, and
+        #   3. we haven't yet censored.
+        #
+        # This is limited to 100 events to ensure that we don't try and do too
+        # much at once. We'll get called again so this should eventually catch
+        # up.
+        sql = """
+            SELECT redactions.event_id, redacts FROM redactions
+            LEFT JOIN events AS original_event ON (
+                redacts = original_event.event_id
+            )
+            WHERE NOT have_censored
+            AND redactions.received_ts <= ?
+            ORDER BY redactions.received_ts ASC
+            LIMIT ?
+        """
+
+        rows = await self.db.execute(
+            "_censor_redactions_fetch", None, sql, before_ts, 100
+        )
+
+        updates = []
+
+        for redaction_id, event_id in rows:
+            redaction_event = await self.get_event(redaction_id, allow_none=True)
+            original_event = await self.get_event(
+                event_id, allow_rejected=True, allow_none=True
+            )
+
+            # The SQL above ensures that we have both the redaction and
+            # original event, so if the `get_event` calls return None it
+            # means that the redaction wasn't allowed. Either way we know that
+            # the result won't change so we mark the fact that we've checked.
+            if (
+                redaction_event
+                and original_event
+                and original_event.internal_metadata.is_redacted()
+            ):
+                # Redaction was allowed
+                pruned_json = encode_json(
+                    prune_event_dict(
+                        original_event.room_version, original_event.get_dict()
+                    )
+                )
+            else:
+                # Redaction wasn't allowed
+                pruned_json = None
+
+            updates.append((redaction_id, event_id, pruned_json))
+
+        def _update_censor_txn(txn):
+            for redaction_id, event_id, pruned_json in updates:
+                if pruned_json:
+                    self._censor_event_txn(txn, event_id, pruned_json)
+
+                self.db.simple_update_one_txn(
+                    txn,
+                    table="redactions",
+                    keyvalues={"event_id": redaction_id},
+                    updatevalues={"have_censored": True},
+                )
+
+        await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+
+    def _censor_event_txn(self, txn, event_id, pruned_json):
+        """Censor an event by replacing its JSON in the event_json table with the
+        provided pruned JSON.
+
+        Args:
+            txn (LoggingTransaction): The database transaction.
+            event_id (str): The ID of the event to censor.
+            pruned_json (str): The pruned JSON
+        """
+        self.db.simple_update_one_txn(
+            txn,
+            table="event_json",
+            keyvalues={"event_id": event_id},
+            updatevalues={"json": pruned_json},
+        )
+
+    @defer.inlineCallbacks
+    def expire_event(self, event_id):
+        """Retrieve and expire an event that has expired, and delete its associated
+        expiry timestamp. If the event can't be retrieved, delete its associated
+        timestamp so we don't try to expire it again in the future.
+
+        Args:
+             event_id (str): The ID of the event to delete.
+        """
+        # Try to retrieve the event's content from the database or the event cache.
+        event = yield self.get_event(event_id)
+
+        def delete_expired_event_txn(txn):
+            # Delete the expiry timestamp associated with this event from the database.
+            self._delete_event_expiry_txn(txn, event_id)
+
+            if not event:
+                # If we can't find the event, log a warning and delete the expiry date
+                # from the database so that we don't try to expire it again in the
+                # future.
+                logger.warning(
+                    "Can't expire event %s because we don't have it.", event_id
+                )
+                return
+
+            # Prune the event's dict then convert it to JSON.
+            pruned_json = encode_json(
+                prune_event_dict(event.room_version, event.get_dict())
+            )
+
+            # Update the event_json table to replace the event's JSON with the pruned
+            # JSON.
+            self._censor_event_txn(txn, event.event_id, pruned_json)
+
+            # We need to invalidate the event cache entry for this event because we
+            # changed its content in the database. We can't call
+            # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+            # right type.
+            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            # Send that invalidation to replication so that other workers also invalidate
+            # the event cache.
+            self._send_invalidation_to_replication(
+                txn, "_get_event_cache", (event.event_id,)
+            )
+
+        yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+    def _delete_event_expiry_txn(self, txn, event_id):
+        """Delete the expiry timestamp associated with an event ID without deleting the
+        actual event.
+
+        Args:
+            txn (LoggingTransaction): The transaction to use to perform the deletion.
+            event_id (str): The event ID to delete the associated expiry timestamp of.
+        """
+        return self.db.simple_delete_txn(
+            txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+        )
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 92bc06919b..71f8d43a76 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -22,7 +22,6 @@ from twisted.internet import defer
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import Database, make_tuple_comparison_clause
-from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.descriptors import Cache
 
 logger = logging.getLogger(__name__)
@@ -361,7 +360,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
 
         self.client_ip_last_seen = Cache(
-            name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+            name="client_ip_last_seen", keylen=4, max_entries=50000
         )
 
         super(ClientIpStore, self).__init__(database, db_conn, hs)
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 03f5141e6c..fb9f798e29 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List, Tuple
+from typing import List, Optional, Set, Tuple
 
 from six import iteritems
 
@@ -342,32 +342,23 @@ class DeviceWorkerStore(SQLBaseStore):
 
     def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
         # We update the device_lists_outbound_last_success with the successfully
-        # poked users. We do the join to see which users need to be inserted and
-        # which updated.
+        # poked users.
         sql = """
-            SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+            SELECT user_id, coalesce(max(o.stream_id), 0)
             FROM device_lists_outbound_pokes as o
-            LEFT JOIN device_lists_outbound_last_success as s
-                USING (destination, user_id)
             WHERE destination = ? AND o.stream_id <= ?
             GROUP BY user_id
         """
         txn.execute(sql, (destination, stream_id))
         rows = txn.fetchall()
 
-        sql = """
-            UPDATE device_lists_outbound_last_success
-            SET stream_id = ?
-            WHERE destination = ? AND user_id = ?
-        """
-        txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
-
-        sql = """
-            INSERT INTO device_lists_outbound_last_success
-            (destination, user_id, stream_id) VALUES (?, ?, ?)
-        """
-        txn.executemany(
-            sql, ((destination, row[0], row[1]) for row in rows if not row[2])
+        self.db.simple_upsert_many_txn(
+            txn=txn,
+            table="device_lists_outbound_last_success",
+            key_names=("destination", "user_id"),
+            key_values=((destination, user_id) for user_id, _ in rows),
+            value_names=("stream_id",),
+            value_values=((stream_id,) for _, stream_id in rows),
         )
 
         # Delete all sent outbound pokes
@@ -654,21 +645,31 @@ class DeviceWorkerStore(SQLBaseStore):
         return results
 
     @defer.inlineCallbacks
-    def get_user_ids_requiring_device_list_resync(self, user_ids: Collection[str]):
+    def get_user_ids_requiring_device_list_resync(
+        self, user_ids: Optional[Collection[str]] = None,
+    ) -> Set[str]:
         """Given a list of remote users return the list of users that we
-        should resync the device lists for.
+        should resync the device lists for. If None is given instead of a list,
+        return every user that we should resync the device lists for.
 
         Returns:
-            Deferred[Set[str]]
+            The IDs of users whose device lists need resync.
         """
-
-        rows = yield self.db.simple_select_many_batch(
-            table="device_lists_remote_resync",
-            column="user_id",
-            iterable=user_ids,
-            retcols=("user_id",),
-            desc="get_user_ids_requiring_device_list_resync",
-        )
+        if user_ids:
+            rows = yield self.db.simple_select_many_batch(
+                table="device_lists_remote_resync",
+                column="user_id",
+                iterable=user_ids,
+                retcols=("user_id",),
+                desc="get_user_ids_requiring_device_list_resync_with_iterable",
+            )
+        else:
+            rows = yield self.db.simple_select_list(
+                table="device_lists_remote_resync",
+                keyvalues=None,
+                retcols=("user_id",),
+                desc="get_user_ids_requiring_device_list_resync",
+            )
 
         return {row["user_id"] for row in rows}
 
@@ -684,6 +685,25 @@ class DeviceWorkerStore(SQLBaseStore):
             desc="make_remote_user_device_cache_as_stale",
         )
 
+    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+        """Mark that we no longer track device lists for remote user.
+        """
+
+        def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+            self.db.simple_delete_txn(
+                txn,
+                table="device_lists_remote_extremeties",
+                keyvalues={"user_id": user_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
+            )
+
+        return self.db.runInteraction(
+            "mark_remote_user_device_list_as_unsubscribed",
+            _mark_remote_user_device_list_as_unsubscribed_txn,
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
@@ -725,6 +745,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
         )
 
+        # a pair of background updates that were added during the 1.14 release cycle,
+        # but replaced with 58/06dlols_unique_idx.py
+        self.db.updates.register_noop_background_update(
+            "device_lists_outbound_last_success_unique_idx",
+        )
+        self.db.updates.register_noop_background_update(
+            "drop_device_lists_outbound_last_success_non_unique_idx",
+        )
+
     @defer.inlineCallbacks
     def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
         def f(conn):
@@ -935,17 +964,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             desc="update_device",
         )
 
-    @defer.inlineCallbacks
-    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
-        """Mark that we no longer track device lists for remote user.
-        """
-        yield self.db.simple_delete(
-            table="device_lists_remote_extremeties",
-            keyvalues={"user_id": user_id},
-            desc="mark_remote_user_device_list_as_unsubscribed",
-        )
-        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
-
     def update_remote_device_list_cache_entry(
         self, user_id, device_id, content, stream_id
     ):
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index bcf746b7ef..20698bfd16 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -25,7 +25,9 @@ from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
 
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -268,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
-    def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
-        """Returns a user's cross-signing key.
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            user_id (str): the user whose key is being requested
-            key_type (str): the type of key that is being requested: either 'master'
-                for a master key, 'self_signing' for a self-signing key, or
-                'user_signing' for a user-signing key
-            from_user_id (str): if specified, signatures made by this user on
-                the key will be included in the result
-
-        Returns:
-            dict of the key data or None if not found
-        """
-        sql = (
-            "SELECT keydata "
-            "  FROM e2e_cross_signing_keys "
-            " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
-        )
-        txn.execute(sql, (user_id, key_type))
-        row = txn.fetchone()
-        if not row:
-            return None
-        key = json.loads(row[0])
-
-        device_id = None
-        for k in key["keys"].values():
-            device_id = k
-
-        if from_user_id is not None:
-            sql = (
-                "SELECT key_id, signature "
-                "  FROM e2e_cross_signing_signatures "
-                " WHERE user_id = ? "
-                "   AND target_user_id = ? "
-                "   AND target_device_id = ? "
-            )
-            txn.execute(sql, (from_user_id, user_id, device_id))
-            row = txn.fetchone()
-            if row:
-                key.setdefault("signatures", {}).setdefault(from_user_id, {})[
-                    row[0]
-                ] = row[1]
-
-        return key
-
+    @defer.inlineCallbacks
     def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
         """Returns a user's cross-signing key.
 
@@ -329,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         Returns:
             dict of the key data or None if not found
         """
-        return self.db.runInteraction(
-            "get_e2e_cross_signing_key",
-            self._get_e2e_cross_signing_key_txn,
-            user_id,
-            key_type,
-            from_user_id,
-        )
+        res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+        user_keys = res.get(user_id)
+        if not user_keys:
+            return None
+        return user_keys.get(key_type)
 
     @cached(num_args=1)
     def _get_bare_e2e_cross_signing_keys(self, user_id):
@@ -391,26 +345,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         """
         result = {}
 
-        batch_size = 100
-        chunks = [
-            user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
-        ]
-        for user_chunk in chunks:
-            sql = """
+        for user_chunk in batch_iter(user_ids, 100):
+            clause, params = make_in_list_sql_clause(
+                txn.database_engine, "k.user_id", user_chunk
+            )
+            sql = (
+                """
                 SELECT k.user_id, k.keytype, k.keydata, k.stream_id
                   FROM e2e_cross_signing_keys k
                   INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
                                 FROM e2e_cross_signing_keys
                                GROUP BY user_id, keytype) s
                  USING (user_id, stream_id, keytype)
-                 WHERE k.user_id IN (%s)
-            """ % (
-                ",".join("?" for u in user_chunk),
+                 WHERE
+            """
+                + clause
             )
-            query_params = []
-            query_params.extend(user_chunk)
 
-            txn.execute(sql, query_params)
+            txn.execute(sql, params)
             rows = self.db.cursor_to_dict(txn)
 
             for row in rows:
@@ -453,15 +405,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                     device_id = k
                 devices[(user_id, device_id)] = key_type
 
-        device_list = list(devices)
-
-        # split into batches
-        batch_size = 100
-        chunks = [
-            device_list[i : i + batch_size]
-            for i in range(0, len(device_list), batch_size)
-        ]
-        for user_chunk in chunks:
+        for batch in batch_iter(devices.keys(), size=100):
             sql = """
                 SELECT target_user_id, target_device_id, key_id, signature
                   FROM e2e_cross_signing_signatures
@@ -469,11 +413,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                    AND (%s)
             """ % (
                 " OR ".join(
-                    "(target_user_id = ? AND target_device_id = ?)" for d in devices
+                    "(target_user_id = ? AND target_device_id = ?)" for _ in batch
                 )
             )
             query_params = [from_user_id]
-            for item in devices:
+            for item in batch:
                 # item is a (user_id, device_id) tuple
                 query_params.extend(item)
 
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index b99439cc37..24ce8c4330 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -640,89 +640,6 @@ class EventFederationStore(EventFederationWorkerStore):
             self._delete_old_forward_extrem_cache, 60 * 60 * 1000
         )
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
-        min_depth = self._get_min_depth_interaction(txn, room_id)
-
-        if min_depth is not None and depth >= min_depth:
-            return
-
-        self.db.simple_upsert_txn(
-            txn,
-            table="room_depth",
-            keyvalues={"room_id": room_id},
-            values={"min_depth": depth},
-        )
-
-    def _handle_mult_prev_events(self, txn, events):
-        """
-        For the given event, update the event edges table and forward and
-        backward extremities tables.
-        """
-        self.db.simple_insert_many_txn(
-            txn,
-            table="event_edges",
-            values=[
-                {
-                    "event_id": ev.event_id,
-                    "prev_event_id": e_id,
-                    "room_id": ev.room_id,
-                    "is_state": False,
-                }
-                for ev in events
-                for e_id in ev.prev_event_ids()
-            ],
-        )
-
-        self._update_backward_extremeties(txn, events)
-
-    def _update_backward_extremeties(self, txn, events):
-        """Updates the event_backward_extremities tables based on the new/updated
-        events being persisted.
-
-        This is called for new events *and* for events that were outliers, but
-        are now being persisted as non-outliers.
-
-        Forward extremities are handled when we first start persisting the events.
-        """
-        events_by_room = {}
-        for ev in events:
-            events_by_room.setdefault(ev.room_id, []).append(ev)
-
-        query = (
-            "INSERT INTO event_backward_extremities (event_id, room_id)"
-            " SELECT ?, ? WHERE NOT EXISTS ("
-            " SELECT 1 FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-            " )"
-            " AND NOT EXISTS ("
-            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
-            " AND outlier = ?"
-            " )"
-        )
-
-        txn.executemany(
-            query,
-            [
-                (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
-                for ev in events
-                for e_id in ev.prev_event_ids()
-                if not ev.internal_metadata.is_outlier()
-            ],
-        )
-
-        query = (
-            "DELETE FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-        )
-        txn.executemany(
-            query,
-            [
-                (ev.event_id, ev.room_id)
-                for ev in events
-                if not ev.internal_metadata.is_outlier()
-            ],
-        )
-
     def _delete_old_forward_extrem_cache(self):
         def _delete_old_forward_extrem_cache_txn(txn):
             # Delete entries older than a month, while making sure we don't delete
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 8eed590929..0321274de2 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -652,69 +652,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             self._start_rotate_notifs, 30 * 60 * 1000
         )
 
-    def _set_push_actions_for_event_and_users_txn(
-        self, txn, events_and_contexts, all_events_and_contexts
-    ):
-        """Handles moving push actions from staging table to main
-        event_push_actions table for all events in `events_and_contexts`.
-
-        Also ensures that all events in `all_events_and_contexts` are removed
-        from the push action staging area.
-
-        Args:
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
-        """
-
-        sql = """
-            INSERT INTO event_push_actions (
-                room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight
-            )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
-            FROM event_push_actions_staging
-            WHERE event_id = ?
-        """
-
-        if events_and_contexts:
-            txn.executemany(
-                sql,
-                (
-                    (
-                        event.room_id,
-                        event.internal_metadata.stream_ordering,
-                        event.depth,
-                        event.event_id,
-                    )
-                    for event, _ in events_and_contexts
-                ),
-            )
-
-        for event, _ in events_and_contexts:
-            user_ids = self.db.simple_select_onecol_txn(
-                txn,
-                table="event_push_actions_staging",
-                keyvalues={"event_id": event.event_id},
-                retcol="user_id",
-            )
-
-            for uid in user_ids:
-                txn.call_after(
-                    self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-                    (event.room_id, uid),
-                )
-
-        # Now we delete the staging area for *all* events that were being
-        # persisted.
-        txn.executemany(
-            "DELETE FROM event_push_actions_staging WHERE event_id = ?",
-            ((event.event_id,) for event, _ in all_events_and_contexts),
-        )
-
     @defer.inlineCallbacks
     def get_push_actions_for_user(
         self, user_id, before=None, limit=50, only_highlight=False
@@ -763,17 +700,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         )
         return result[0] or 0
 
-    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
-        # Sad that we have to blow away the cache for the whole room here
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id,),
-        )
-        txn.execute(
-            "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
-            (room_id, event_id),
-        )
-
     def _remove_old_push_actions_before_txn(
         self, txn, room_id, user_id, stream_ordering
     ):
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index e71c23541d..a6572571b4 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,39 +17,44 @@
 
 import itertools
 import logging
-from collections import Counter as c_counter, OrderedDict, namedtuple
+from collections import OrderedDict, namedtuple
 from functools import wraps
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
 
-from six import iteritems, text_type
+from six import integer_types, iteritems, text_type
 from six.moves import range
 
+import attr
 from canonicaljson import json
 from prometheus_client import Counter
 
 from twisted.internet import defer
 
 import synapse.metrics
-from synapse.api.constants import EventContentFields, EventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    Membership,
+    RelationTypes,
+)
 from synapse.api.room_versions import RoomVersions
+from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
-from synapse.events.utils import prune_event_dict
 from synapse.logging.utils import log_function
-from synapse.metrics import BucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.data_stores.main.event_federation import EventFederationStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.data_stores.main.search import SearchEntry
 from synapse.storage.database import Database, LoggingTransaction
-from synapse.storage.persist_events import DeltaState
-from synapse.types import RoomStreamToken, StateMap, get_domain_from_id
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import StateMap, get_domain_from_id
 from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.iterutils import batch_iter
 
+if TYPE_CHECKING:
+    from synapse.storage.data_stores.main import DataStore
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
@@ -94,58 +99,49 @@ def _retry_on_integrity_error(func):
     return f
 
 
-# inherits from EventFederationStore so that we can call _update_backward_extremities
-# and _handle_mult_prev_events (though arguably those could both be moved in here)
-class EventsStore(
-    StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
-):
-    def __init__(self, database: Database, db_conn, hs):
-        super(EventsStore, self).__init__(database, db_conn, hs)
+@attr.s(slots=True)
+class DeltaState:
+    """Deltas to use to update the `current_state_events` table.
 
-        # Collect metrics on the number of forward extremities that exist.
-        # Counter of number of extremities to count
-        self._current_forward_extremities_amount = c_counter()
+    Attributes:
+        to_delete: List of type/state_keys to delete from current state
+        to_insert: Map of state to upsert into current state
+        no_longer_in_room: The server is not longer in the room, so the room
+            should e.g. be removed from `current_state_events` table.
+    """
 
-        BucketCollector(
-            "synapse_forward_extremities",
-            lambda: self._current_forward_extremities_amount,
-            buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
-        )
+    to_delete = attr.ib(type=List[Tuple[str, str]])
+    to_insert = attr.ib(type=StateMap[str])
+    no_longer_in_room = attr.ib(type=bool, default=False)
 
-        # Read the extrems every 60 minutes
-        def read_forward_extremities():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "read_forward_extremities", self._read_forward_extremities
-            )
 
-        hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+class PersistEventsStore:
+    """Contains all the functions for writing events to the database.
 
-        def _censor_redactions():
-            return run_as_background_process(
-                "_censor_redactions", self._censor_redactions
-            )
+    Should only be instantiated on one process (when using a worker mode setup).
+
+    Note: This is not part of the `DataStore` mixin.
+    """
 
-        if self.hs.config.redaction_retention_period is not None:
-            hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+    def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"):
+        self.hs = hs
+        self.db = db
+        self.store = main_data_store
+        self.database_engine = db.engine
+        self._clock = hs.get_clock()
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
-    @defer.inlineCallbacks
-    def _read_forward_extremities(self):
-        def fetch(txn):
-            txn.execute(
-                """
-                select count(*) c from event_forward_extremities
-                group by room_id
-                """
-            )
-            return txn.fetchall()
+        # Ideally we'd move these ID gens here, unfortunately some other ID
+        # generators are chained off them so doing so is a bit of a PITA.
+        self._backfill_id_gen = self.store._backfill_id_gen  # type: StreamIdGenerator
+        self._stream_id_gen = self.store._stream_id_gen  # type: StreamIdGenerator
 
-        res = yield self.db.runInteraction("read_forward_extremities", fetch)
-        self._current_forward_extremities_amount = c_counter([x[0] for x in res])
+        # This should only exist on instances that are configured to write
+        assert (
+            hs.config.worker.writers.events == hs.get_instance_name()
+        ), "Can only instantiate EventsStore on master"
 
     @_retry_on_integrity_error
     @defer.inlineCallbacks
@@ -237,10 +233,10 @@ class EventsStore(
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
             for room_id, new_state in iteritems(current_state_for_room):
-                self.get_current_state_ids.prefill((room_id,), new_state)
+                self.store.get_current_state_ids.prefill((room_id,), new_state)
 
             for room_id, latest_event_ids in iteritems(new_forward_extremeties):
-                self.get_latest_event_ids_in_room.prefill(
+                self.store.get_latest_event_ids_in_room.prefill(
                     (room_id,), list(latest_event_ids)
                 )
 
@@ -586,7 +582,7 @@ class EventsStore(
                 )
 
             txn.call_after(
-                self._curr_state_delta_stream_cache.entity_has_changed,
+                self.store._curr_state_delta_stream_cache.entity_has_changed,
                 room_id,
                 stream_id,
             )
@@ -606,10 +602,13 @@ class EventsStore(
 
             for member in members_changed:
                 txn.call_after(
-                    self.get_rooms_for_user_with_stream_ordering.invalidate, (member,)
+                    self.store.get_rooms_for_user_with_stream_ordering.invalidate,
+                    (member,),
                 )
 
-            self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+            self.store._invalidate_state_caches_and_stream(
+                txn, room_id, members_changed
+            )
 
     def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
         """Update the room version in the database based off current state
@@ -647,7 +646,9 @@ class EventsStore(
             self.db.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
             )
-            txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
+            txn.call_after(
+                self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
+            )
 
         self.db.simple_insert_many_txn(
             txn,
@@ -713,10 +714,10 @@ class EventsStore(
         depth_updates = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
-            txn.call_after(self._invalidate_get_event_cache, event.event_id)
+            txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
             if not backfilled:
                 txn.call_after(
-                    self._events_stream_cache.entity_has_changed,
+                    self.store._events_stream_cache.entity_has_changed,
                     event.room_id,
                     event.internal_metadata.stream_ordering,
                 )
@@ -1088,13 +1089,15 @@ class EventsStore(
 
         def prefill():
             for cache_entry in to_prefill:
-                self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry)
+                self.store._get_event_cache.prefill(
+                    (cache_entry[0].event_id,), cache_entry
+                )
 
         txn.call_after(prefill)
 
     def _store_redaction(self, txn, event):
         # invalidate the cache for the redacted event
-        txn.call_after(self._invalidate_get_event_cache, event.redacts)
+        txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
 
         self.db.simple_insert_txn(
             txn,
@@ -1106,783 +1109,512 @@ class EventsStore(
             },
         )
 
-    async def _censor_redactions(self):
-        """Censors all redactions older than the configured period that haven't
-        been censored yet.
+    def insert_labels_for_event_txn(
+        self, txn, event_id, labels, room_id, topological_ordering
+    ):
+        """Store the mapping between an event's ID and its labels, with one row per
+        (event_id, label) tuple.
 
-        By censor we mean update the event_json table with the redacted event.
+        Args:
+            txn (LoggingTransaction): The transaction to execute.
+            event_id (str): The event's ID.
+            labels (list[str]): A list of text labels.
+            room_id (str): The ID of the room the event was sent to.
+            topological_ordering (int): The position of the event in the room's topology.
         """
+        return self.db.simple_insert_many_txn(
+            txn=txn,
+            table="event_labels",
+            values=[
+                {
+                    "event_id": event_id,
+                    "label": label,
+                    "room_id": room_id,
+                    "topological_ordering": topological_ordering,
+                }
+                for label in labels
+            ],
+        )
 
-        if self.hs.config.redaction_retention_period is None:
-            return
-
-        if not (
-            await self.db.updates.has_completed_background_update(
-                "redactions_have_censored_ts_idx"
-            )
-        ):
-            # We don't want to run this until the appropriate index has been
-            # created.
-            return
-
-        before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+        """Save the expiry timestamp associated with a given event ID.
 
-        # We fetch all redactions that:
-        #   1. point to an event we have,
-        #   2. has a received_ts from before the cut off, and
-        #   3. we haven't yet censored.
-        #
-        # This is limited to 100 events to ensure that we don't try and do too
-        # much at once. We'll get called again so this should eventually catch
-        # up.
-        sql = """
-            SELECT redactions.event_id, redacts FROM redactions
-            LEFT JOIN events AS original_event ON (
-                redacts = original_event.event_id
-            )
-            WHERE NOT have_censored
-            AND redactions.received_ts <= ?
-            ORDER BY redactions.received_ts ASC
-            LIMIT ?
+        Args:
+            txn (LoggingTransaction): The database transaction to use.
+            event_id (str): The event ID the expiry timestamp is associated with.
+            expiry_ts (int): The timestamp at which to expire (delete) the event.
         """
-
-        rows = await self.db.execute(
-            "_censor_redactions_fetch", None, sql, before_ts, 100
+        return self.db.simple_insert_txn(
+            txn=txn,
+            table="event_expiry",
+            values={"event_id": event_id, "expiry_ts": expiry_ts},
         )
 
-        updates = []
+    def _store_event_reference_hashes_txn(self, txn, events):
+        """Store a hash for a PDU
+        Args:
+            txn (cursor):
+            events (list): list of Events.
+        """
 
-        for redaction_id, event_id in rows:
-            redaction_event = await self.get_event(redaction_id, allow_none=True)
-            original_event = await self.get_event(
-                event_id, allow_rejected=True, allow_none=True
+        vals = []
+        for event in events:
+            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+            vals.append(
+                {
+                    "event_id": event.event_id,
+                    "algorithm": ref_alg,
+                    "hash": memoryview(ref_hash_bytes),
+                }
             )
 
-            # The SQL above ensures that we have both the redaction and
-            # original event, so if the `get_event` calls return None it
-            # means that the redaction wasn't allowed. Either way we know that
-            # the result won't change so we mark the fact that we've checked.
-            if (
-                redaction_event
-                and original_event
-                and original_event.internal_metadata.is_redacted()
-            ):
-                # Redaction was allowed
-                pruned_json = encode_json(
-                    prune_event_dict(
-                        original_event.room_version, original_event.get_dict()
-                    )
-                )
-            else:
-                # Redaction wasn't allowed
-                pruned_json = None
-
-            updates.append((redaction_id, event_id, pruned_json))
-
-        def _update_censor_txn(txn):
-            for redaction_id, event_id, pruned_json in updates:
-                if pruned_json:
-                    self._censor_event_txn(txn, event_id, pruned_json)
-
-                self.db.simple_update_one_txn(
-                    txn,
-                    table="redactions",
-                    keyvalues={"event_id": redaction_id},
-                    updatevalues={"have_censored": True},
-                )
-
-        await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+        self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
 
-    def _censor_event_txn(self, txn, event_id, pruned_json):
-        """Censor an event by replacing its JSON in the event_json table with the
-        provided pruned JSON.
-
-        Args:
-            txn (LoggingTransaction): The database transaction.
-            event_id (str): The ID of the event to censor.
-            pruned_json (str): The pruned JSON
+    def _store_room_members_txn(self, txn, events, backfilled):
+        """Store a room member in the database.
         """
-        self.db.simple_update_one_txn(
+        self.db.simple_insert_many_txn(
             txn,
-            table="event_json",
-            keyvalues={"event_id": event_id},
-            updatevalues={"json": pruned_json},
+            table="room_memberships",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "user_id": event.state_key,
+                    "sender": event.user_id,
+                    "room_id": event.room_id,
+                    "membership": event.membership,
+                    "display_name": event.content.get("displayname", None),
+                    "avatar_url": event.content.get("avatar_url", None),
+                }
+                for event in events
+            ],
         )
 
-    @defer.inlineCallbacks
-    def count_daily_messages(self):
-        """
-        Returns an estimate of the number of messages sent in the last day.
-
-        If it has been significantly less or more than one day since the last
-        call to this function, it will return None.
-        """
-
-        def _count_messages(txn):
-            sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
-                WHERE type = 'm.room.message'
-                AND stream_ordering > ?
-            """
-            txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
-            return count
-
-        ret = yield self.db.runInteraction("count_messages", _count_messages)
-        return ret
-
-    @defer.inlineCallbacks
-    def count_daily_sent_messages(self):
-        def _count_messages(txn):
-            # This is good enough as if you have silly characters in your own
-            # hostname then thats your own fault.
-            like_clause = "%:" + self.hs.hostname
-
-            sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
-                WHERE type = 'm.room.message'
-                    AND sender LIKE ?
-                AND stream_ordering > ?
-            """
-
-            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            (count,) = txn.fetchone()
-            return count
-
-        ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
-        return ret
-
-    @defer.inlineCallbacks
-    def count_daily_active_rooms(self):
-        def _count(txn):
-            sql = """
-                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
-                WHERE type = 'm.room.message'
-                AND stream_ordering > ?
-            """
-            txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
-            return count
-
-        ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
-        return ret
-
-    @cached(num_args=5, max_entries=10)
-    def get_all_new_events(
-        self,
-        last_backfill_id,
-        last_forward_id,
-        current_backfill_id,
-        current_forward_id,
-        limit,
-    ):
-        """Get all the new events that have arrived at the server either as
-        new events or as backfilled events"""
-        have_backfill_events = last_backfill_id != current_backfill_id
-        have_forward_events = last_forward_id != current_forward_id
-
-        if not have_backfill_events and not have_forward_events:
-            return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
-        def get_all_new_events_txn(txn):
-            sql = (
-                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " WHERE ? < stream_ordering AND stream_ordering <= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
+        for event in events:
+            txn.call_after(
+                self.store._membership_stream_cache.entity_has_changed,
+                event.state_key,
+                event.internal_metadata.stream_ordering,
             )
-            if have_forward_events:
-                txn.execute(sql, (last_forward_id, current_forward_id, limit))
-                new_forward_events = txn.fetchall()
-
-                if len(new_forward_events) == limit:
-                    upper_bound = new_forward_events[-1][0]
-                else:
-                    upper_bound = current_forward_id
-
-                sql = (
-                    "SELECT event_stream_ordering, event_id, state_group"
-                    " FROM ex_outlier_stream"
-                    " WHERE ? > event_stream_ordering"
-                    " AND event_stream_ordering >= ?"
-                    " ORDER BY event_stream_ordering DESC"
-                )
-                txn.execute(sql, (last_forward_id, upper_bound))
-                forward_ex_outliers = txn.fetchall()
-            else:
-                new_forward_events = []
-                forward_ex_outliers = []
-
-            sql = (
-                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " WHERE ? > stream_ordering AND stream_ordering >= ?"
-                " ORDER BY stream_ordering DESC"
-                " LIMIT ?"
+            txn.call_after(
+                self.store.get_invited_rooms_for_local_user.invalidate,
+                (event.state_key,),
             )
-            if have_backfill_events:
-                txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
-                new_backfill_events = txn.fetchall()
 
-                if len(new_backfill_events) == limit:
-                    upper_bound = new_backfill_events[-1][0]
-                else:
-                    upper_bound = current_backfill_id
-
-                sql = (
-                    "SELECT -event_stream_ordering, event_id, state_group"
-                    " FROM ex_outlier_stream"
-                    " WHERE ? > event_stream_ordering"
-                    " AND event_stream_ordering >= ?"
-                    " ORDER BY event_stream_ordering DESC"
-                )
-                txn.execute(sql, (-last_backfill_id, -upper_bound))
-                backward_ex_outliers = txn.fetchall()
-            else:
-                new_backfill_events = []
-                backward_ex_outliers = []
-
-            return AllNewEventsResult(
-                new_forward_events,
-                new_backfill_events,
-                forward_ex_outliers,
-                backward_ex_outliers,
+            # We update the local_invites table only if the event is "current",
+            # i.e., its something that has just happened. If the event is an
+            # outlier it is only current if its an "out of band membership",
+            # like a remote invite or a rejection of a remote invite.
+            is_new_state = not backfilled and (
+                not event.internal_metadata.is_outlier()
+                or event.internal_metadata.is_out_of_band_membership()
             )
+            is_mine = self.is_mine_id(event.state_key)
+            if is_new_state and is_mine:
+                if event.membership == Membership.INVITE:
+                    self.db.simple_insert_txn(
+                        txn,
+                        table="local_invites",
+                        values={
+                            "event_id": event.event_id,
+                            "invitee": event.state_key,
+                            "inviter": event.sender,
+                            "room_id": event.room_id,
+                            "stream_id": event.internal_metadata.stream_ordering,
+                        },
+                    )
+                else:
+                    sql = (
+                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+                        " AND replaced_by is NULL"
+                    )
 
-        return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
-
-    def purge_history(self, room_id, token, delete_local_events):
-        """Deletes room history before a certain point
-
-        Args:
-            room_id (str):
+                    txn.execute(
+                        sql,
+                        (
+                            event.internal_metadata.stream_ordering,
+                            event.event_id,
+                            event.room_id,
+                            event.state_key,
+                        ),
+                    )
 
-            token (str): A topological token to delete events before
+                # We also update the `local_current_membership` table with
+                # latest invite info. This will usually get updated by the
+                # `current_state_events` handling, unless its an outlier.
+                if event.internal_metadata.is_outlier():
+                    # This should only happen for out of band memberships, so
+                    # we add a paranoia check.
+                    assert event.internal_metadata.is_out_of_band_membership()
+
+                    self.db.simple_upsert_txn(
+                        txn,
+                        table="local_current_membership",
+                        keyvalues={
+                            "room_id": event.room_id,
+                            "user_id": event.state_key,
+                        },
+                        values={
+                            "event_id": event.event_id,
+                            "membership": event.membership,
+                        },
+                    )
 
-            delete_local_events (bool):
-                if True, we will delete local events as well as remote ones
-                (instead of just marking them as outliers and deleting their
-                state groups).
+    def _handle_event_relations(self, txn, event):
+        """Handles inserting relation data during peristence of events
 
-        Returns:
-            Deferred[set[int]]: The set of state groups that are referenced by
-            deleted events.
+        Args:
+            txn
+            event (EventBase)
         """
+        relation = event.content.get("m.relates_to")
+        if not relation:
+            # No relations
+            return
 
-        return self.db.runInteraction(
-            "purge_history",
-            self._purge_history_txn,
-            room_id,
-            token,
-            delete_local_events,
-        )
+        rel_type = relation.get("rel_type")
+        if rel_type not in (
+            RelationTypes.ANNOTATION,
+            RelationTypes.REFERENCE,
+            RelationTypes.REPLACE,
+        ):
+            # Unknown relation type
+            return
 
-    def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
-        token = RoomStreamToken.parse(token_str)
-
-        # Tables that should be pruned:
-        #     event_auth
-        #     event_backward_extremities
-        #     event_edges
-        #     event_forward_extremities
-        #     event_json
-        #     event_push_actions
-        #     event_reference_hashes
-        #     event_search
-        #     event_to_state_groups
-        #     events
-        #     rejections
-        #     room_depth
-        #     state_groups
-        #     state_groups_state
-
-        # we will build a temporary table listing the events so that we don't
-        # have to keep shovelling the list back and forth across the
-        # connection. Annoyingly the python sqlite driver commits the
-        # transaction on CREATE, so let's do this first.
-        #
-        # furthermore, we might already have the table from a previous (failed)
-        # purge attempt, so let's drop the table first.
+        parent_id = relation.get("event_id")
+        if not parent_id:
+            # Invalid relation
+            return
 
-        txn.execute("DROP TABLE IF EXISTS events_to_purge")
+        aggregation_key = relation.get("key")
 
-        txn.execute(
-            "CREATE TEMPORARY TABLE events_to_purge ("
-            "    event_id TEXT NOT NULL,"
-            "    should_delete BOOLEAN NOT NULL"
-            ")"
+        self.db.simple_insert_txn(
+            txn,
+            table="event_relations",
+            values={
+                "event_id": event.event_id,
+                "relates_to_id": parent_id,
+                "relation_type": rel_type,
+                "aggregation_key": aggregation_key,
+            },
         )
 
-        # First ensure that we're not about to delete all the forward extremeties
-        txn.execute(
-            "SELECT e.event_id, e.depth FROM events as e "
-            "INNER JOIN event_forward_extremities as f "
-            "ON e.event_id = f.event_id "
-            "AND e.room_id = f.room_id "
-            "WHERE f.room_id = ?",
-            (room_id,),
+        txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,))
+        txn.call_after(
+            self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
         )
-        rows = txn.fetchall()
-        max_depth = max(row[1] for row in rows)
-
-        if max_depth < token.topological:
-            # We need to ensure we don't delete all the events from the database
-            # otherwise we wouldn't be able to send any events (due to not
-            # having any backwards extremeties)
-            raise SynapseError(
-                400, "topological_ordering is greater than forward extremeties"
-            )
-
-        logger.info("[purge] looking for events to delete")
-
-        should_delete_expr = "state_key IS NULL"
-        should_delete_params = ()
-        if not delete_local_events:
-            should_delete_expr += " AND event_id NOT LIKE ?"
-
-            # We include the parameter twice since we use the expression twice
-            should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
 
-        should_delete_params += (room_id, token.topological)
+        if rel_type == RelationTypes.REPLACE:
+            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
 
-        # Note that we insert events that are outliers and aren't going to be
-        # deleted, as nothing will happen to them.
-        txn.execute(
-            "INSERT INTO events_to_purge"
-            " SELECT event_id, %s"
-            " FROM events AS e LEFT JOIN state_events USING (event_id)"
-            " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
-            % (should_delete_expr, should_delete_expr),
-            should_delete_params,
-        )
+    def _handle_redaction(self, txn, redacted_event_id):
+        """Handles receiving a redaction and checking whether we need to remove
+        any redacted relations from the database.
 
-        # We create the indices *after* insertion as that's a lot faster.
+        Args:
+            txn
+            redacted_event_id (str): The event that was redacted.
+        """
 
-        # create an index on should_delete because later we'll be looking for
-        # the should_delete / shouldn't_delete subsets
-        txn.execute(
-            "CREATE INDEX events_to_purge_should_delete"
-            " ON events_to_purge(should_delete)"
+        self.db.simple_delete_txn(
+            txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
 
-        # We do joins against events_to_purge for e.g. calculating state
-        # groups to purge, etc., so lets make an index.
-        txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
-
-        txn.execute("SELECT event_id, should_delete FROM events_to_purge")
-        event_rows = txn.fetchall()
-        logger.info(
-            "[purge] found %i events before cutoff, of which %i can be deleted",
-            len(event_rows),
-            sum(1 for e in event_rows if e[1]),
-        )
+    def _store_room_topic_txn(self, txn, event):
+        if hasattr(event, "content") and "topic" in event.content:
+            self.store_event_search_txn(
+                txn, event, "content.topic", event.content["topic"]
+            )
 
-        logger.info("[purge] Finding new backward extremities")
+    def _store_room_name_txn(self, txn, event):
+        if hasattr(event, "content") and "name" in event.content:
+            self.store_event_search_txn(
+                txn, event, "content.name", event.content["name"]
+            )
 
-        # We calculate the new entries for the backward extremeties by finding
-        # events to be purged that are pointed to by events we're not going to
-        # purge.
-        txn.execute(
-            "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
-            " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
-            " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
-            " WHERE ep2.event_id IS NULL"
-        )
-        new_backwards_extrems = txn.fetchall()
+    def _store_room_message_txn(self, txn, event):
+        if hasattr(event, "content") and "body" in event.content:
+            self.store_event_search_txn(
+                txn, event, "content.body", event.content["body"]
+            )
 
-        logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+    def _store_retention_policy_for_room_txn(self, txn, event):
+        if hasattr(event, "content") and (
+            "min_lifetime" in event.content or "max_lifetime" in event.content
+        ):
+            if (
+                "min_lifetime" in event.content
+                and not isinstance(event.content.get("min_lifetime"), integer_types)
+            ) or (
+                "max_lifetime" in event.content
+                and not isinstance(event.content.get("max_lifetime"), integer_types)
+            ):
+                # Ignore the event if one of the value isn't an integer.
+                return
 
-        txn.execute(
-            "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
-        )
+            self.db.simple_insert_txn(
+                txn=txn,
+                table="room_retention",
+                values={
+                    "room_id": event.room_id,
+                    "event_id": event.event_id,
+                    "min_lifetime": event.content.get("min_lifetime"),
+                    "max_lifetime": event.content.get("max_lifetime"),
+                },
+            )
 
-        # Update backward extremeties
-        txn.executemany(
-            "INSERT INTO event_backward_extremities (room_id, event_id)"
-            " VALUES (?, ?)",
-            [(room_id, event_id) for event_id, in new_backwards_extrems],
-        )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_retention_policy_for_room, (event.room_id,)
+            )
 
-        logger.info("[purge] finding state groups referenced by deleted events")
+    def store_event_search_txn(self, txn, event, key, value):
+        """Add event to the search table
 
-        # Get all state groups that are referenced by events that are to be
-        # deleted.
-        txn.execute(
-            """
-            SELECT DISTINCT state_group FROM events_to_purge
-            INNER JOIN event_to_state_groups USING (event_id)
+        Args:
+            txn (cursor):
+            event (EventBase):
+            key (str):
+            value (str):
         """
+        self.store.store_search_entries_txn(
+            txn,
+            (
+                SearchEntry(
+                    key=key,
+                    value=value,
+                    event_id=event.event_id,
+                    room_id=event.room_id,
+                    stream_ordering=event.internal_metadata.stream_ordering,
+                    origin_server_ts=event.origin_server_ts,
+                ),
+            ),
         )
 
-        referenced_state_groups = {sg for sg, in txn}
-        logger.info(
-            "[purge] found %i referenced state groups", len(referenced_state_groups)
-        )
+    def _set_push_actions_for_event_and_users_txn(
+        self, txn, events_and_contexts, all_events_and_contexts
+    ):
+        """Handles moving push actions from staging table to main
+        event_push_actions table for all events in `events_and_contexts`.
 
-        logger.info("[purge] removing events from event_to_state_groups")
-        txn.execute(
-            "DELETE FROM event_to_state_groups "
-            "WHERE event_id IN (SELECT event_id from events_to_purge)"
-        )
-        for event_id, _ in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+        Also ensures that all events in `all_events_and_contexts` are removed
+        from the push action staging area.
 
-        # Delete all remote non-state events
-        for table in (
-            "events",
-            "event_json",
-            "event_auth",
-            "event_edges",
-            "event_forward_extremities",
-            "event_reference_hashes",
-            "event_search",
-            "rejections",
-        ):
-            logger.info("[purge] removing events from %s", table)
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            all_events_and_contexts (list[(EventBase, EventContext)]): all
+                events that we were going to persist. This includes events
+                we've already persisted, etc, that wouldn't appear in
+                events_and_context.
+        """
 
-            txn.execute(
-                "DELETE FROM %s WHERE event_id IN ("
-                "    SELECT event_id FROM events_to_purge WHERE should_delete"
-                ")" % (table,)
+        sql = """
+            INSERT INTO event_push_actions (
+                room_id, event_id, user_id, actions, stream_ordering,
+                topological_ordering, notif, highlight
             )
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            FROM event_push_actions_staging
+            WHERE event_id = ?
+        """
 
-        # event_push_actions lacks an index on event_id, and has one on
-        # (room_id, event_id) instead.
-        for table in ("event_push_actions",):
-            logger.info("[purge] removing events from %s", table)
+        if events_and_contexts:
+            txn.executemany(
+                sql,
+                (
+                    (
+                        event.room_id,
+                        event.internal_metadata.stream_ordering,
+                        event.depth,
+                        event.event_id,
+                    )
+                    for event, _ in events_and_contexts
+                ),
+            )
 
-            txn.execute(
-                "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
-                "    SELECT event_id FROM events_to_purge WHERE should_delete"
-                ")" % (table,),
-                (room_id,),
+        for event, _ in events_and_contexts:
+            user_ids = self.db.simple_select_onecol_txn(
+                txn,
+                table="event_push_actions_staging",
+                keyvalues={"event_id": event.event_id},
+                retcol="user_id",
             )
 
-        # Mark all state and own events as outliers
-        logger.info("[purge] marking remaining events as outliers")
-        txn.execute(
-            "UPDATE events SET outlier = ?"
-            " WHERE event_id IN ("
-            "    SELECT event_id FROM events_to_purge "
-            "    WHERE NOT should_delete"
-            ")",
-            (True,),
+            for uid in user_ids:
+                txn.call_after(
+                    self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+                    (event.room_id, uid),
+                )
+
+        # Now we delete the staging area for *all* events that were being
+        # persisted.
+        txn.executemany(
+            "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+            ((event.event_id,) for event, _ in all_events_and_contexts),
         )
 
-        # synapse tries to take out an exclusive lock on room_depth whenever it
-        # persists events (because upsert), and once we run this update, we
-        # will block that for the rest of our transaction.
-        #
-        # So, let's stick it at the end so that we don't block event
-        # persistence.
-        #
-        # We do this by calculating the minimum depth of the backwards
-        # extremities. However, the events in event_backward_extremities
-        # are ones we don't have yet so we need to look at the events that
-        # point to it via event_edges table.
-        txn.execute(
-            """
-            SELECT COALESCE(MIN(depth), 0)
-            FROM event_backward_extremities AS eb
-            INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
-            INNER JOIN events AS e ON e.event_id = eg.event_id
-            WHERE eb.room_id = ?
-        """,
+    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+        # Sad that we have to blow away the cache for the whole room here
+        txn.call_after(
+            self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
             (room_id,),
         )
-        (min_depth,) = txn.fetchone()
-
-        logger.info("[purge] updating room_depth to %d", min_depth)
-
         txn.execute(
-            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
-            (min_depth, room_id),
+            "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
+            (room_id, event_id),
         )
 
-        # finally, drop the temp table. this will commit the txn in sqlite,
-        # so make sure to keep this actually last.
-        txn.execute("DROP TABLE events_to_purge")
-
-        logger.info("[purge] done")
-
-        return referenced_state_groups
-
-    def purge_room(self, room_id):
-        """Deletes all record of a room
+    def _store_rejections_txn(self, txn, event_id, reason):
+        self.db.simple_insert_txn(
+            txn,
+            table="rejections",
+            values={
+                "event_id": event_id,
+                "reason": reason,
+                "last_check": self._clock.time_msec(),
+            },
+        )
 
-        Args:
-            room_id (str)
+    def _store_event_state_mappings_txn(
+        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
+    ):
+        state_groups = {}
+        for event, context in events_and_contexts:
+            if event.internal_metadata.is_outlier():
+                continue
 
-        Returns:
-            Deferred[List[int]]: The list of state groups to delete.
-        """
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.state_group_before_event
+                continue
 
-        return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
+            state_groups[event.event_id] = context.state_group
 
-    def _purge_room_txn(self, txn, room_id):
-        # First we fetch all the state groups that should be deleted, before
-        # we delete that information.
-        txn.execute(
-            """
-                SELECT DISTINCT state_group FROM events
-                INNER JOIN event_to_state_groups USING(event_id)
-                WHERE events.room_id = ?
-            """,
-            (room_id,),
+        self.db.simple_insert_many_txn(
+            txn,
+            table="event_to_state_groups",
+            values=[
+                {"state_group": state_group_id, "event_id": event_id}
+                for event_id, state_group_id in iteritems(state_groups)
+            ],
         )
 
-        state_groups = [row[0] for row in txn]
-
-        # Now we delete tables which lack an index on room_id but have one on event_id
-        for table in (
-            "event_auth",
-            "event_edges",
-            "event_push_actions_staging",
-            "event_reference_hashes",
-            "event_relations",
-            "event_to_state_groups",
-            "redactions",
-            "rejections",
-            "state_events",
-        ):
-            logger.info("[purge] removing %s from %s", room_id, table)
-
-            txn.execute(
-                """
-                DELETE FROM %s WHERE event_id IN (
-                  SELECT event_id FROM events WHERE room_id=?
-                )
-                """
-                % (table,),
-                (room_id,),
+        for event_id, state_group_id in iteritems(state_groups):
+            txn.call_after(
+                self.store._get_state_group_for_event.prefill,
+                (event_id,),
+                state_group_id,
             )
 
-        # and finally, the tables with an index on room_id (or no useful index)
-        for table in (
-            "current_state_events",
-            "event_backward_extremities",
-            "event_forward_extremities",
-            "event_json",
-            "event_push_actions",
-            "event_search",
-            "events",
-            "group_rooms",
-            "public_room_list_stream",
-            "receipts_graph",
-            "receipts_linearized",
-            "room_aliases",
-            "room_depth",
-            "room_memberships",
-            "room_stats_state",
-            "room_stats_current",
-            "room_stats_historical",
-            "room_stats_earliest_token",
-            "rooms",
-            "stream_ordering_to_exterm",
-            "users_in_public_rooms",
-            "users_who_share_private_rooms",
-            # no useful index, but let's clear them anyway
-            "appservice_room_list",
-            "e2e_room_keys",
-            "event_push_summary",
-            "pusher_throttle",
-            "group_summary_rooms",
-            "local_invites",
-            "room_account_data",
-            "room_tags",
-            "local_current_membership",
-        ):
-            logger.info("[purge] removing %s from %s", room_id, table)
-            txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
-
-        # Other tables we do NOT need to clear out:
-        #
-        #  - blocked_rooms
-        #    This is important, to make sure that we don't accidentally rejoin a blocked
-        #    room after it was purged
-        #
-        #  - user_directory
-        #    This has a room_id column, but it is unused
-        #
-
-        # Other tables that we might want to consider clearing out include:
-        #
-        #  - event_reports
-        #       Given that these are intended for abuse management my initial
-        #       inclination is to leave them in place.
-        #
-        #  - current_state_delta_stream
-        #  - ex_outlier_stream
-        #  - room_tags_revisions
-        #       The problem with these is that they are largeish and there is no room_id
-        #       index on them. In any case we should be clearing out 'stream' tables
-        #       periodically anyway (#5888)
-
-        # TODO: we could probably usefully do a bunch of cache invalidation here
-
-        logger.info("[purge] done")
-
-        return state_groups
+    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+        min_depth = self.store._get_min_depth_interaction(txn, room_id)
 
-    async def is_event_after(self, event_id1, event_id2):
-        """Returns True if event_id1 is after event_id2 in the stream
-        """
-        to_1, so_1 = await self._get_event_ordering(event_id1)
-        to_2, so_2 = await self._get_event_ordering(event_id2)
-        return (to_1, so_1) > (to_2, so_2)
+        if min_depth is not None and depth >= min_depth:
+            return
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def _get_event_ordering(self, event_id):
-        res = yield self.db.simple_select_one(
-            table="events",
-            retcols=["topological_ordering", "stream_ordering"],
-            keyvalues={"event_id": event_id},
-            allow_none=True,
+        self.db.simple_upsert_txn(
+            txn,
+            table="room_depth",
+            keyvalues={"room_id": room_id},
+            values={"min_depth": depth},
         )
 
-        if not res:
-            raise SynapseError(404, "Could not find event %s" % (event_id,))
-
-        return (int(res["topological_ordering"]), int(res["stream_ordering"]))
-
-    def insert_labels_for_event_txn(
-        self, txn, event_id, labels, room_id, topological_ordering
-    ):
-        """Store the mapping between an event's ID and its labels, with one row per
-        (event_id, label) tuple.
-
-        Args:
-            txn (LoggingTransaction): The transaction to execute.
-            event_id (str): The event's ID.
-            labels (list[str]): A list of text labels.
-            room_id (str): The ID of the room the event was sent to.
-            topological_ordering (int): The position of the event in the room's topology.
+    def _handle_mult_prev_events(self, txn, events):
         """
-        return self.db.simple_insert_many_txn(
-            txn=txn,
-            table="event_labels",
+        For the given event, update the event edges table and forward and
+        backward extremities tables.
+        """
+        self.db.simple_insert_many_txn(
+            txn,
+            table="event_edges",
             values=[
                 {
-                    "event_id": event_id,
-                    "label": label,
-                    "room_id": room_id,
-                    "topological_ordering": topological_ordering,
+                    "event_id": ev.event_id,
+                    "prev_event_id": e_id,
+                    "room_id": ev.room_id,
+                    "is_state": False,
                 }
-                for label in labels
+                for ev in events
+                for e_id in ev.prev_event_ids()
             ],
         )
 
-    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
-        """Save the expiry timestamp associated with a given event ID.
+        self._update_backward_extremeties(txn, events)
 
-        Args:
-            txn (LoggingTransaction): The database transaction to use.
-            event_id (str): The event ID the expiry timestamp is associated with.
-            expiry_ts (int): The timestamp at which to expire (delete) the event.
-        """
-        return self.db.simple_insert_txn(
-            txn=txn,
-            table="event_expiry",
-            values={"event_id": event_id, "expiry_ts": expiry_ts},
-        )
+    def _update_backward_extremeties(self, txn, events):
+        """Updates the event_backward_extremities tables based on the new/updated
+        events being persisted.
 
-    @defer.inlineCallbacks
-    def expire_event(self, event_id):
-        """Retrieve and expire an event that has expired, and delete its associated
-        expiry timestamp. If the event can't be retrieved, delete its associated
-        timestamp so we don't try to expire it again in the future.
+        This is called for new events *and* for events that were outliers, but
+        are now being persisted as non-outliers.
 
-        Args:
-             event_id (str): The ID of the event to delete.
+        Forward extremities are handled when we first start persisting the events.
         """
-        # Try to retrieve the event's content from the database or the event cache.
-        event = yield self.get_event(event_id)
-
-        def delete_expired_event_txn(txn):
-            # Delete the expiry timestamp associated with this event from the database.
-            self._delete_event_expiry_txn(txn, event_id)
-
-            if not event:
-                # If we can't find the event, log a warning and delete the expiry date
-                # from the database so that we don't try to expire it again in the
-                # future.
-                logger.warning(
-                    "Can't expire event %s because we don't have it.", event_id
-                )
-                return
-
-            # Prune the event's dict then convert it to JSON.
-            pruned_json = encode_json(
-                prune_event_dict(event.room_version, event.get_dict())
-            )
-
-            # Update the event_json table to replace the event's JSON with the pruned
-            # JSON.
-            self._censor_event_txn(txn, event.event_id, pruned_json)
-
-            # We need to invalidate the event cache entry for this event because we
-            # changed its content in the database. We can't call
-            # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
-            # right type.
-            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
-            # Send that invalidation to replication so that other workers also invalidate
-            # the event cache.
-            self._send_invalidation_to_replication(
-                txn, "_get_event_cache", (event.event_id,)
-            )
-
-        yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
-
-    def _delete_event_expiry_txn(self, txn, event_id):
-        """Delete the expiry timestamp associated with an event ID without deleting the
-        actual event.
+        events_by_room = {}
+        for ev in events:
+            events_by_room.setdefault(ev.room_id, []).append(ev)
+
+        query = (
+            "INSERT INTO event_backward_extremities (event_id, room_id)"
+            " SELECT ?, ? WHERE NOT EXISTS ("
+            " SELECT 1 FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+            " )"
+            " AND NOT EXISTS ("
+            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+            " AND outlier = ?"
+            " )"
+        )
 
-        Args:
-            txn (LoggingTransaction): The transaction to use to perform the deletion.
-            event_id (str): The event ID to delete the associated expiry timestamp of.
-        """
-        return self.db.simple_delete_txn(
-            txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+        txn.executemany(
+            query,
+            [
+                (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+                for ev in events
+                for e_id in ev.prev_event_ids()
+                if not ev.internal_metadata.is_outlier()
+            ],
         )
 
-    def get_next_event_to_expire(self):
-        """Retrieve the entry with the lowest expiry timestamp in the event_expiry
-        table, or None if there's no more event to expire.
+        query = (
+            "DELETE FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+        )
+        txn.executemany(
+            query,
+            [
+                (ev.event_id, ev.room_id)
+                for ev in events
+                if not ev.internal_metadata.is_outlier()
+            ],
+        )
 
-        Returns: Deferred[Optional[Tuple[str, int]]]
-            A tuple containing the event ID as its first element and an expiry timestamp
-            as its second one, if there's at least one row in the event_expiry table.
-            None otherwise.
+    async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+        """Mark the invite has having been rejected even though we failed to
+        create a leave event for it.
         """
 
-        def get_next_event_to_expire_txn(txn):
-            txn.execute(
-                """
-                SELECT event_id, expiry_ts FROM event_expiry
-                ORDER BY expiry_ts ASC LIMIT 1
-                """
-            )
+        sql = (
+            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+            " AND replaced_by is NULL"
+        )
 
-            return txn.fetchone()
+        def f(txn, stream_ordering):
+            txn.execute(sql, (stream_ordering, True, room_id, user_id))
 
-        return self.db.runInteraction(
-            desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
-        )
+            # We also clear this entry from `local_current_membership`.
+            # Ideally we'd point to a leave event, but we don't have one, so
+            # nevermind.
+            self.db.simple_delete_txn(
+                txn,
+                table="local_current_membership",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+            )
 
+        with self._stream_id_gen.get_next() as stream_ordering:
+            await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
 
-AllNewEventsResult = namedtuple(
-    "AllNewEventsResult",
-    [
-        "new_forward_events",
-        "new_backfill_events",
-        "forward_ex_outliers",
-        "backward_ex_outliers",
-    ],
-)
+        return stream_ordering
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 73df6b33ba..213d69100a 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -27,7 +27,7 @@ from constantly import NamedConstant, Names
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
@@ -37,10 +37,12 @@ from synapse.events import make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -74,14 +76,50 @@ class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
+        if hs.config.worker.writers.events == hs.get_instance_name():
+            # We are the process in charge of generating stream ids for events,
+            # so instantiate ID generators based on the database
+            self._stream_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                extra_tables=[("local_invites", "stream_id")],
+            )
+            self._backfill_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                step=-1,
+                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+            )
+        else:
+            # Another process is in charge of persisting events and generating
+            # stream IDs: rely on the replication streams to let us know which
+            # IDs we can process.
+            self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
+            self._backfill_id_gen = SlavedIdTracker(
+                db_conn, "events", "stream_ordering", step=-1
+            )
+
         self._get_event_cache = Cache(
-            "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+            "*getEvent*",
+            keylen=3,
+            max_entries=hs.config.caches.event_cache_size,
+            apply_cache_factor_from_config=False,
         )
 
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == "events":
+            self._stream_id_gen.advance(token)
+        elif stream_name == "backfill":
+            self._backfill_id_gen.advance(-token)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
+
     def get_received_ts(self, event_id):
         """Get received_ts (when it was persisted) for the event.
 
@@ -1154,4 +1192,152 @@ class EventsWorkerStore(SQLBaseStore):
         rows = await self.db.runInteraction(
             "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
         )
+
         return rows, to_token, True
+
+    @cached(num_args=5, max_entries=10)
+    def get_all_new_events(
+        self,
+        last_backfill_id,
+        last_forward_id,
+        current_backfill_id,
+        current_forward_id,
+        limit,
+    ):
+        """Get all the new events that have arrived at the server either as
+        new events or as backfilled events"""
+        have_backfill_events = last_backfill_id != current_backfill_id
+        have_forward_events = last_forward_id != current_forward_id
+
+        if not have_backfill_events and not have_forward_events:
+            return defer.succeed(AllNewEventsResult([], [], [], [], []))
+
+        def get_all_new_events_txn(txn):
+            sql = (
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            if have_forward_events:
+                txn.execute(sql, (last_forward_id, current_forward_id, limit))
+                new_forward_events = txn.fetchall()
+
+                if len(new_forward_events) == limit:
+                    upper_bound = new_forward_events[-1][0]
+                else:
+                    upper_bound = current_forward_id
+
+                sql = (
+                    "SELECT event_stream_ordering, event_id, state_group"
+                    " FROM ex_outlier_stream"
+                    " WHERE ? > event_stream_ordering"
+                    " AND event_stream_ordering >= ?"
+                    " ORDER BY event_stream_ordering DESC"
+                )
+                txn.execute(sql, (last_forward_id, upper_bound))
+                forward_ex_outliers = txn.fetchall()
+            else:
+                new_forward_events = []
+                forward_ex_outliers = []
+
+            sql = (
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering DESC"
+                " LIMIT ?"
+            )
+            if have_backfill_events:
+                txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
+                new_backfill_events = txn.fetchall()
+
+                if len(new_backfill_events) == limit:
+                    upper_bound = new_backfill_events[-1][0]
+                else:
+                    upper_bound = current_backfill_id
+
+                sql = (
+                    "SELECT -event_stream_ordering, event_id, state_group"
+                    " FROM ex_outlier_stream"
+                    " WHERE ? > event_stream_ordering"
+                    " AND event_stream_ordering >= ?"
+                    " ORDER BY event_stream_ordering DESC"
+                )
+                txn.execute(sql, (-last_backfill_id, -upper_bound))
+                backward_ex_outliers = txn.fetchall()
+            else:
+                new_backfill_events = []
+                backward_ex_outliers = []
+
+            return AllNewEventsResult(
+                new_forward_events,
+                new_backfill_events,
+                forward_ex_outliers,
+                backward_ex_outliers,
+            )
+
+        return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
+
+    async def is_event_after(self, event_id1, event_id2):
+        """Returns True if event_id1 is after event_id2 in the stream
+        """
+        to_1, so_1 = await self.get_event_ordering(event_id1)
+        to_2, so_2 = await self.get_event_ordering(event_id2)
+        return (to_1, so_1) > (to_2, so_2)
+
+    @cachedInlineCallbacks(max_entries=5000)
+    def get_event_ordering(self, event_id):
+        res = yield self.db.simple_select_one(
+            table="events",
+            retcols=["topological_ordering", "stream_ordering"],
+            keyvalues={"event_id": event_id},
+            allow_none=True,
+        )
+
+        if not res:
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+        return (int(res["topological_ordering"]), int(res["stream_ordering"]))
+
+    def get_next_event_to_expire(self):
+        """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+        table, or None if there's no more event to expire.
+
+        Returns: Deferred[Optional[Tuple[str, int]]]
+            A tuple containing the event ID as its first element and an expiry timestamp
+            as its second one, if there's at least one row in the event_expiry table.
+            None otherwise.
+        """
+
+        def get_next_event_to_expire_txn(txn):
+            txn.execute(
+                """
+                SELECT event_id, expiry_ts FROM event_expiry
+                ORDER BY expiry_ts ASC LIMIT 1
+                """
+            )
+
+            return txn.fetchone()
+
+        return self.db.runInteraction(
+            desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+        )
+
+
+AllNewEventsResult = namedtuple(
+    "AllNewEventsResult",
+    [
+        "new_forward_events",
+        "new_backfill_events",
+        "forward_ex_outliers",
+        "backward_ex_outliers",
+    ],
+)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 0963e6c250..fb1361f1c1 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_invited_users_in_group",
         )
 
-    def get_rooms_in_group(self, group_id, include_private=False):
+    def get_rooms_in_group(self, group_id: str, include_private: bool = False):
+        """Retrieve the rooms that belong to a given group. Does not return rooms that
+        lack members.
+
+        Args:
+            group_id: The ID of the group to query for rooms
+            include_private: Whether to return private rooms in results
+
+        Returns:
+            Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
+            form of:
+
+            {
+              "room_id": "!a_room_id:example.com",  # The ID of the room
+              "is_public": False                    # Whether this is a public room or not
+            }
+        """
         # TODO: Pagination
 
-        keyvalues = {"group_id": group_id}
-        if not include_private:
-            keyvalues["is_public"] = True
+        def _get_rooms_in_group_txn(txn):
+            sql = """
+            SELECT room_id, is_public FROM group_rooms
+                WHERE group_id = ?
+                AND room_id IN (
+                    SELECT group_rooms.room_id FROM group_rooms
+                    LEFT JOIN room_stats_current ON
+                        group_rooms.room_id = room_stats_current.room_id
+                        AND joined_members > 0
+                        AND local_users_in_room > 0
+                    LEFT JOIN rooms ON
+                        group_rooms.room_id = rooms.room_id
+                        AND (room_version <> '') = ?
+                )
+            """
+            args = [group_id, False]
 
-        return self.db.simple_select_list(
-            table="group_rooms",
-            keyvalues=keyvalues,
-            retcols=("room_id", "is_public"),
-            desc="get_rooms_in_group",
-        )
+            if not include_private:
+                sql += " AND is_public = ?"
+                args += [True]
+
+            txn.execute(sql, args)
+
+            return [
+                {"room_id": room_id, "is_public": is_public}
+                for room_id, is_public in txn
+            ]
 
-    def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+        return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
+
+    def get_rooms_for_summary_by_category(
+        self, group_id: str, include_private: bool = False,
+    ):
         """Get the rooms and categories that should be included in a summary request
 
-        Returns ([rooms], [categories])
+        Args:
+            group_id: The ID of the group to query the summary for
+            include_private: Whether to return private rooms in results
+
+        Returns:
+            Deferred[Tuple[List, Dict]]: A tuple containing:
+
+                * A list of dictionaries with the keys:
+                    * "room_id": str, the room ID
+                    * "is_public": bool, whether the room is public
+                    * "category_id": str|None, the category ID if set, else None
+                    * "order": int, the sort order of rooms
+
+                * A dictionary with the key:
+                    * category_id (str): a dictionary with the keys:
+                        * "is_public": bool, whether the category is public
+                        * "profile": str, the category profile
+                        * "order": int, the sort order of rooms in this category
         """
 
         def _get_rooms_for_summary_txn(txn):
@@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore):
                 SELECT room_id, is_public, category_id, room_order
                 FROM group_summary_rooms
                 WHERE group_id = ?
+                AND room_id IN (
+                    SELECT group_rooms.room_id FROM group_rooms
+                    LEFT JOIN room_stats_current ON
+                        group_rooms.room_id = room_stats_current.room_id
+                        AND joined_members > 0
+                        AND local_users_in_room > 0
+                    LEFT JOIN rooms ON
+                        group_rooms.room_id = rooms.room_id
+                        AND (room_version <> '') = ?
+                )
             """
 
             if not include_private:
                 sql += " AND is_public = ?"
-                txn.execute(sql, (group_id, True))
+                txn.execute(sql, (group_id, False, True))
             else:
-                txn.execute(sql, (group_id,))
+                txn.execute(sql, (group_id, False))
 
             rooms = [
                 {
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
index ba89c68c9f..4e1642a27a 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/data_stores/main/keys.py
@@ -17,8 +17,6 @@
 import itertools
 import logging
 
-import six
-
 from signedjson.key import decode_verify_key_bytes
 
 from synapse.storage._base import SQLBaseStore
@@ -28,12 +26,8 @@ from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
-    db_binary_type = six.moves.builtins.buffer
-else:
-    db_binary_type = memoryview
+
+db_binary_type = memoryview
 
 
 class KeyStore(SQLBaseStore):
diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/data_stores/main/metrics.py
new file mode 100644
index 0000000000..dad5bbc602
--- /dev/null
+++ b/synapse/storage/data_stores/main/metrics.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import typing
+from collections import Counter
+
+from twisted.internet import defer
+
+from synapse.metrics import BucketCollector
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.event_push_actions import (
+    EventPushActionsWorkerStore,
+)
+from synapse.storage.database import Database
+
+
+class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
+    """Functions to pull various metrics from the DB, for e.g. phone home
+    stats and prometheus metrics.
+    """
+
+    def __init__(self, database: Database, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        # Collect metrics on the number of forward extremities that exist.
+        # Counter of number of extremities to count
+        self._current_forward_extremities_amount = (
+            Counter()
+        )  # type: typing.Counter[int]
+
+        BucketCollector(
+            "synapse_forward_extremities",
+            lambda: self._current_forward_extremities_amount,
+            buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
+        )
+
+        # Read the extrems every 60 minutes
+        def read_forward_extremities():
+            # run as a background process to make sure that the database transactions
+            # have a logcontext to report to
+            return run_as_background_process(
+                "read_forward_extremities", self._read_forward_extremities
+            )
+
+        hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+
+    async def _read_forward_extremities(self):
+        def fetch(txn):
+            txn.execute(
+                """
+                select count(*) c from event_forward_extremities
+                group by room_id
+                """
+            )
+            return txn.fetchall()
+
+        res = await self.db.runInteraction("read_forward_extremities", fetch)
+        self._current_forward_extremities_amount = Counter([x[0] for x in res])
+
+    @defer.inlineCallbacks
+    def count_daily_messages(self):
+        """
+        Returns an estimate of the number of messages sent in the last day.
+
+        If it has been significantly less or more than one day since the last
+        call to this function, it will return None.
+        """
+
+        def _count_messages(txn):
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.message'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        ret = yield self.db.runInteraction("count_messages", _count_messages)
+        return ret
+
+    @defer.inlineCallbacks
+    def count_daily_sent_messages(self):
+        def _count_messages(txn):
+            # This is good enough as if you have silly characters in your own
+            # hostname then thats your own fault.
+            like_clause = "%:" + self.hs.hostname
+
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.message'
+                    AND sender LIKE ?
+                AND stream_ordering > ?
+            """
+
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+            (count,) = txn.fetchone()
+            return count
+
+        ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
+        return ret
+
+    @defer.inlineCallbacks
+    def count_daily_active_rooms(self):
+        def _count(txn):
+            sql = """
+                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                WHERE type = 'm.room.message'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
+        return ret
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 925bc5691b..1310d39069 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -17,7 +17,7 @@ import logging
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -122,6 +122,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
     def __init__(self, database: Database, db_conn, hs):
         super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
 
+        self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+        self._mau_stats_only = hs.config.mau_stats_only
+        self._max_mau_value = hs.config.max_mau_value
+
         # Do not add more reserved users than the total allowable number
         # cur = LoggingTransaction(
         self.db.new_transaction(
@@ -130,7 +134,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             [],
             [],
             self._initialise_reserved_users,
-            hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value],
+            hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
         )
 
     def _initialise_reserved_users(self, txn, threepids):
@@ -142,6 +146,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             threepids (list[dict]): List of threepid dicts to reserve
         """
 
+        # XXX what is this function trying to achieve?  It upserts into
+        # monthly_active_users for each *registered* reserved mau user, but why?
+        #
+        #  - shouldn't there already be an entry for each reserved user (at least
+        #    if they have been active recently)?
+        #
+        #  - if it's important that the timestamp is kept up to date, why do we only
+        #    run this at startup?
+
         for tp in threepids:
             user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
 
@@ -174,76 +187,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             """
 
             thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-            query_args = [thirty_days_ago]
-            base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
-
-            # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
-            # when len(reserved_users) == 0. Works fine on sqlite.
-            if len(reserved_users) > 0:
-                # questionmarks is a hack to overcome sqlite not supporting
-                # tuples in 'WHERE IN %s'
-                question_marks = ",".join("?" * len(reserved_users))
-
-                query_args.extend(reserved_users)
-                sql = base_sql + " AND user_id NOT IN ({})".format(question_marks)
-            else:
-                sql = base_sql
 
-            txn.execute(sql, query_args)
+            in_clause, in_clause_args = make_in_list_sql_clause(
+                self.database_engine, "user_id", reserved_users
+            )
+
+            txn.execute(
+                "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s"
+                % (in_clause,),
+                [thirty_days_ago] + in_clause_args,
+            )
 
-            max_mau_value = self.hs.config.max_mau_value
-            if self.hs.config.limit_usage_by_mau:
+            if self._limit_usage_by_mau:
                 # If MAU user count still exceeds the MAU threshold, then delete on
                 # a least recently active basis.
                 # Note it is not possible to write this query using OFFSET due to
                 # incompatibilities in how sqlite and postgres support the feature.
-                # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
-                # While Postgres does not require 'LIMIT', but also does not support
+                # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present,
+                # while Postgres does not require 'LIMIT', but also does not support
                 # negative LIMIT values. So there is no way to write it that both can
                 # support
-                if len(reserved_users) == 0:
-                    sql = """
-                        DELETE FROM monthly_active_users
-                        WHERE user_id NOT IN (
-                            SELECT user_id FROM monthly_active_users
-                            ORDER BY timestamp DESC
-                            LIMIT ?
-                        )
-                        """
-                    txn.execute(sql, (max_mau_value,))
-                # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
-                # when len(reserved_users) == 0. Works fine on sqlite.
-                else:
-                    # Must be >= 0 for postgres
-                    num_of_non_reserved_users_to_remove = max(
-                        max_mau_value - len(reserved_users), 0
-                    )
 
-                    # It is important to filter reserved users twice to guard
-                    # against the case where the reserved user is present in the
-                    # SELECT, meaning that a legitmate mau is deleted.
-                    sql = """
-                        DELETE FROM monthly_active_users
-                        WHERE user_id NOT IN (
-                            SELECT user_id FROM monthly_active_users
-                            WHERE user_id NOT IN ({})
-                            ORDER BY timestamp DESC
-                            LIMIT ?
-                        )
-                        AND user_id NOT IN ({})
-                    """.format(
-                        question_marks, question_marks
+                # Limit must be >= 0 for postgres
+                num_of_non_reserved_users_to_remove = max(
+                    self._max_mau_value - len(reserved_users), 0
+                )
+
+                # It is important to filter reserved users twice to guard
+                # against the case where the reserved user is present in the
+                # SELECT, meaning that a legitimate mau is deleted.
+                sql = """
+                    DELETE FROM monthly_active_users
+                    WHERE user_id NOT IN (
+                        SELECT user_id FROM monthly_active_users
+                        WHERE NOT %s
+                        ORDER BY timestamp DESC
+                        LIMIT ?
                     )
-
-                    query_args = [
-                        *reserved_users,
-                        num_of_non_reserved_users_to_remove,
-                        *reserved_users,
-                    ]
-
-                    txn.execute(sql, query_args)
-
-            # It seems poor to invalidate the whole cache, Postgres supports
+                    AND NOT %s
+                """ % (
+                    in_clause,
+                    in_clause,
+                )
+
+                query_args = (
+                    in_clause_args
+                    + [num_of_non_reserved_users_to_remove]
+                    + in_clause_args
+                )
+                txn.execute(sql, query_args)
+
+            # It seems poor to invalidate the whole cache. Postgres supports
             # 'Returning' which would allow me to invalidate only the
             # specific users, but sqlite has no way to do this and instead
             # I would need to SELECT and the DELETE which without locking
@@ -335,7 +329,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
         Args:
             user_id(str): the user_id to query
         """
-        if self.hs.config.limit_usage_by_mau or self.hs.config.mau_stats_only:
+        if self._limit_usage_by_mau or self._mau_stats_only:
             # Trial users and guests should not be included as part of MAU group
             is_guest = yield self.is_guest(user_id)
             if is_guest:
@@ -356,11 +350,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
                 # In the case where mau_stats_only is True and limit_usage_by_mau is
                 # False, there is no point in checking get_monthly_active_count - it
                 # adds no value and will break the logic if max_mau_value is exceeded.
-                if not self.hs.config.limit_usage_by_mau:
+                if not self._limit_usage_by_mau:
                     yield self.upsert_monthly_active_user(user_id)
                 else:
                     count = yield self.get_monthly_active_count()
-                    if count < self.hs.config.max_mau_value:
+                    if count < self._max_mau_value:
                         yield self.upsert_monthly_active_user(user_id)
             elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
                 yield self.upsert_monthly_active_user(user_id)
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index 2a97991d23..7c69996041 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -199,7 +199,7 @@ class ProfileStore(ProfileWorkerStore):
         return self.db.simple_upsert(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
-            values={
+            updatevalues={
                 "displayname": displayname,
                 "avatar_url": avatar_url,
                 "last_check": self._clock.time_msec(),
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
new file mode 100644
index 0000000000..a93e1ef198
--- /dev/null
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -0,0 +1,399 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, Tuple
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.types import RoomStreamToken
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
+    def purge_history(self, room_id, token, delete_local_events):
+        """Deletes room history before a certain point
+
+        Args:
+            room_id (str):
+
+            token (str): A topological token to delete events before
+
+            delete_local_events (bool):
+                if True, we will delete local events as well as remote ones
+                (instead of just marking them as outliers and deleting their
+                state groups).
+
+        Returns:
+            Deferred[set[int]]: The set of state groups that are referenced by
+            deleted events.
+        """
+
+        return self.db.runInteraction(
+            "purge_history",
+            self._purge_history_txn,
+            room_id,
+            token,
+            delete_local_events,
+        )
+
+    def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
+        token = RoomStreamToken.parse(token_str)
+
+        # Tables that should be pruned:
+        #     event_auth
+        #     event_backward_extremities
+        #     event_edges
+        #     event_forward_extremities
+        #     event_json
+        #     event_push_actions
+        #     event_reference_hashes
+        #     event_search
+        #     event_to_state_groups
+        #     events
+        #     rejections
+        #     room_depth
+        #     state_groups
+        #     state_groups_state
+
+        # we will build a temporary table listing the events so that we don't
+        # have to keep shovelling the list back and forth across the
+        # connection. Annoyingly the python sqlite driver commits the
+        # transaction on CREATE, so let's do this first.
+        #
+        # furthermore, we might already have the table from a previous (failed)
+        # purge attempt, so let's drop the table first.
+
+        txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+        txn.execute(
+            "CREATE TEMPORARY TABLE events_to_purge ("
+            "    event_id TEXT NOT NULL,"
+            "    should_delete BOOLEAN NOT NULL"
+            ")"
+        )
+
+        # First ensure that we're not about to delete all the forward extremeties
+        txn.execute(
+            "SELECT e.event_id, e.depth FROM events as e "
+            "INNER JOIN event_forward_extremities as f "
+            "ON e.event_id = f.event_id "
+            "AND e.room_id = f.room_id "
+            "WHERE f.room_id = ?",
+            (room_id,),
+        )
+        rows = txn.fetchall()
+        max_depth = max(row[1] for row in rows)
+
+        if max_depth < token.topological:
+            # We need to ensure we don't delete all the events from the database
+            # otherwise we wouldn't be able to send any events (due to not
+            # having any backwards extremeties)
+            raise SynapseError(
+                400, "topological_ordering is greater than forward extremeties"
+            )
+
+        logger.info("[purge] looking for events to delete")
+
+        should_delete_expr = "state_key IS NULL"
+        should_delete_params = ()  # type: Tuple[Any, ...]
+        if not delete_local_events:
+            should_delete_expr += " AND event_id NOT LIKE ?"
+
+            # We include the parameter twice since we use the expression twice
+            should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
+
+        should_delete_params += (room_id, token.topological)
+
+        # Note that we insert events that are outliers and aren't going to be
+        # deleted, as nothing will happen to them.
+        txn.execute(
+            "INSERT INTO events_to_purge"
+            " SELECT event_id, %s"
+            " FROM events AS e LEFT JOIN state_events USING (event_id)"
+            " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
+            % (should_delete_expr, should_delete_expr),
+            should_delete_params,
+        )
+
+        # We create the indices *after* insertion as that's a lot faster.
+
+        # create an index on should_delete because later we'll be looking for
+        # the should_delete / shouldn't_delete subsets
+        txn.execute(
+            "CREATE INDEX events_to_purge_should_delete"
+            " ON events_to_purge(should_delete)"
+        )
+
+        # We do joins against events_to_purge for e.g. calculating state
+        # groups to purge, etc., so lets make an index.
+        txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
+
+        txn.execute("SELECT event_id, should_delete FROM events_to_purge")
+        event_rows = txn.fetchall()
+        logger.info(
+            "[purge] found %i events before cutoff, of which %i can be deleted",
+            len(event_rows),
+            sum(1 for e in event_rows if e[1]),
+        )
+
+        logger.info("[purge] Finding new backward extremities")
+
+        # We calculate the new entries for the backward extremeties by finding
+        # events to be purged that are pointed to by events we're not going to
+        # purge.
+        txn.execute(
+            "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+            " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+            " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
+            " WHERE ep2.event_id IS NULL"
+        )
+        new_backwards_extrems = txn.fetchall()
+
+        logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+
+        txn.execute(
+            "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
+        )
+
+        # Update backward extremeties
+        txn.executemany(
+            "INSERT INTO event_backward_extremities (room_id, event_id)"
+            " VALUES (?, ?)",
+            [(room_id, event_id) for event_id, in new_backwards_extrems],
+        )
+
+        logger.info("[purge] finding state groups referenced by deleted events")
+
+        # Get all state groups that are referenced by events that are to be
+        # deleted.
+        txn.execute(
+            """
+            SELECT DISTINCT state_group FROM events_to_purge
+            INNER JOIN event_to_state_groups USING (event_id)
+        """
+        )
+
+        referenced_state_groups = {sg for sg, in txn}
+        logger.info(
+            "[purge] found %i referenced state groups", len(referenced_state_groups)
+        )
+
+        logger.info("[purge] removing events from event_to_state_groups")
+        txn.execute(
+            "DELETE FROM event_to_state_groups "
+            "WHERE event_id IN (SELECT event_id from events_to_purge)"
+        )
+        for event_id, _ in event_rows:
+            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+
+        # Delete all remote non-state events
+        for table in (
+            "events",
+            "event_json",
+            "event_auth",
+            "event_edges",
+            "event_forward_extremities",
+            "event_reference_hashes",
+            "event_search",
+            "rejections",
+        ):
+            logger.info("[purge] removing events from %s", table)
+
+            txn.execute(
+                "DELETE FROM %s WHERE event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,)
+            )
+
+        # event_push_actions lacks an index on event_id, and has one on
+        # (room_id, event_id) instead.
+        for table in ("event_push_actions",):
+            logger.info("[purge] removing events from %s", table)
+
+            txn.execute(
+                "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,),
+                (room_id,),
+            )
+
+        # Mark all state and own events as outliers
+        logger.info("[purge] marking remaining events as outliers")
+        txn.execute(
+            "UPDATE events SET outlier = ?"
+            " WHERE event_id IN ("
+            "    SELECT event_id FROM events_to_purge "
+            "    WHERE NOT should_delete"
+            ")",
+            (True,),
+        )
+
+        # synapse tries to take out an exclusive lock on room_depth whenever it
+        # persists events (because upsert), and once we run this update, we
+        # will block that for the rest of our transaction.
+        #
+        # So, let's stick it at the end so that we don't block event
+        # persistence.
+        #
+        # We do this by calculating the minimum depth of the backwards
+        # extremities. However, the events in event_backward_extremities
+        # are ones we don't have yet so we need to look at the events that
+        # point to it via event_edges table.
+        txn.execute(
+            """
+            SELECT COALESCE(MIN(depth), 0)
+            FROM event_backward_extremities AS eb
+            INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
+            INNER JOIN events AS e ON e.event_id = eg.event_id
+            WHERE eb.room_id = ?
+        """,
+            (room_id,),
+        )
+        (min_depth,) = txn.fetchone()
+
+        logger.info("[purge] updating room_depth to %d", min_depth)
+
+        txn.execute(
+            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+            (min_depth, room_id),
+        )
+
+        # finally, drop the temp table. this will commit the txn in sqlite,
+        # so make sure to keep this actually last.
+        txn.execute("DROP TABLE events_to_purge")
+
+        logger.info("[purge] done")
+
+        return referenced_state_groups
+
+    def purge_room(self, room_id):
+        """Deletes all record of a room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[List[int]]: The list of state groups to delete.
+        """
+
+        return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
+
+    def _purge_room_txn(self, txn, room_id):
+        # First we fetch all the state groups that should be deleted, before
+        # we delete that information.
+        txn.execute(
+            """
+                SELECT DISTINCT state_group FROM events
+                INNER JOIN event_to_state_groups USING(event_id)
+                WHERE events.room_id = ?
+            """,
+            (room_id,),
+        )
+
+        state_groups = [row[0] for row in txn]
+
+        # Now we delete tables which lack an index on room_id but have one on event_id
+        for table in (
+            "event_auth",
+            "event_edges",
+            "event_push_actions_staging",
+            "event_reference_hashes",
+            "event_relations",
+            "event_to_state_groups",
+            "redactions",
+            "rejections",
+            "state_events",
+        ):
+            logger.info("[purge] removing %s from %s", room_id, table)
+
+            txn.execute(
+                """
+                DELETE FROM %s WHERE event_id IN (
+                  SELECT event_id FROM events WHERE room_id=?
+                )
+                """
+                % (table,),
+                (room_id,),
+            )
+
+        # and finally, the tables with an index on room_id (or no useful index)
+        for table in (
+            "current_state_events",
+            "event_backward_extremities",
+            "event_forward_extremities",
+            "event_json",
+            "event_push_actions",
+            "event_search",
+            "events",
+            "group_rooms",
+            "public_room_list_stream",
+            "receipts_graph",
+            "receipts_linearized",
+            "room_aliases",
+            "room_depth",
+            "room_memberships",
+            "room_stats_state",
+            "room_stats_current",
+            "room_stats_historical",
+            "room_stats_earliest_token",
+            "rooms",
+            "stream_ordering_to_exterm",
+            "users_in_public_rooms",
+            "users_who_share_private_rooms",
+            # no useful index, but let's clear them anyway
+            "appservice_room_list",
+            "e2e_room_keys",
+            "event_push_summary",
+            "pusher_throttle",
+            "group_summary_rooms",
+            "local_invites",
+            "room_account_data",
+            "room_tags",
+            "local_current_membership",
+        ):
+            logger.info("[purge] removing %s from %s", room_id, table)
+            txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
+
+        # Other tables we do NOT need to clear out:
+        #
+        #  - blocked_rooms
+        #    This is important, to make sure that we don't accidentally rejoin a blocked
+        #    room after it was purged
+        #
+        #  - user_directory
+        #    This has a room_id column, but it is unused
+        #
+
+        # Other tables that we might want to consider clearing out include:
+        #
+        #  - event_reports
+        #       Given that these are intended for abuse management my initial
+        #       inclination is to leave them in place.
+        #
+        #  - current_state_delta_stream
+        #  - ex_outlier_stream
+        #  - room_tags_revisions
+        #       The problem with these is that they are largeish and there is no room_id
+        #       index on them. In any case we should be clearing out 'stream' tables
+        #       periodically anyway (#5888)
+
+        # TODO: we could probably usefully do a bunch of cache invalidation here
+
+        logger.info("[purge] done")
+
+        return state_groups
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index b3faafa0a4..ef8f40959f 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -16,19 +16,23 @@
 
 import abc
 import logging
+from typing import Union
 
 from canonicaljson import json
 
 from twisted.internet import defer
 
 from synapse.push.baserules import list_with_base_rules
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
 from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
 from synapse.storage.database import Database
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.storage.util.id_generators import ChainedIdGenerator
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -64,6 +68,7 @@ class PushRulesWorkerStore(
     ReceiptsWorkerStore,
     PusherWorkerStore,
     RoomMemberWorkerStore,
+    EventsWorkerStore,
     SQLBaseStore,
 ):
     """This is an abstract base class where subclasses must implement
@@ -77,6 +82,15 @@ class PushRulesWorkerStore(
     def __init__(self, database: Database, db_conn, hs):
         super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
+        if hs.config.worker.worker_app is None:
+            self._push_rules_stream_id_gen = ChainedIdGenerator(
+                self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+            )  # type: Union[ChainedIdGenerator, SlavedIdTracker]
+        else:
+            self._push_rules_stream_id_gen = SlavedIdTracker(
+                db_conn, "push_rules_stream", "stream_id"
+            )
+
         push_rules_prefill, push_rules_id = self.db.get_cache_dict(
             db_conn,
             "push_rules_stream",
diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py
index 1c07c7a425..27e5a2084a 100644
--- a/synapse/storage/data_stores/main/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -21,17 +21,6 @@ logger = logging.getLogger(__name__)
 
 
 class RejectionsStore(SQLBaseStore):
-    def _store_rejections_txn(self, txn, event_id, reason):
-        self.db.simple_insert_txn(
-            txn,
-            table="rejections",
-            values={
-                "event_id": event_id,
-                "reason": reason,
-                "last_check": self._clock.time_msec(),
-            },
-        )
-
     def get_rejection_reason(self, event_id):
         return self.db.simple_select_one_onecol(
             table="rejections",
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
index 046c2b4845..7d477f8d01 100644
--- a/synapse/storage/data_stores/main/relations.py
+++ b/synapse/storage/data_stores/main/relations.py
@@ -324,62 +324,4 @@ class RelationsWorkerStore(SQLBaseStore):
 
 
 class RelationsStore(RelationsWorkerStore):
-    def _handle_event_relations(self, txn, event):
-        """Handles inserting relation data during peristence of events
-
-        Args:
-            txn
-            event (EventBase)
-        """
-        relation = event.content.get("m.relates_to")
-        if not relation:
-            # No relations
-            return
-
-        rel_type = relation.get("rel_type")
-        if rel_type not in (
-            RelationTypes.ANNOTATION,
-            RelationTypes.REFERENCE,
-            RelationTypes.REPLACE,
-        ):
-            # Unknown relation type
-            return
-
-        parent_id = relation.get("event_id")
-        if not parent_id:
-            # Invalid relation
-            return
-
-        aggregation_key = relation.get("key")
-
-        self.db.simple_insert_txn(
-            txn,
-            table="event_relations",
-            values={
-                "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
-            },
-        )
-
-        txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
-        txn.call_after(
-            self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
-        )
-
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
-
-    def _handle_redaction(self, txn, redacted_event_id):
-        """Handles receiving a redaction and checking whether we need to remove
-        any redacted relations from the database.
-
-        Args:
-            txn
-            redacted_event_id (str): The event that was redacted.
-        """
-
-        self.db.simple_delete_txn(
-            txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
-        )
+    pass
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index fc7e07a13d..55c1defc73 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -21,8 +21,6 @@ from abc import abstractmethod
 from enum import Enum
 from typing import Any, Dict, List, Optional, Tuple
 
-from six import integer_types
-
 from canonicaljson import json
 
 from twisted.internet import defer
@@ -98,6 +96,37 @@ class RoomWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
+    def get_room_with_stats(self, room_id: str):
+        """Retrieve room with statistics.
+
+        Args:
+            room_id: The ID of the room to retrieve.
+        Returns:
+            A dict containing the room information, or None if the room is unknown.
+        """
+
+        def get_room_with_stats_txn(txn, room_id):
+            sql = """
+                SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
+                  curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
+                  rooms.creator, state.encryption, state.is_federatable AS federatable,
+                  rooms.is_public AS public, state.join_rules, state.guest_access,
+                  state.history_visibility, curr.current_state_events AS state_events
+                FROM rooms
+                LEFT JOIN room_stats_state state USING (room_id)
+                LEFT JOIN room_stats_current curr USING (room_id)
+                WHERE room_id = ?
+                """
+            txn.execute(sql, [room_id])
+            res = self.db.cursor_to_dict(txn)[0]
+            res["federatable"] = bool(res["federatable"])
+            res["public"] = bool(res["public"])
+            return res
+
+        return self.db.runInteraction(
+            "get_room_with_stats", get_room_with_stats_txn, room_id
+        )
+
     def get_public_room_ids(self):
         return self.db.simple_select_onecol(
             table="rooms",
@@ -1294,53 +1323,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
         return self.db.runInteraction("get_rooms", f)
 
-    def _store_room_topic_txn(self, txn, event):
-        if hasattr(event, "content") and "topic" in event.content:
-            self.store_event_search_txn(
-                txn, event, "content.topic", event.content["topic"]
-            )
-
-    def _store_room_name_txn(self, txn, event):
-        if hasattr(event, "content") and "name" in event.content:
-            self.store_event_search_txn(
-                txn, event, "content.name", event.content["name"]
-            )
-
-    def _store_room_message_txn(self, txn, event):
-        if hasattr(event, "content") and "body" in event.content:
-            self.store_event_search_txn(
-                txn, event, "content.body", event.content["body"]
-            )
-
-    def _store_retention_policy_for_room_txn(self, txn, event):
-        if hasattr(event, "content") and (
-            "min_lifetime" in event.content or "max_lifetime" in event.content
-        ):
-            if (
-                "min_lifetime" in event.content
-                and not isinstance(event.content.get("min_lifetime"), integer_types)
-            ) or (
-                "max_lifetime" in event.content
-                and not isinstance(event.content.get("max_lifetime"), integer_types)
-            ):
-                # Ignore the event if one of the value isn't an integer.
-                return
-
-            self.db.simple_insert_txn(
-                txn=txn,
-                table="room_retention",
-                values={
-                    "room_id": event.room_id,
-                    "event_id": event.event_id,
-                    "min_lifetime": event.content.get("min_lifetime"),
-                    "max_lifetime": event.content.get("max_lifetime"),
-                },
-            )
-
-            self._invalidate_cache_and_stream(
-                txn, self.get_retention_policy_for_room, (event.room_id,)
-            )
-
     def add_event_report(
         self, room_id, event_id, user_id, reason, content, received_ts
     ):
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index e626b7f6f7..137ebac833 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -45,7 +45,6 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 from synapse.util.metrics import Measure
-from synapse.util.stringutils import to_ascii
 
 logger = logging.getLogger(__name__)
 
@@ -153,16 +152,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 self._check_safe_current_state_events_membership_updated_txn,
             )
 
-    @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
-    def get_hosts_in_room(self, room_id, cache_context):
-        """Returns the set of all hosts currently in the room
-        """
-        user_ids = yield self.get_users_in_room(
-            room_id, on_invalidate=cache_context.invalidate
-        )
-        hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
-        return hosts
-
     @cached(max_entries=100000, iterable=True)
     def get_users_in_room(self, room_id):
         return self.db.runInteraction(
@@ -189,7 +178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
 
         txn.execute(sql, (room_id, Membership.JOIN))
-        return [to_ascii(r[0]) for r in txn]
+        return [r[0] for r in txn]
 
     @cached(max_entries=100000)
     def get_room_summary(self, room_id):
@@ -233,7 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(sql, (room_id,))
             res = {}
             for count, membership in txn:
-                summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
+                summary = res.setdefault(membership, MemberSummary([], count))
 
             # we order by membership and then fairly arbitrarily by event_id so
             # heroes are consistent
@@ -265,11 +254,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
             txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
             for user_id, membership, event_id in txn:
-                summary = res[to_ascii(membership)]
+                summary = res[membership]
                 # we will always have a summary for this membership type at this
                 # point given the summary currently contains the counts.
                 members = summary.members
-                members.append((to_ascii(user_id), to_ascii(event_id)))
+                members.append((user_id, event_id))
 
             return res
 
@@ -594,13 +583,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             ev_entry = event_map.get(event_id)
             if ev_entry:
                 if ev_entry.event.membership == Membership.JOIN:
-                    users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
-                        display_name=to_ascii(
-                            ev_entry.event.content.get("displayname", None)
-                        ),
-                        avatar_url=to_ascii(
-                            ev_entry.event.content.get("avatar_url", None)
-                        ),
+                    users_in_room[ev_entry.event.state_key] = ProfileInfo(
+                        display_name=ev_entry.event.content.get("displayname", None),
+                        avatar_url=ev_entry.event.content.get("avatar_url", None),
                     )
             else:
                 missing_member_event_ids.append(event_id)
@@ -614,9 +599,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 if event.event_id in member_event_ids:
-                    users_in_room[to_ascii(event.state_key)] = ProfileInfo(
-                        display_name=to_ascii(event.content.get("displayname", None)),
-                        avatar_url=to_ascii(event.content.get("avatar_url", None)),
+                    users_in_room[event.state_key] = ProfileInfo(
+                        display_name=event.content.get("displayname", None),
+                        avatar_url=event.content.get("avatar_url", None),
                     )
 
         return users_in_room
@@ -1061,119 +1046,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
         super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database.
-        """
-        self.db.simple_insert_many_txn(
-            txn,
-            table="room_memberships",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "user_id": event.state_key,
-                    "sender": event.user_id,
-                    "room_id": event.room_id,
-                    "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
-                }
-                for event in events
-            ],
-        )
-
-        for event in events:
-            txn.call_after(
-                self._membership_stream_cache.entity_has_changed,
-                event.state_key,
-                event.internal_metadata.stream_ordering,
-            )
-            txn.call_after(
-                self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
-            )
-
-            # We update the local_invites table only if the event is "current",
-            # i.e., its something that has just happened. If the event is an
-            # outlier it is only current if its an "out of band membership",
-            # like a remote invite or a rejection of a remote invite.
-            is_new_state = not backfilled and (
-                not event.internal_metadata.is_outlier()
-                or event.internal_metadata.is_out_of_band_membership()
-            )
-            is_mine = self.hs.is_mine_id(event.state_key)
-            if is_new_state and is_mine:
-                if event.membership == Membership.INVITE:
-                    self.db.simple_insert_txn(
-                        txn,
-                        table="local_invites",
-                        values={
-                            "event_id": event.event_id,
-                            "invitee": event.state_key,
-                            "inviter": event.sender,
-                            "room_id": event.room_id,
-                            "stream_id": event.internal_metadata.stream_ordering,
-                        },
-                    )
-                else:
-                    sql = (
-                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
-                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-                        " AND replaced_by is NULL"
-                    )
-
-                    txn.execute(
-                        sql,
-                        (
-                            event.internal_metadata.stream_ordering,
-                            event.event_id,
-                            event.room_id,
-                            event.state_key,
-                        ),
-                    )
-
-                # We also update the `local_current_membership` table with
-                # latest invite info. This will usually get updated by the
-                # `current_state_events` handling, unless its an outlier.
-                if event.internal_metadata.is_outlier():
-                    # This should only happen for out of band memberships, so
-                    # we add a paranoia check.
-                    assert event.internal_metadata.is_out_of_band_membership()
-
-                    self.db.simple_upsert_txn(
-                        txn,
-                        table="local_current_membership",
-                        keyvalues={
-                            "room_id": event.room_id,
-                            "user_id": event.state_key,
-                        },
-                        values={
-                            "event_id": event.event_id,
-                            "membership": event.membership,
-                        },
-                    )
-
-    @defer.inlineCallbacks
-    def locally_reject_invite(self, user_id, room_id):
-        sql = (
-            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
-            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-            " AND replaced_by is NULL"
-        )
-
-        def f(txn, stream_ordering):
-            txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
-            # We also clear this entry from `local_current_membership`.
-            # Ideally we'd point to a leave event, but we don't have one, so
-            # nevermind.
-            self.db.simple_delete_txn(
-                txn,
-                table="local_current_membership",
-                keyvalues={"room_id": room_id, "user_id": user_id},
-            )
-
-        with self._stream_id_gen.get_next() as stream_ordering:
-            yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
-
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
 
diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
new file mode 100644
index 0000000000..aa46eb0e10
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We keep the old table here to enable us to roll back. It doesn't matter
+-- that we have dropped all the data here.
+TRUNCATE cache_invalidation_stream;
+
+CREATE TABLE cache_invalidation_stream_by_instance (
+    stream_id       BIGINT NOT NULL,
+    instance_name   TEXT NOT NULL,
+    cache_func      TEXT NOT NULL,
+    keys            TEXT[],
+    invalidation_ts BIGINT
+);
+
+CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
+
+CREATE SEQUENCE cache_invalidation_stream_seq;
diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py
new file mode 100644
index 0000000000..d353f2bcb3
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py
@@ -0,0 +1,80 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This migration rebuilds the device_lists_outbound_last_success table without duplicate
+entries, and with a UNIQUE index.
+"""
+
+import logging
+from io import StringIO
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
+from synapse.storage.types import Cursor
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(*args, **kwargs):
+    pass
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    # some instances might already have this index, in which case we can skip this
+    if isinstance(database_engine, PostgresEngine):
+        cur.execute(
+            """
+            SELECT 1 FROM pg_class WHERE relkind = 'i'
+            AND relname = 'device_lists_outbound_last_success_unique_idx'
+            """
+        )
+
+        if cur.rowcount:
+            logger.info(
+                "Unique index exists on device_lists_outbound_last_success: "
+                "skipping rebuild"
+            )
+            return
+
+    logger.info("Rebuilding device_lists_outbound_last_success with unique index")
+    execute_statements_from_stream(cur, StringIO(_rebuild_commands))
+
+
+# there might be duplicates, so the easiest way to achieve this is to create a new
+# table with the right data, and renaming it into place
+
+_rebuild_commands = """
+DROP TABLE IF EXISTS device_lists_outbound_last_success_new;
+
+CREATE TABLE device_lists_outbound_last_success_new (
+    destination TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    stream_id BIGINT NOT NULL
+);
+
+-- this took about 30 seconds on matrix.org's 16 million rows.
+INSERT INTO device_lists_outbound_last_success_new
+    SELECT destination, user_id, MAX(stream_id) FROM device_lists_outbound_last_success
+    GROUP BY destination, user_id;
+
+-- and this another 30 seconds.
+CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx
+    ON device_lists_outbound_last_success_new (destination, user_id);
+
+DROP TABLE device_lists_outbound_last_success;
+
+ALTER TABLE device_lists_outbound_last_success_new
+    RENAME TO device_lists_outbound_last_success;
+"""
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 47ebb8a214..13f49d8060 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -37,7 +37,55 @@ SearchEntry = namedtuple(
 )
 
 
-class SearchBackgroundUpdateStore(SQLBaseStore):
+class SearchWorkerStore(SQLBaseStore):
+    def store_search_entries_txn(self, txn, entries):
+        """Add entries to the search table
+
+        Args:
+            txn (cursor):
+            entries (iterable[SearchEntry]):
+                entries to be added to the table
+        """
+        if not self.hs.config.enable_search:
+            return
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "INSERT INTO event_search"
+                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+            )
+
+            args = (
+                (
+                    entry.event_id,
+                    entry.room_id,
+                    entry.key,
+                    entry.value,
+                    entry.stream_ordering,
+                    entry.origin_server_ts,
+                )
+                for entry in entries
+            )
+
+            txn.executemany(sql, args)
+
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "INSERT INTO event_search (event_id, room_id, key, value)"
+                " VALUES (?,?,?,?)"
+            )
+            args = (
+                (entry.event_id, entry.room_id, entry.key, entry.value)
+                for entry in entries
+            )
+
+            txn.executemany(sql, args)
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+
+class SearchBackgroundUpdateStore(SearchWorkerStore):
 
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@@ -296,80 +344,11 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
 
         return num_rows
 
-    def store_search_entries_txn(self, txn, entries):
-        """Add entries to the search table
-
-        Args:
-            txn (cursor):
-            entries (iterable[SearchEntry]):
-                entries to be added to the table
-        """
-        if not self.hs.config.enable_search:
-            return
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "INSERT INTO event_search"
-                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
-                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
-            )
-
-            args = (
-                (
-                    entry.event_id,
-                    entry.room_id,
-                    entry.key,
-                    entry.value,
-                    entry.stream_ordering,
-                    entry.origin_server_ts,
-                )
-                for entry in entries
-            )
-
-            txn.executemany(sql, args)
-
-        elif isinstance(self.database_engine, Sqlite3Engine):
-            sql = (
-                "INSERT INTO event_search (event_id, room_id, key, value)"
-                " VALUES (?,?,?,?)"
-            )
-            args = (
-                (entry.event_id, entry.room_id, entry.key, entry.value)
-                for entry in entries
-            )
-
-            txn.executemany(sql, args)
-        else:
-            # This should be unreachable.
-            raise Exception("Unrecognized database engine")
-
 
 class SearchStore(SearchBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
         super(SearchStore, self).__init__(database, db_conn, hs)
 
-    def store_event_search_txn(self, txn, event, key, value):
-        """Add event to the search table
-
-        Args:
-            txn (cursor):
-            event (EventBase):
-            key (str):
-            value (str):
-        """
-        self.store_search_entries_txn(
-            txn,
-            (
-                SearchEntry(
-                    key=key,
-                    value=value,
-                    event_id=event.event_id,
-                    room_id=event.room_id,
-                    stream_ordering=event.internal_metadata.stream_ordering,
-                    origin_server_ts=event.origin_server_ts,
-                ),
-            ),
-        )
-
     @defer.inlineCallbacks
     def search_msgs(self, room_ids, search_term, keys):
         """Performs a full text search over events with given keys.
diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py
index 563216b63c..36244d9f5d 100644
--- a/synapse/storage/data_stores/main/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -13,23 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import six
-
 from unpaddedbase64 import encode_base64
 
 from twisted.internet import defer
 
-from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
-    db_binary_type = six.moves.builtins.buffer
-else:
-    db_binary_type = memoryview
-
 
 class SignatureWorkerStore(SQLBaseStore):
     @cached()
@@ -79,23 +69,3 @@ class SignatureWorkerStore(SQLBaseStore):
 
 class SignatureStore(SignatureWorkerStore):
     """Persistence for event signatures and hashes"""
-
-    def _store_event_reference_hashes_txn(self, txn, events):
-        """Store a hash for a PDU
-        Args:
-            txn (cursor):
-            events (list): list of Events.
-        """
-
-        vals = []
-        for event in events:
-            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append(
-                {
-                    "event_id": event.event_id,
-                    "algorithm": ref_alg,
-                    "hash": db_binary_type(ref_hash_bytes),
-                }
-            )
-
-        self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 3a3b9a8e72..347cc50778 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -16,17 +16,12 @@
 import collections.abc
 import logging
 from collections import namedtuple
-from typing import Iterable, Tuple
-
-from six import iteritems
 
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
@@ -34,7 +29,6 @@ from synapse.storage.database import Database
 from synapse.storage.state import StateFilter
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.stringutils import to_ascii
 
 logger = logging.getLogger(__name__)
 
@@ -190,9 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 (room_id,),
             )
 
-            return {
-                (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
-            }
+            return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
 
         return self.db.runInteraction(
             "get_current_state_ids", _get_current_state_ids_txn
@@ -473,33 +465,3 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
 
     def __init__(self, database: Database, db_conn, hs):
         super(StateStore, self).__init__(database, db_conn, hs)
-
-    def _store_event_state_mappings_txn(
-        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
-    ):
-        state_groups = {}
-        for event, context in events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                continue
-
-            # if the event was rejected, just give it the same state as its
-            # predecessor.
-            if context.rejected:
-                state_groups[event.event_id] = context.state_group_before_event
-                continue
-
-            state_groups[event.event_id] = context.state_group
-
-        self.db.simple_insert_many_txn(
-            txn,
-            table="event_to_state_groups",
-            values=[
-                {"state_group": state_group_id, "event_id": event_id}
-                for event_id, state_group_id in iteritems(state_groups)
-            ],
-        )
-
-        for event_id, state_group_id in iteritems(state_groups):
-            txn.call_after(
-                self._get_state_group_for_event.prefill, (event_id,), state_group_id
-            )
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py
index 5b07c2fbc0..a9bf457939 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -16,8 +16,6 @@
 import logging
 from collections import namedtuple
 
-import six
-
 from canonicaljson import encode_canonical_json
 
 from twisted.internet import defer
@@ -27,12 +25,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
-    db_binary_type = six.moves.builtins.buffer
-else:
-    db_binary_type = memoryview
+db_binary_type = memoryview
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 57a5267663..f3ad1e4369 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -28,7 +28,6 @@ from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateSt
 from synapse.storage.database import Database
 from synapse.storage.state import StateFilter
 from synapse.types import StateMap
-from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
 
@@ -90,11 +89,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         self._state_group_cache = DictionaryCache(
             "*stateGroupCache*",
             # TODO: this hasn't been tuned yet
-            50000 * get_cache_factor_for("stateGroupCache"),
+            50000,
         )
         self._state_group_members_cache = DictionaryCache(
-            "*stateGroupMembersCache*",
-            500000 * get_cache_factor_for("stateGroupMembersCache"),
+            "*stateGroupMembersCache*", 500000,
         )
 
     @cached(max_entries=10000, iterable=True)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 50f475bfd3..b112ff3df2 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,7 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
-from synapse.util.stringutils import exception_to_unicode
+from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
 
@@ -422,20 +422,14 @@ class Database(object):
                     # This can happen if the database disappears mid
                     # transaction.
                     logger.warning(
-                        "[TXN OPERROR] {%s} %s %d/%d",
-                        name,
-                        exception_to_unicode(e),
-                        i,
-                        N,
+                        "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
                     )
                     if i < N:
                         i += 1
                         try:
                             conn.rollback()
                         except self.engine.module.Error as e1:
-                            logger.warning(
-                                "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
-                            )
+                            logger.warning("[TXN EROLL] {%s} %s", name, e1)
                         continue
                     raise
                 except self.engine.module.DatabaseError as e:
@@ -447,9 +441,7 @@ class Database(object):
                                 conn.rollback()
                             except self.engine.module.Error as e1:
                                 logger.warning(
-                                    "[TXN EROLL] {%s} %s",
-                                    name,
-                                    exception_to_unicode(e1),
+                                    "[TXN EROLL] {%s} %s", name, e1,
                                 )
                             continue
                     raise
@@ -889,20 +881,24 @@ class Database(object):
         txn.execute(sql, list(allvalues.values()))
 
     def simple_upsert_many_txn(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[str]],
+    ) -> None:
         """
         Upsert, many times.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_many_txn_native_upsert(
@@ -914,20 +910,24 @@ class Database(object):
             )
 
     def simple_upsert_many_txn_emulated(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Iterable[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[str]],
+    ) -> None:
         """
         Upsert, many times, but without native UPSERT support or batching.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         # No value columns, therefore make a blank list so that the following
         # zip() works correctly.
@@ -941,20 +941,24 @@ class Database(object):
             self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
 
     def simple_upsert_many_txn_native_upsert(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[Any]],
+    ) -> None:
         """
         Upsert, many times, using batching where possible.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         allnames = []  # type: List[str]
         allnames.extend(key_names)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 0f9ac1cf09..f159400a87 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -23,7 +23,6 @@ from typing import Iterable, List, Optional, Set, Tuple
 from six import iteritems
 from six.moves import range
 
-import attr
 from prometheus_client import Counter, Histogram
 
 from twisted.internet import defer
@@ -35,6 +34,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.state import StateResolutionStore
 from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main.events import DeltaState
 from synapse.types import StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
@@ -73,22 +73,6 @@ stale_forward_extremities_counter = Histogram(
 )
 
 
-@attr.s(slots=True)
-class DeltaState:
-    """Deltas to use to update the `current_state_events` table.
-
-    Attributes:
-        to_delete: List of type/state_keys to delete from current state
-        to_insert: Map of state to upsert into current state
-        no_longer_in_room: The server is not longer in the room, so the room
-            should e.g. be removed from `current_state_events` table.
-    """
-
-    to_delete = attr.ib(type=List[Tuple[str, str]])
-    to_insert = attr.ib(type=StateMap[str])
-    no_longer_in_room = attr.ib(type=bool, default=False)
-
-
 class _EventPeristenceQueue(object):
     """Queues up events so that they can be persisted in bulk with only one
     concurrent transaction per room.
@@ -205,6 +189,7 @@ class EventsPersistenceStorage(object):
         # store for now.
         self.main_store = stores.main
         self.state_store = stores.state
+        self.persist_events_store = stores.persist_events
 
         self._clock = hs.get_clock()
         self.is_mine_id = hs.is_mine_id
@@ -445,7 +430,7 @@ class EventsPersistenceStorage(object):
                         if current_state is not None:
                             current_state_for_room[room_id] = current_state
 
-            await self.main_store._persist_events_and_state_updates(
+            await self.persist_events_store._persist_events_and_state_updates(
                 chunk,
                 current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
@@ -491,13 +476,15 @@ class EventsPersistenceStorage(object):
         )
 
         # Remove any events which are prev_events of any existing events.
-        existing_prevs = await self.main_store._get_events_which_are_prevs(result)
+        existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
+            result
+        )
         result.difference_update(existing_prevs)
 
         # Finally handle the case where the new events have soft-failed prev
         # events. If they do we need to remove them and their prev events,
         # otherwise we end up with dangling extremities.
-        existing_prevs = await self.main_store._get_prevs_before_rejected(
+        existing_prevs = await self.persist_events_store._get_prevs_before_rejected(
             e_id for event in new_events for e_id in event.prev_event_ids()
         )
         result.difference_update(existing_prevs)
@@ -753,8 +740,8 @@ class EventsPersistenceStorage(object):
         # whose state has changed as we've already their new state above.
         users_to_ignore = [
             state_key
-            for _, state_key in itertools.chain(delta.to_insert, delta.to_delete)
-            if self.is_mine_id(state_key)
+            for typ, state_key in itertools.chain(delta.to_insert, delta.to_delete)
+            if typ == EventTypes.Member and self.is_mine_id(state_key)
         ]
 
         if await self.main_store.is_local_host_in_room_ignoring_users(
@@ -799,3 +786,9 @@ class EventsPersistenceStorage(object):
 
         for user_id in left_users:
             await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+    async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+        """Mark the invite has having been rejected even though we failed to
+        create a leave event for it.
+        """
+        return await self.persist_events_store.locally_reject_invite(user_id, room_id)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1712932f31..9afc145340 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -19,16 +19,20 @@ import logging
 import os
 import re
 from collections import Counter
+from typing import TextIO
 
 import attr
 
 from synapse.storage.engines.postgres import PostgresEngine
+from synapse.storage.types import Cursor
 
 logger = logging.getLogger(__name__)
 
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
+# XXX: If you're about to bump this to 59 (or higher) please create an update
+# that drops the unused `cache_invalidation_stream` table, as per #7436!
 SCHEMA_VERSION = 58
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -477,8 +481,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
             )
 
         logger.info("applying schema %s for %s", name, modname)
-        for statement in get_statements(stream):
-            cur.execute(statement)
+        execute_statements_from_stream(cur, stream)
 
         # Mark as done.
         cur.execute(
@@ -536,8 +539,12 @@ def get_statements(f):
 
 def executescript(txn, schema_path):
     with open(schema_path, "r") as f:
-        for statement in get_statements(f):
-            txn.execute(statement)
+        execute_statements_from_stream(txn, f)
+
+
+def execute_statements_from_stream(cur: Cursor, f: TextIO):
+    for statement in get_statements(f):
+        cur.execute(statement)
 
 
 def _get_or_create_schema_state(txn, database_engine):
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d851beaa5..f89ce0bed2 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
 import contextlib
 import threading
 from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
 
 
 class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
             )
-        self._unfinished_ids = deque()
+        self._unfinished_ids = deque()  # type: Deque[int]
 
     def get_next(self):
         """
@@ -161,9 +166,10 @@ class ChainedIdGenerator(object):
 
     def __init__(self, chained_generator, db_conn, table, column):
         self.chained_generator = chained_generator
+        self._table = table
         self._lock = threading.Lock()
         self._current_max = _load_current_id(db_conn, table, column)
-        self._unfinished_ids = deque()
+        self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
 
     def get_next(self):
         """
@@ -198,3 +204,173 @@ class ChainedIdGenerator(object):
                 return stream_id - 1, chained_id
 
             return self._current_max, self.chained_generator.get_current_token()
+
+    def advance(self, token: int):
+        """Stub implementation for advancing the token when receiving updates
+        over replication; raises an exception as this instance should be the
+        only source of updates.
+        """
+
+        raise Exception(
+            "Attempted to advance token on source for table %r", self._table
+        )
+
+
+class MultiWriterIdGenerator:
+    """An ID generator that tracks a stream that can have multiple writers.
+
+    Uses a Postgres sequence to coordinate ID assignment, but positions of other
+    writers will only get updated when `advance` is called (by replication).
+
+    Note: Only works with Postgres.
+
+    Args:
+        db_conn
+        db
+        instance_name: The name of this instance.
+        table: Database table associated with stream.
+        instance_column: Column that stores the row's writer's instance name
+        id_column: Column that stores the stream ID.
+        sequence_name: The name of the postgres sequence used to generate new
+            IDs.
+    """
+
+    def __init__(
+        self,
+        db_conn,
+        db: Database,
+        instance_name: str,
+        table: str,
+        instance_column: str,
+        id_column: str,
+        sequence_name: str,
+    ):
+        self._db = db
+        self._instance_name = instance_name
+        self._sequence_name = sequence_name
+
+        # We lock as some functions may be called from DB threads.
+        self._lock = threading.Lock()
+
+        self._current_positions = self._load_current_ids(
+            db_conn, table, instance_column, id_column
+        )
+
+        # Set of local IDs that we're still processing. The current position
+        # should be less than the minimum of this set (if not empty).
+        self._unfinished_ids = set()  # type: Set[int]
+
+    def _load_current_ids(
+        self, db_conn, table: str, instance_column: str, id_column: str
+    ) -> Dict[str, int]:
+        sql = """
+            SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+            GROUP BY %(instance)s
+        """ % {
+            "instance": instance_column,
+            "id": id_column,
+            "table": table,
+        }
+
+        cur = db_conn.cursor()
+        cur.execute(sql)
+
+        # `cur` is an iterable over returned rows, which are 2-tuples.
+        current_positions = dict(cur)
+
+        cur.close()
+
+        return current_positions
+
+    def _load_next_id_txn(self, txn):
+        txn.execute("SELECT nextval(?)", (self._sequence_name,))
+        (next_id,) = txn.fetchone()
+        return next_id
+
+    async def get_next(self):
+        """
+        Usage:
+            with await stream_id_gen.get_next() as stream_id:
+                # ... persist event ...
+        """
+        next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+        # Assert the fetched ID is actually greater than what we currently
+        # believe the ID to be. If not, then the sequence and table have got
+        # out of sync somehow.
+        assert self.get_current_token() < next_id
+
+        with self._lock:
+            self._unfinished_ids.add(next_id)
+
+        @contextlib.contextmanager
+        def manager():
+            try:
+                yield next_id
+            finally:
+                self._mark_id_as_finished(next_id)
+
+        return manager()
+
+    def get_next_txn(self, txn: LoggingTransaction):
+        """
+        Usage:
+
+            stream_id = stream_id_gen.get_next(txn)
+            # ... persist event ...
+        """
+
+        next_id = self._load_next_id_txn(txn)
+
+        with self._lock:
+            self._unfinished_ids.add(next_id)
+
+        txn.call_after(self._mark_id_as_finished, next_id)
+        txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+        return next_id
+
+    def _mark_id_as_finished(self, next_id: int):
+        """The ID has finished being processed so we should advance the
+        current poistion if possible.
+        """
+
+        with self._lock:
+            self._unfinished_ids.discard(next_id)
+
+            # Figure out if its safe to advance the position by checking there
+            # aren't any lower allocated IDs that are yet to finish.
+            if all(c > next_id for c in self._unfinished_ids):
+                curr = self._current_positions.get(self._instance_name, 0)
+                self._current_positions[self._instance_name] = max(curr, next_id)
+
+    def get_current_token(self, instance_name: str = None) -> int:
+        """Gets the current position of a named writer (defaults to current
+        instance).
+
+        Returns 0 if we don't have a position for the named writer (likely due
+        to it being a new writer).
+        """
+
+        if instance_name is None:
+            instance_name = self._instance_name
+
+        with self._lock:
+            return self._current_positions.get(instance_name, 0)
+
+    def get_positions(self) -> Dict[str, int]:
+        """Get a copy of the current positon map.
+        """
+
+        with self._lock:
+            return dict(self._current_positions)
+
+    def advance(self, instance_name: str, new_id: int):
+        """Advance the postion of the named writer to the given ID, if greater
+        than existing entry.
+        """
+
+        with self._lock:
+            self._current_positions[instance_name] = max(
+                new_id, self._current_positions.get(instance_name, 0)
+            )