summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/__init__.py193
-rw-r--r--synapse/storage/databases/main/account_data.py9
-rw-r--r--synapse/storage/databases/main/appservice.py171
-rw-r--r--synapse/storage/databases/main/censor_events.py21
-rw-r--r--synapse/storage/databases/main/client_ips.py117
-rw-r--r--synapse/storage/databases/main/devices.py289
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py129
-rw-r--r--synapse/storage/databases/main/event_federation.py142
-rw-r--r--synapse/storage/databases/main/event_push_actions.py267
-rw-r--r--synapse/storage/databases/main/events.py73
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py231
-rw-r--r--synapse/storage/databases/main/keys.py5
-rw-r--r--synapse/storage/databases/main/media_repository.py131
-rw-r--r--synapse/storage/databases/main/metrics.py213
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py110
-rw-r--r--synapse/storage/databases/main/profile.py92
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/pusher.py2
-rw-r--r--synapse/storage/databases/main/receipts.py71
-rw-r--r--synapse/storage/databases/main/registration.py471
-rw-r--r--synapse/storage/databases/main/room.py306
-rw-r--r--synapse/storage/databases/main/roommember.py68
-rw-r--r--synapse/storage/databases/main/schema/delta/20/pushers.py19
-rw-r--r--synapse/storage/databases/main/schema/delta/25/fts.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/27/ts.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/30/as_users.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/31/pushers.py19
-rw-r--r--synapse/storage/databases/main/schema/delta/31/search_update.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/33/event_fields.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/33/remote_media_ts.py5
-rw-r--r--synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py7
-rw-r--r--synapse/storage/databases/main/schema/delta/57/local_current_membership.py1
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres12
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11dehydration.sql20
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11fallback.sql24
-rw-r--r--synapse/storage/databases/main/schema/delta/58/12room_stats.sql4
-rw-r--r--synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres25
-rw-r--r--synapse/storage/databases/main/schema/delta/58/19txn_id.sql40
-rw-r--r--synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql1
-rw-r--r--synapse/storage/databases/main/schema/delta/58/22puppet_token.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql2
-rw-r--r--synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql19
-rw-r--r--synapse/storage/databases/main/stats.py127
-rw-r--r--synapse/storage/databases/main/stream.py295
-rw-r--r--synapse/storage/databases/main/transactions.py118
-rw-r--r--synapse/storage/databases/main/ui_auth.py6
-rw-r--r--synapse/storage/databases/main/user_directory.py45
52 files changed, 2717 insertions, 1292 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0cb12f4c61..43660ec4fb 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,9 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import calendar
 import logging
-import time
 from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import PresenceState
@@ -148,7 +146,6 @@ class DataStore(
             db_conn, "e2e_cross_signing_keys", "stream_id"
         )
 
-        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
@@ -268,9 +265,6 @@ class DataStore(
         self._stream_order_on_start = self.get_room_max_stream_ordering()
         self._min_stream_order_on_start = self.get_room_min_stream_ordering()
 
-        # Used in _generate_user_daily_visits to keep track of progress
-        self._last_user_visit_update = self._get_start_of_day()
-
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
@@ -289,7 +283,6 @@ class DataStore(
             " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
             " WHERE state != ?"
         )
-        sql = self.database_engine.convert_param_style(sql)
 
         txn = db_conn.cursor()
         txn.execute(sql, (PresenceState.OFFLINE,))
@@ -301,192 +294,6 @@ class DataStore(
 
         return [UserPresenceState(**row) for row in rows]
 
-    async def count_daily_users(self) -> int:
-        """
-        Counts the number of users who used this homeserver in the last 24 hours.
-        """
-        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-        return await self.db_pool.runInteraction(
-            "count_daily_users", self._count_users, yesterday
-        )
-
-    async def count_monthly_users(self) -> int:
-        """
-        Counts the number of users who used this homeserver in the last 30 days.
-        Note this method is intended for phonehome metrics only and is different
-        from the mau figure in synapse.storage.monthly_active_users which,
-        amongst other things, includes a 3 day grace period before a user counts.
-        """
-        thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-        return await self.db_pool.runInteraction(
-            "count_monthly_users", self._count_users, thirty_days_ago
-        )
-
-    def _count_users(self, txn, time_from):
-        """
-        Returns number of users seen in the past time_from period
-        """
-        sql = """
-            SELECT COALESCE(count(*), 0) FROM (
-                SELECT user_id FROM user_ips
-                WHERE last_seen > ?
-                GROUP BY user_id
-            ) u
-        """
-        txn.execute(sql, (time_from,))
-        (count,) = txn.fetchone()
-        return count
-
-    async def count_r30_users(self) -> Dict[str, int]:
-        """
-        Counts the number of 30 day retained users, defined as:-
-         * Users who have created their accounts more than 30 days ago
-         * Where last seen at most 30 days ago
-         * Where account creation and last_seen are > 30 days apart
-
-        Returns:
-             A mapping of counts globally as well as broken out by platform.
-        """
-
-        def _count_r30_users(txn):
-            thirty_days_in_secs = 86400 * 30
-            now = int(self._clock.time())
-            thirty_days_ago_in_secs = now - thirty_days_in_secs
-
-            sql = """
-                SELECT platform, COALESCE(count(*), 0) FROM (
-                     SELECT
-                        users.name, platform, users.creation_ts * 1000,
-                        MAX(uip.last_seen)
-                     FROM users
-                     INNER JOIN (
-                         SELECT
-                         user_id,
-                         last_seen,
-                         CASE
-                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
-                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
-                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
-                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
-                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
-                             ELSE 'unknown'
-                         END
-                         AS platform
-                         FROM user_ips
-                     ) uip
-                     ON users.name = uip.user_id
-                     AND users.appservice_id is NULL
-                     AND users.creation_ts < ?
-                     AND uip.last_seen/1000 > ?
-                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
-                     GROUP BY users.name, platform, users.creation_ts
-                ) u GROUP BY platform
-            """
-
-            results = {}
-            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
-            for row in txn:
-                if row[0] == "unknown":
-                    pass
-                results[row[0]] = row[1]
-
-            sql = """
-                SELECT COALESCE(count(*), 0) FROM (
-                    SELECT users.name, users.creation_ts * 1000,
-                                                        MAX(uip.last_seen)
-                    FROM users
-                    INNER JOIN (
-                        SELECT
-                        user_id,
-                        last_seen
-                        FROM user_ips
-                    ) uip
-                    ON users.name = uip.user_id
-                    AND appservice_id is NULL
-                    AND users.creation_ts < ?
-                    AND uip.last_seen/1000 > ?
-                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
-                    GROUP BY users.name, users.creation_ts
-                ) u
-            """
-
-            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
-            (count,) = txn.fetchone()
-            results["all"] = count
-
-            return results
-
-        return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
-
-    def _get_start_of_day(self):
-        """
-        Returns millisecond unixtime for start of UTC day.
-        """
-        now = time.gmtime()
-        today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
-        return today_start * 1000
-
-    async def generate_user_daily_visits(self) -> None:
-        """
-        Generates daily visit data for use in cohort/ retention analysis
-        """
-
-        def _generate_user_daily_visits(txn):
-            logger.info("Calling _generate_user_daily_visits")
-            today_start = self._get_start_of_day()
-            a_day_in_milliseconds = 24 * 60 * 60 * 1000
-            now = self.clock.time_msec()
-
-            sql = """
-                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
-                    SELECT u.user_id, u.device_id, ?
-                    FROM user_ips AS u
-                    LEFT JOIN (
-                      SELECT user_id, device_id, timestamp FROM user_daily_visits
-                      WHERE timestamp = ?
-                    ) udv
-                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
-                    INNER JOIN users ON users.name=u.user_id
-                    WHERE last_seen > ? AND last_seen <= ?
-                    AND udv.timestamp IS NULL AND users.is_guest=0
-                    AND users.appservice_id IS NULL
-                    GROUP BY u.user_id, u.device_id
-            """
-
-            # This means that the day has rolled over but there could still
-            # be entries from the previous day. There is an edge case
-            # where if the user logs in at 23:59 and overwrites their
-            # last_seen at 00:01 then they will not be counted in the
-            # previous day's stats - it is important that the query is run
-            # often to minimise this case.
-            if today_start > self._last_user_visit_update:
-                yesterday_start = today_start - a_day_in_milliseconds
-                txn.execute(
-                    sql,
-                    (
-                        yesterday_start,
-                        yesterday_start,
-                        self._last_user_visit_update,
-                        today_start,
-                    ),
-                )
-                self._last_user_visit_update = today_start
-
-            txn.execute(
-                sql, (today_start, today_start, self._last_user_visit_update, now)
-            )
-            # Update _last_user_visit_update to now. The reason to do this
-            # rather just clamping to the beginning of the day is to limit
-            # the size of the join - meaning that the query can be run more
-            # frequently
-            self._last_user_visit_update = now
-
-        await self.db_pool.runInteraction(
-            "generate_user_daily_visits", _generate_user_daily_visits
-        )
-
     async def get_users(self) -> List[Dict[str, Any]]:
         """Function to retrieve a list of users in users table.
 
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index ef81d73573..49ee23470d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -18,6 +18,7 @@ import abc
 import logging
 from typing import Dict, List, Optional, Tuple
 
+from synapse.api.constants import AccountDataTypes
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
@@ -291,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
         self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
     ) -> bool:
         ignored_account_data = await self.get_global_account_data_by_type_for_user(
-            "m.ignored_user_list",
+            AccountDataTypes.IGNORED_USER_LIST,
             ignorer_user_id,
             on_invalidate=cache_context.invalidate,
         )
         if not ignored_account_data:
             return False
 
-        return ignored_user_id in ignored_account_data.get("ignored_users", {})
+        try:
+            return ignored_user_id in ignored_account_data.get("ignored_users", {})
+        except TypeError:
+            # The type of the ignored_users field is invalid.
+            return False
 
 
 class AccountDataStore(AccountDataWorkerStore):
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 85f6b1e3fd..e550cbc866 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -15,18 +15,31 @@
 # limitations under the License.
 import logging
 import re
+from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
 
-from synapse.appservice import AppServiceTransaction
+from synapse.appservice import (
+    ApplicationService,
+    ApplicationServiceState,
+    AppServiceTransaction,
+)
 from synapse.config.appservice import load_appservices
+from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.types import Connection
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-def _make_exclusive_regex(services_cache):
+def _make_exclusive_regex(
+    services_cache: List[ApplicationService],
+) -> Optional[Pattern]:
     # We precompile a regex constructed from all the regexes that the AS's
     # have registered for exclusive users.
     exclusive_user_regexes = [
@@ -36,17 +49,19 @@ def _make_exclusive_regex(services_cache):
     ]
     if exclusive_user_regexes:
         exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
-        exclusive_user_regex = re.compile(exclusive_user_regex)
+        exclusive_user_pattern = re.compile(
+            exclusive_user_regex
+        )  # type: Optional[Pattern]
     else:
         # We handle this case specially otherwise the constructed regex
         # will always match
-        exclusive_user_regex = None
+        exclusive_user_pattern = None
 
-    return exclusive_user_regex
+    return exclusive_user_pattern
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.app_service_config_files
         )
@@ -57,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
     def get_app_services(self):
         return self.services_cache
 
-    def get_if_app_services_interested_in_user(self, user_id):
+    def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
         """Check if the user is one associated with an app service (exclusively)
         """
         if self.exclusive_user_regex:
@@ -65,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         else:
             return False
 
-    def get_app_service_by_user_id(self, user_id):
+    def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
         """Retrieve an application service from their user ID.
 
         All application services have associated with them a particular user ID.
@@ -74,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         a user ID to an application service.
 
         Args:
-            user_id(str): The user ID to see if it is an application service.
+            user_id: The user ID to see if it is an application service.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.sender == user_id:
                 return service
         return None
 
-    def get_app_service_by_token(self, token):
+    def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
         """Get the application service with the given appservice token.
 
         Args:
-            token (str): The application service token.
+            token: The application service token.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.token == token:
                 return service
         return None
 
-    def get_app_service_by_id(self, as_id):
+    def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
         """Get the application service with the given appservice ID.
 
         Args:
-            as_id (str): The application service ID.
+            as_id: The application service ID.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.id == as_id:
@@ -121,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
 class ApplicationServiceTransactionWorkerStore(
     ApplicationServiceWorkerStore, EventsWorkerStore
 ):
-    async def get_appservices_by_state(self, state):
+    async def get_appservices_by_state(
+        self, state: ApplicationServiceState
+    ) -> List[ApplicationService]:
         """Get a list of application services based on their state.
 
         Args:
-            state(ApplicationServiceState): The state to filter on.
+            state: The state to filter on.
         Returns:
             A list of ApplicationServices, which may be empty.
         """
@@ -142,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
                     services.append(service)
         return services
 
-    async def get_appservice_state(self, service):
+    async def get_appservice_state(
+        self, service: ApplicationService
+    ) -> Optional[ApplicationServiceState]:
         """Get the application service state.
 
         Args:
-            service(ApplicationService): The service whose state to set.
+            service: The service whose state to set.
         Returns:
-            An ApplicationServiceState.
+            An ApplicationServiceState or none.
         """
         result = await self.db_pool.simple_select_one(
             "application_services_state",
@@ -161,26 +180,36 @@ class ApplicationServiceTransactionWorkerStore(
             return result.get("state")
         return None
 
-    async def set_appservice_state(self, service, state) -> None:
+    async def set_appservice_state(
+        self, service: ApplicationService, state: ApplicationServiceState
+    ) -> None:
         """Set the application service state.
 
         Args:
-            service(ApplicationService): The service whose state to set.
-            state(ApplicationServiceState): The connectivity state to apply.
+            service: The service whose state to set.
+            state: The connectivity state to apply.
         """
         await self.db_pool.simple_upsert(
             "application_services_state", {"as_id": service.id}, {"state": state}
         )
 
-    async def create_appservice_txn(self, service, events):
+    async def create_appservice_txn(
+        self,
+        service: ApplicationService,
+        events: List[EventBase],
+        ephemeral: List[JsonDict],
+    ) -> AppServiceTransaction:
         """Atomically creates a new transaction for this application service
-        with the given list of events.
+        with the given list of events. Ephemeral events are NOT persisted to the
+        database and are not resent if a transaction is retried.
 
         Args:
-            service(ApplicationService): The service who the transaction is for.
-            events(list<Event>): A list of events to put in the transaction.
+            service: The service who the transaction is for.
+            events: A list of persistent events to put in the transaction.
+            ephemeral: A list of ephemeral events to put in the transaction.
+
         Returns:
-            AppServiceTransaction: A new transaction.
+            A new transaction.
         """
 
         def _create_appservice_txn(txn):
@@ -207,19 +236,22 @@ class ApplicationServiceTransactionWorkerStore(
                 "VALUES(?,?,?)",
                 (service.id, new_txn_id, event_ids),
             )
-            return AppServiceTransaction(service=service, id=new_txn_id, events=events)
+            return AppServiceTransaction(
+                service=service, id=new_txn_id, events=events, ephemeral=ephemeral
+            )
 
         return await self.db_pool.runInteraction(
             "create_appservice_txn", _create_appservice_txn
         )
 
-    async def complete_appservice_txn(self, txn_id, service) -> None:
+    async def complete_appservice_txn(
+        self, txn_id: int, service: ApplicationService
+    ) -> None:
         """Completes an application service transaction.
 
         Args:
-            txn_id(str): The transaction ID being completed.
-            service(ApplicationService): The application service which was sent
-            this transaction.
+            txn_id: The transaction ID being completed.
+            service: The application service which was sent this transaction.
         """
         txn_id = int(txn_id)
 
@@ -259,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
             "complete_appservice_txn", _complete_appservice_txn
         )
 
-    async def get_oldest_unsent_txn(self, service):
-        """Get the oldest transaction which has not been sent for this
-        service.
+    async def get_oldest_unsent_txn(
+        self, service: ApplicationService
+    ) -> Optional[AppServiceTransaction]:
+        """Get the oldest transaction which has not been sent for this service.
 
         Args:
-            service(ApplicationService): The app service to get the oldest txn.
+            service: The app service to get the oldest txn.
         Returns:
             An AppServiceTransaction or None.
         """
@@ -296,9 +329,11 @@ class ApplicationServiceTransactionWorkerStore(
 
         events = await self.get_events_as_list(event_ids)
 
-        return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
+        return AppServiceTransaction(
+            service=service, id=entry["txn_id"], events=events, ephemeral=[]
+        )
 
-    def _get_last_txn(self, txn, service_id):
+    def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
         txn.execute(
             "SELECT last_txn FROM application_services_state WHERE as_id=?",
             (service_id,),
@@ -309,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
         else:
             return int(last_txn_id[0])  # select 'last_txn' col
 
-    async def set_appservice_last_pos(self, pos) -> None:
+    async def set_appservice_last_pos(self, pos: int) -> None:
         def set_appservice_last_pos_txn(txn):
             txn.execute(
                 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@@ -319,8 +354,10 @@ class ApplicationServiceTransactionWorkerStore(
             "set_appservice_last_pos", set_appservice_last_pos_txn
         )
 
-    async def get_new_events_for_appservice(self, current_id, limit):
-        """Get all new evnets"""
+    async def get_new_events_for_appservice(
+        self, current_id: int, limit: int
+    ) -> Tuple[int, List[EventBase]]:
+        """Get all new events for an appservice"""
 
         def get_new_events_for_appservice_txn(txn):
             sql = (
@@ -351,6 +388,54 @@ class ApplicationServiceTransactionWorkerStore(
 
         return upper_bound, events
 
+    async def get_type_stream_id_for_appservice(
+        self, service: ApplicationService, type: str
+    ) -> int:
+        if type not in ("read_receipt", "presence"):
+            raise ValueError(
+                "Expected type to be a valid application stream id type, got %s"
+                % (type,)
+            )
+
+        def get_type_stream_id_for_appservice_txn(txn):
+            stream_id_type = "%s_stream_id" % type
+            txn.execute(
+                # We do NOT want to escape `stream_id_type`.
+                "SELECT %s FROM application_services_state WHERE as_id=?"
+                % stream_id_type,
+                (service.id,),
+            )
+            last_stream_id = txn.fetchone()
+            if last_stream_id is None or last_stream_id[0] is None:  # no row exists
+                return 0
+            else:
+                return int(last_stream_id[0])
+
+        return await self.db_pool.runInteraction(
+            "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
+        )
+
+    async def set_type_stream_id_for_appservice(
+        self, service: ApplicationService, type: str, pos: Optional[int]
+    ) -> None:
+        if type not in ("read_receipt", "presence"):
+            raise ValueError(
+                "Expected type to be a valid application stream id type, got %s"
+                % (type,)
+            )
+
+        def set_type_stream_id_for_appservice_txn(txn):
+            stream_id_type = "%s_stream_id" % type
+            txn.execute(
+                "UPDATE application_services_state SET %s = ? WHERE as_id=?"
+                % stream_id_type,
+                (pos, service.id),
+            )
+
+        await self.db_pool.runInteraction(
+            "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
+        )
+
 
 class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
     # This is currently empty due to there not being any AS storage functions
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index f211ddbaf8..3e26d5ba87 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -17,12 +17,12 @@ import logging
 from typing import TYPE_CHECKING
 
 from synapse.events.utils import prune_event_dict
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.databases.main.events import encode_json
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import json_encoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -35,14 +35,13 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
-        def _censor_redactions():
-            return run_as_background_process(
-                "_censor_redactions", self._censor_redactions
-            )
-
-        if self.hs.config.redaction_retention_period is not None:
-            hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+        if (
+            hs.config.run_background_tasks
+            and self.hs.config.redaction_retention_period is not None
+        ):
+            hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
 
+    @wrap_as_background_process("_censor_redactions")
     async def _censor_redactions(self):
         """Censors all redactions older than the configured period that haven't
         been censored yet.
@@ -105,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 and original_event.internal_metadata.is_redacted()
             ):
                 # Redaction was allowed
-                pruned_json = encode_json(
+                pruned_json = json_encoder.encode(
                     prune_event_dict(
                         original_event.room_version, original_event.get_dict()
                     )
@@ -171,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 return
 
             # Prune the event's dict then convert it to JSON.
-            pruned_json = encode_json(
+            pruned_json = json_encoder.encode(
                 prune_event_dict(event.room_version, event.get_dict())
             )
 
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 239c7a949c..339bd691a4 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
 
@@ -351,16 +351,70 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         return updated
 
 
-class ClientIpStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self.user_ips_max_age = hs.config.user_ips_max_age
+
+        if hs.config.run_background_tasks and self.user_ips_max_age:
+            self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
+    @wrap_as_background_process("prune_old_user_ips")
+    async def _prune_old_user_ips(self):
+        """Removes entries in user IPs older than the configured period.
+        """
+
+        if self.user_ips_max_age is None:
+            # Nothing to do
+            return
+
+        if not await self.db_pool.updates.has_completed_background_update(
+            "devices_last_seen"
+        ):
+            # Only start pruning if we have finished populating the devices
+            # last seen info.
+            return
+
+        # We do a slightly funky SQL delete to ensure we don't try and delete
+        # too much at once (as the table may be very large from before we
+        # started pruning).
+        #
+        # This works by finding the max last_seen that is less than the given
+        # time, but has no more than N rows before it, deleting all rows with
+        # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+        # returns exactly one row).
+        sql = """
+            DELETE FROM user_ips
+            WHERE last_seen <= (
+                SELECT COALESCE(MAX(last_seen), -1)
+                FROM (
+                    SELECT last_seen FROM user_ips
+                    WHERE last_seen <= ?
+                    ORDER BY last_seen ASC
+                    LIMIT 5000
+                ) AS u
+            )
+        """
 
-        self.client_ip_last_seen = Cache(
-            name="client_ip_last_seen", keylen=4, max_entries=50000
+        timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+        def _prune_old_user_ips_txn(txn):
+            txn.execute(sql, (timestamp,))
+
+        await self.db_pool.runInteraction(
+            "_prune_old_user_ips", _prune_old_user_ips_txn
         )
 
-        super().__init__(database, db_conn, hs)
 
-        self.user_ips_max_age = hs.config.user_ips_max_age
+class ClientIpStore(ClientIpWorkerStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+
+        self.client_ip_last_seen = LruCache(
+            cache_name="client_ip_last_seen", keylen=4, max_size=50000
+        )
+
+        super().__init__(database, db_conn, hs)
 
         # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
         self._batch_row_update = {}
@@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             "before", "shutdown", self._update_client_ips_batch
         )
 
-        if self.user_ips_max_age:
-            self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
-
     async def insert_client_ip(
         self, user_id, access_token, ip, user_agent, device_id, now=None
     ):
@@ -391,7 +442,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
             return
 
-        self.client_ip_last_seen.prefill(key, now)
+        self.client_ip_last_seen.set(key, now)
 
         self._batch_row_update[key] = (user_agent, device_id, now)
 
@@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             }
             for (access_token, ip), (user_agent, last_seen) in results.items()
         ]
-
-    @wrap_as_background_process("prune_old_user_ips")
-    async def _prune_old_user_ips(self):
-        """Removes entries in user IPs older than the configured period.
-        """
-
-        if self.user_ips_max_age is None:
-            # Nothing to do
-            return
-
-        if not await self.db_pool.updates.has_completed_background_update(
-            "devices_last_seen"
-        ):
-            # Only start pruning if we have finished populating the devices
-            # last seen info.
-            return
-
-        # We do a slightly funky SQL delete to ensure we don't try and delete
-        # too much at once (as the table may be very large from before we
-        # started pruning).
-        #
-        # This works by finding the max last_seen that is less than the given
-        # time, but has no more than N rows before it, deleting all rows with
-        # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
-        # returns exactly one row).
-        sql = """
-            DELETE FROM user_ips
-            WHERE last_seen <= (
-                SELECT COALESCE(MAX(last_seen), -1)
-                FROM (
-                    SELECT last_seen FROM user_ips
-                    WHERE last_seen <= ?
-                    ORDER BY last_seen ASC
-                    LIMIT 5000
-                ) AS u
-            )
-        """
-
-        timestamp = self.clock.time_msec() - self.user_ips_max_age
-
-        def _prune_old_user_ips_txn(txn):
-            txn.execute(sql, (timestamp,))
-
-        await self.db_pool.runInteraction(
-            "_prune_old_user_ips", _prune_old_user_ips_txn
-        )
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fdf394c612..dfb4f87b8f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
 # Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@ from synapse.logging.opentracing import (
     trace,
     whitelisted_homeserver,
 )
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -33,8 +33,9 @@ from synapse.storage.database import (
     make_tuple_comparison_clause,
 )
 from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
-from synapse.util import json_encoder
-from synapse.util.caches.descriptors import Cache, cached, cachedList
+from synapse.util import json_decoder, json_encoder
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
 
@@ -48,6 +49,14 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(
+                self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+            )
+
     async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
@@ -698,6 +707,172 @@ class DeviceWorkerStore(SQLBaseStore):
             _mark_remote_user_device_list_as_unsubscribed_txn,
         )
 
+    async def get_dehydrated_device(
+        self, user_id: str
+    ) -> Optional[Tuple[str, JsonDict]]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        # FIXME: make sure device ID still exists in devices table
+        row = await self.db_pool.simple_select_one(
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            retcols=["device_id", "device_data"],
+            allow_none=True,
+        )
+        return (
+            (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
+        )
+
+    def _store_dehydrated_device_txn(
+        self, txn, user_id: str, device_id: str, device_data: str
+    ) -> Optional[str]:
+        old_device_id = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            retcol="device_id",
+            allow_none=True,
+        )
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            values={"device_id": device_id, "device_data": device_data},
+        )
+        return old_device_id
+
+    async def store_dehydrated_device(
+        self, user_id: str, device_id: str, device_data: JsonDict
+    ) -> Optional[str]:
+        """Store a dehydrated device for a user.
+
+        Args:
+            user_id: the user that we are storing the device for
+            device_id: the ID of the dehydrated device
+            device_data: the dehydrated device information
+        Returns:
+            device id of the user's previous dehydrated device, if any
+        """
+        return await self.db_pool.runInteraction(
+            "store_dehydrated_device_txn",
+            self._store_dehydrated_device_txn,
+            user_id,
+            device_id,
+            json_encoder.encode(device_data),
+        )
+
+    async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
+        """Remove a dehydrated device.
+
+        Args:
+            user_id: the user that the dehydrated device belongs to
+            device_id: the ID of the dehydrated device
+        """
+        count = await self.db_pool.simple_delete(
+            "dehydrated_devices",
+            {"user_id": user_id, "device_id": device_id},
+            desc="remove_dehydrated_device",
+        )
+        return count >= 1
+
+    @wrap_as_background_process("prune_old_outbound_device_pokes")
+    async def _prune_old_outbound_device_pokes(
+        self, prune_age: int = 24 * 60 * 60 * 1000
+    ) -> None:
+        """Delete old entries out of the device_lists_outbound_pokes to ensure
+        that we don't fill up due to dead servers.
+
+        Normally, we try to send device updates as a delta since a previous known point:
+        this is done by setting the prev_id in the m.device_list_update EDU. However,
+        for that to work, we have to have a complete record of each change to
+        each device, which can add up to quite a lot of data.
+
+        An alternative mechanism is that, if the remote server sees that it has missed
+        an entry in the stream_id sequence for a given user, it will request a full
+        list of that user's devices. Hence, we can reduce the amount of data we have to
+        store (and transmit in some future transaction), by clearing almost everything
+        for a given destination out of the database, and having the remote server
+        resync.
+
+        All we need to do is make sure we keep at least one row for each
+        (user, destination) pair, to remind us to send a m.device_list_update EDU for
+        that user when the destination comes back. It doesn't matter which device
+        we keep.
+        """
+        yesterday = self._clock.time_msec() - prune_age
+
+        def _prune_txn(txn):
+            # look for (user, destination) pairs which have an update older than
+            # the cutoff.
+            #
+            # For each pair, we also need to know the most recent stream_id, and
+            # an arbitrary device_id at that stream_id.
+            select_sql = """
+            SELECT
+                dlop1.destination,
+                dlop1.user_id,
+                MAX(dlop1.stream_id) AS stream_id,
+                (SELECT MIN(dlop2.device_id) AS device_id FROM
+                    device_lists_outbound_pokes dlop2
+                    WHERE dlop2.destination = dlop1.destination AND
+                      dlop2.user_id=dlop1.user_id AND
+                      dlop2.stream_id=MAX(dlop1.stream_id)
+                )
+            FROM device_lists_outbound_pokes dlop1
+                GROUP BY destination, user_id
+                HAVING min(ts) < ? AND count(*) > 1
+            """
+
+            txn.execute(select_sql, (yesterday,))
+            rows = txn.fetchall()
+
+            if not rows:
+                return
+
+            logger.info(
+                "Pruning old outbound device list updates for %i users/destinations: %s",
+                len(rows),
+                shortstr((row[0], row[1]) for row in rows),
+            )
+
+            # we want to keep the update with the highest stream_id for each user.
+            #
+            # there might be more than one update (with different device_ids) with the
+            # same stream_id, so we also delete all but one rows with the max stream id.
+            delete_sql = """
+                DELETE FROM device_lists_outbound_pokes
+                WHERE destination = ? AND user_id = ? AND (
+                    stream_id < ? OR
+                    (stream_id = ? AND device_id != ?)
+                )
+            """
+            count = 0
+            for (destination, user_id, stream_id, device_id) in rows:
+                txn.execute(
+                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+                )
+                count += txn.rowcount
+
+            # Since we've deleted unsent deltas, we need to remove the entry
+            # of last successful sent so that the prev_ids are correctly set.
+            sql = """
+                DELETE FROM device_lists_outbound_last_success
+                WHERE destination = ? AND user_id = ?
+            """
+            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
+            logger.info("Pruned %d device list outbound pokes", count)
+
+        await self.db_pool.runInteraction(
+            "_prune_old_outbound_device_pokes", _prune_txn,
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -830,14 +1005,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
-        self.device_id_exists_cache = Cache(
-            name="device_id_exists", keylen=2, max_entries=10000
+        self.device_id_exists_cache = LruCache(
+            cache_name="device_id_exists", keylen=2, max_size=10000
         )
 
-        self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
-
     async def store_device(
-        self, user_id: str, device_id: str, initial_device_display_name: str
+        self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
     ) -> bool:
         """Ensure the given device is known; add it to the store if not
 
@@ -879,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 )
                 if hidden:
                     raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
-            self.device_id_exists_cache.prefill(key, True)
+            self.device_id_exists_cache.set(key, True)
             return inserted
         except StoreError:
             raise
@@ -955,7 +1128,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def update_remote_device_list_cache_entry(
-        self, user_id: str, device_id: str, content: JsonDict, stream_id: int
+        self, user_id: str, device_id: str, content: JsonDict, stream_id: str
     ) -> None:
         """Updates a single device in the cache of a remote user's devicelist.
 
@@ -983,7 +1156,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_id: str,
         content: JsonDict,
-        stream_id: int,
+        stream_id: str,
     ) -> None:
         if content.get("deleted"):
             self.db_pool.simple_delete_txn(
@@ -1193,95 +1366,3 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 for device_id in device_ids
             ],
         )
-
-    def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
-        """Delete old entries out of the device_lists_outbound_pokes to ensure
-        that we don't fill up due to dead servers.
-
-        Normally, we try to send device updates as a delta since a previous known point:
-        this is done by setting the prev_id in the m.device_list_update EDU. However,
-        for that to work, we have to have a complete record of each change to
-        each device, which can add up to quite a lot of data.
-
-        An alternative mechanism is that, if the remote server sees that it has missed
-        an entry in the stream_id sequence for a given user, it will request a full
-        list of that user's devices. Hence, we can reduce the amount of data we have to
-        store (and transmit in some future transaction), by clearing almost everything
-        for a given destination out of the database, and having the remote server
-        resync.
-
-        All we need to do is make sure we keep at least one row for each
-        (user, destination) pair, to remind us to send a m.device_list_update EDU for
-        that user when the destination comes back. It doesn't matter which device
-        we keep.
-        """
-        yesterday = self._clock.time_msec() - prune_age
-
-        def _prune_txn(txn):
-            # look for (user, destination) pairs which have an update older than
-            # the cutoff.
-            #
-            # For each pair, we also need to know the most recent stream_id, and
-            # an arbitrary device_id at that stream_id.
-            select_sql = """
-            SELECT
-                dlop1.destination,
-                dlop1.user_id,
-                MAX(dlop1.stream_id) AS stream_id,
-                (SELECT MIN(dlop2.device_id) AS device_id FROM
-                    device_lists_outbound_pokes dlop2
-                    WHERE dlop2.destination = dlop1.destination AND
-                      dlop2.user_id=dlop1.user_id AND
-                      dlop2.stream_id=MAX(dlop1.stream_id)
-                )
-            FROM device_lists_outbound_pokes dlop1
-                GROUP BY destination, user_id
-                HAVING min(ts) < ? AND count(*) > 1
-            """
-
-            txn.execute(select_sql, (yesterday,))
-            rows = txn.fetchall()
-
-            if not rows:
-                return
-
-            logger.info(
-                "Pruning old outbound device list updates for %i users/destinations: %s",
-                len(rows),
-                shortstr((row[0], row[1]) for row in rows),
-            )
-
-            # we want to keep the update with the highest stream_id for each user.
-            #
-            # there might be more than one update (with different device_ids) with the
-            # same stream_id, so we also delete all but one rows with the max stream id.
-            delete_sql = """
-                DELETE FROM device_lists_outbound_pokes
-                WHERE destination = ? AND user_id = ? AND (
-                    stream_id < ? OR
-                    (stream_id = ? AND device_id != ?)
-                )
-            """
-            count = 0
-            for (destination, user_id, stream_id, device_id) in rows:
-                txn.execute(
-                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
-                )
-                count += txn.rowcount
-
-            # Since we've deleted unsent deltas, we need to remove the entry
-            # of last successful sent so that the prev_ids are correctly set.
-            sql = """
-                DELETE FROM device_lists_outbound_last_success
-                WHERE destination = ? AND user_id = ?
-            """
-            txn.executemany(sql, ((row[0], row[1]) for row in rows))
-
-            logger.info("Pruned %d device list outbound pokes", count)
-
-        return run_as_background_process(
-            "prune_old_outbound_device_pokes",
-            self.db_pool.runInteraction,
-            "_prune_old_outbound_device_pokes",
-            _prune_txn,
-        )
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 22e1ed15d0..4d1b92d1aa 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
 # Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ from twisted.enterprise.adbapi import Connection
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.database import DatabasePool, make_in_list_sql_clause
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -33,6 +33,7 @@ from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.handlers.e2e_keys import SignatureListItem
+    from synapse.server import HomeServer
 
 
 @attr.s(slots=True)
@@ -47,7 +48,20 @@ class DeviceKeyLookupResult:
     keys = attr.ib(type=Optional[JsonDict])
 
 
-class EndToEndKeyWorkerStore(SQLBaseStore):
+class EndToEndKeyBackgroundStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+        super().__init__(database, db_conn, hs)
+
+        self.db_pool.updates.register_background_index_update(
+            "e2e_cross_signing_keys_idx",
+            index_name="e2e_cross_signing_keys_stream_idx",
+            table="e2e_cross_signing_keys",
+            columns=["stream_id"],
+            unique=True,
+        )
+
+
+class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
     async def get_e2e_device_keys_for_federation_query(
         self, user_id: str
     ) -> Tuple[int, List[JsonDict]]:
@@ -367,6 +381,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
+    async def set_e2e_fallback_keys(
+        self, user_id: str, device_id: str, fallback_keys: JsonDict
+    ) -> None:
+        """Set the user's e2e fallback keys.
+
+        Args:
+            user_id: the user whose keys are being set
+            device_id: the device whose keys are being set
+            fallback_keys: the keys to set.  This is a map from key ID (which is
+                of the form "algorithm:id") to key data.
+        """
+        # fallback_keys will usually only have one item in it, so using a for
+        # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
+        # FIXME: make sure that only one key per algorithm is uploaded
+        for key_id, fallback_key in fallback_keys.items():
+            algorithm, key_id = key_id.split(":", 1)
+            await self.db_pool.simple_upsert(
+                "e2e_fallback_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "algorithm": algorithm,
+                },
+                values={
+                    "key_id": key_id,
+                    "key_json": json_encoder.encode(fallback_key),
+                    "used": False,
+                },
+                desc="set_e2e_fallback_key",
+            )
+
+        await self.invalidate_cache_and_stream(
+            "get_e2e_unused_fallback_key_types", (user_id, device_id)
+        )
+
+    @cached(max_entries=10000)
+    async def get_e2e_unused_fallback_key_types(
+        self, user_id: str, device_id: str
+    ) -> List[str]:
+        """Returns the fallback key types that have an unused key.
+
+        Args:
+            user_id: the user whose keys are being queried
+            device_id: the device whose keys are being queried
+
+        Returns:
+            a list of key types
+        """
+        return await self.db_pool.simple_select_onecol(
+            "e2e_fallback_keys_json",
+            keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
+            retcol="algorithm",
+            desc="get_e2e_unused_fallback_key_types",
+        )
+
     async def get_e2e_cross_signing_key(
         self, user_id: str, key_type: str, from_user_id: Optional[str] = None
     ) -> Optional[dict]:
@@ -701,15 +770,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
                 " LIMIT 1"
             )
+            fallback_sql = (
+                "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
+                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+                " LIMIT 1"
+            )
             result = {}
             delete = []
+            used_fallbacks = []
             for user_id, device_id, algorithm in query_list:
                 user_result = result.setdefault(user_id, {})
                 device_result = user_result.setdefault(device_id, {})
                 txn.execute(sql, (user_id, device_id, algorithm))
-                for key_id, key_json in txn:
+                otk_row = txn.fetchone()
+                if otk_row is not None:
+                    key_id, key_json = otk_row
                     device_result[algorithm + ":" + key_id] = key_json
                     delete.append((user_id, device_id, algorithm, key_id))
+                else:
+                    # no one-time key available, so see if there's a fallback
+                    # key
+                    txn.execute(fallback_sql, (user_id, device_id, algorithm))
+                    fallback_row = txn.fetchone()
+                    if fallback_row is not None:
+                        key_id, key_json, used = fallback_row
+                        device_result[algorithm + ":" + key_id] = key_json
+                        if not used:
+                            used_fallbacks.append(
+                                (user_id, device_id, algorithm, key_id)
+                            )
+
+            # drop any one-time keys that were claimed
             sql = (
                 "DELETE FROM e2e_one_time_keys_json"
                 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -726,6 +817,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 self._invalidate_cache_and_stream(
                     txn, self.count_e2e_one_time_keys, (user_id, device_id)
                 )
+            # mark fallback keys as used
+            for user_id, device_id, algorithm, key_id in used_fallbacks:
+                self.db_pool.simple_update_txn(
+                    txn,
+                    "e2e_fallback_keys_json",
+                    {
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "algorithm": algorithm,
+                        "key_id": key_id,
+                    },
+                    {"used": True},
+                )
+                self._invalidate_cache_and_stream(
+                    txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+                )
+
             return result
 
         return await self.db_pool.runInteraction(
@@ -754,6 +862,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             self._invalidate_cache_and_stream(
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="dehydrated_devices",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="e2e_fallback_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+            )
 
         await self.db_pool.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6d3689c09e..2e07c37340 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -19,19 +19,33 @@ from typing import Dict, Iterable, List, Set, Tuple
 
 from synapse.api.errors import StoreError
 from synapse.events import EventBase
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
 from synapse.types import Collection
 from synapse.util.caches.descriptors import cached
+from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        if hs.config.run_background_tasks:
+            hs.get_clock().looping_call(
+                self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+            )
+
+        # Cache of event ID to list of auth event IDs and their depths.
+        self._event_auth_cache = LruCache(
+            500000, "_event_auth_cache", size_callback=len
+        )  # type: LruCache[str, List[Tuple[str, int]]]
+
     async def get_auth_chain(
         self, event_ids: Collection[str], include_given: bool = False
     ) -> List[EventBase]:
@@ -76,17 +90,45 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         else:
             results = set()
 
-        base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
+        # We pull out the depth simply so that we can populate the
+        # `_event_auth_cache` cache.
+        base_sql = """
+            SELECT a.event_id, auth_id, depth
+            FROM event_auth AS a
+            INNER JOIN events AS e ON (e.event_id = a.auth_id)
+            WHERE
+        """
 
         front = set(event_ids)
         while front:
             new_front = set()
             for chunk in batch_iter(front, 100):
-                clause, args = make_in_list_sql_clause(
-                    txn.database_engine, "event_id", chunk
-                )
-                txn.execute(base_sql + clause, args)
-                new_front.update(r[0] for r in txn)
+                # Pull the auth events either from the cache or DB.
+                to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
+                for event_id in chunk:
+                    res = self._event_auth_cache.get(event_id)
+                    if res is None:
+                        to_fetch.append(event_id)
+                    else:
+                        new_front.update(auth_id for auth_id, depth in res)
+
+                if to_fetch:
+                    clause, args = make_in_list_sql_clause(
+                        txn.database_engine, "a.event_id", to_fetch
+                    )
+                    txn.execute(base_sql + clause, args)
+
+                    # Note we need to batch up the results by event ID before
+                    # adding to the cache.
+                    to_cache = {}
+                    for event_id, auth_event_id, auth_event_depth in txn:
+                        to_cache.setdefault(event_id, []).append(
+                            (auth_event_id, auth_event_depth)
+                        )
+                        new_front.add(auth_event_id)
+
+                    for event_id, auth_events in to_cache.items():
+                        self._event_auth_cache.set(event_id, auth_events)
 
             new_front -= results
 
@@ -205,14 +247,38 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 break
 
             # Fetch the auth events and their depths of the N last events we're
-            # currently walking
+            # currently walking, either from cache or DB.
             search, chunk = search[:-100], search[-100:]
-            clause, args = make_in_list_sql_clause(
-                txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
-            )
-            txn.execute(base_sql + clause, args)
 
-            for event_id, auth_event_id, auth_event_depth in txn:
+            found = []  # Results found  # type: List[Tuple[str, str, int]]
+            to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
+            for _, event_id in chunk:
+                res = self._event_auth_cache.get(event_id)
+                if res is None:
+                    to_fetch.append(event_id)
+                else:
+                    found.extend((event_id, auth_id, depth) for auth_id, depth in res)
+
+            if to_fetch:
+                clause, args = make_in_list_sql_clause(
+                    txn.database_engine, "a.event_id", to_fetch
+                )
+                txn.execute(base_sql + clause, args)
+
+                # We parse the results and add the to the `found` set and the
+                # cache (note we need to batch up the results by event ID before
+                # adding to the cache).
+                to_cache = {}
+                for event_id, auth_event_id, auth_event_depth in txn:
+                    to_cache.setdefault(event_id, []).append(
+                        (auth_event_id, auth_event_depth)
+                    )
+                    found.append((event_id, auth_event_id, auth_event_depth))
+
+                for event_id, auth_events in to_cache.items():
+                    self._event_auth_cache.set(event_id, auth_events)
+
+            for event_id, auth_event_id, auth_event_depth in found:
                 event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
 
                 sets = event_to_missing_sets.get(auth_event_id)
@@ -586,6 +652,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return [row["event_id"] for row in rows]
 
+    @wrap_as_background_process("delete_old_forward_extrem_cache")
+    async def _delete_old_forward_extrem_cache(self) -> None:
+        def _delete_old_forward_extrem_cache_txn(txn):
+            # Delete entries older than a month, while making sure we don't delete
+            # the only entries for a room.
+            sql = """
+                DELETE FROM stream_ordering_to_exterm
+                WHERE
+                room_id IN (
+                    SELECT room_id
+                    FROM stream_ordering_to_exterm
+                    WHERE stream_ordering > ?
+                ) AND stream_ordering < ?
+            """
+            txn.execute(
+                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+            )
+
+        await self.db_pool.runInteraction(
+            "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
+        )
+
 
 class EventFederationStore(EventFederationWorkerStore):
     """ Responsible for storing and serving up the various graphs associated
@@ -606,34 +694,6 @@ class EventFederationStore(EventFederationWorkerStore):
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
         )
 
-        hs.get_clock().looping_call(
-            self._delete_old_forward_extrem_cache, 60 * 60 * 1000
-        )
-
-    def _delete_old_forward_extrem_cache(self):
-        def _delete_old_forward_extrem_cache_txn(txn):
-            # Delete entries older than a month, while making sure we don't delete
-            # the only entries for a room.
-            sql = """
-                DELETE FROM stream_ordering_to_exterm
-                WHERE
-                room_id IN (
-                    SELECT room_id
-                    FROM stream_ordering_to_exterm
-                    WHERE stream_ordering > ?
-                ) AND stream_ordering < ?
-            """
-            txn.execute(
-                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
-            )
-
-        return run_as_background_process(
-            "delete_old_forward_extrem_cache",
-            self.db_pool.runInteraction,
-            "_delete_old_forward_extrem_cache",
-            _delete_old_forward_extrem_cache_txn,
-        )
-
     async def clean_room_for_join(self, room_id):
         return await self.db_pool.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 62f1738732..2e56dfaf31 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,15 +13,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import Dict, List, Optional, Tuple, Union
 
 import attr
 
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -74,19 +73,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         self.stream_ordering_month_ago = None
         self.stream_ordering_day_ago = None
 
-        cur = LoggingTransaction(
-            db_conn.cursor(),
-            name="_find_stream_orderings_for_times_txn",
-            database_engine=self.database_engine,
-        )
+        cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
         self._find_stream_orderings_for_times_txn(cur)
         cur.close()
 
         self.find_stream_orderings_looping_call = self._clock.looping_call(
             self._find_stream_orderings_for_times, 10 * 60 * 1000
         )
+
         self._rotate_delay = 3
         self._rotate_count = 10000
+        self._doing_notif_rotation = False
+        if hs.config.run_background_tasks:
+            self._rotate_notif_loop = self._clock.looping_call(
+                self._rotate_notifs, 30 * 60 * 1000
+            )
 
     @cached(num_args=3, tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
@@ -518,15 +519,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "Error removing push actions after event persistence failure"
             )
 
-    def _find_stream_orderings_for_times(self):
-        return run_as_background_process(
-            "event_push_action_stream_orderings",
-            self.db_pool.runInteraction,
+    @wrap_as_background_process("event_push_action_stream_orderings")
+    async def _find_stream_orderings_for_times(self) -> None:
+        await self.db_pool.runInteraction(
             "_find_stream_orderings_for_times",
             self._find_stream_orderings_for_times_txn,
         )
 
-    def _find_stream_orderings_for_times_txn(self, txn):
+    def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None:
         logger.info("Searching for stream ordering 1 month ago")
         self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
             txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
@@ -656,129 +656,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
         return result[0] if result else None
 
-
-class EventPushActionsStore(EventPushActionsWorkerStore):
-    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        self.db_pool.updates.register_background_index_update(
-            self.EPA_HIGHLIGHT_INDEX,
-            index_name="event_push_actions_u_highlight",
-            table="event_push_actions",
-            columns=["user_id", "stream_ordering"],
-        )
-
-        self.db_pool.updates.register_background_index_update(
-            "event_push_actions_highlights_index",
-            index_name="event_push_actions_highlights_index",
-            table="event_push_actions",
-            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
-            where_clause="highlight=1",
-        )
-
-        self._doing_notif_rotation = False
-        self._rotate_notif_loop = self._clock.looping_call(
-            self._start_rotate_notifs, 30 * 60 * 1000
-        )
-
-    async def get_push_actions_for_user(
-        self, user_id, before=None, limit=50, only_highlight=False
-    ):
-        def f(txn):
-            before_clause = ""
-            if before:
-                before_clause = "AND epa.stream_ordering < ?"
-                args = [user_id, before, limit]
-            else:
-                args = [user_id, limit]
-
-            if only_highlight:
-                if len(before_clause) > 0:
-                    before_clause += " "
-                before_clause += "AND epa.highlight = 1"
-
-            # NB. This assumes event_ids are globally unique since
-            # it makes the query easier to index
-            sql = (
-                "SELECT epa.event_id, epa.room_id,"
-                " epa.stream_ordering, epa.topological_ordering,"
-                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
-                " FROM event_push_actions epa, events e"
-                " WHERE epa.event_id = e.event_id"
-                " AND epa.user_id = ? %s"
-                " AND epa.notif = 1"
-                " ORDER BY epa.stream_ordering DESC"
-                " LIMIT ?" % (before_clause,)
-            )
-            txn.execute(sql, args)
-            return self.db_pool.cursor_to_dict(txn)
-
-        push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
-        for pa in push_actions:
-            pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
-        return push_actions
-
-    async def get_latest_push_action_stream_ordering(self):
-        def f(txn):
-            txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
-            return txn.fetchone()
-
-        result = await self.db_pool.runInteraction(
-            "get_latest_push_action_stream_ordering", f
-        )
-        return result[0] or 0
-
-    def _remove_old_push_actions_before_txn(
-        self, txn, room_id, user_id, stream_ordering
-    ):
-        """
-        Purges old push actions for a user and room before a given
-        stream_ordering.
-
-        We however keep a months worth of highlighted notifications, so that
-        users can still get a list of recent highlights.
-
-        Args:
-            txn: The transcation
-            room_id: Room ID to delete from
-            user_id: user ID to delete for
-            stream_ordering: The lowest stream ordering which will
-                                  not be deleted.
-        """
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id, user_id),
-        )
-
-        # We need to join on the events table to get the received_ts for
-        # event_push_actions and sqlite won't let us use a join in a delete so
-        # we can't just delete where received_ts < x. Furthermore we can
-        # only identify event_push_actions by a tuple of room_id, event_id
-        # we we can't use a subquery.
-        # Instead, we look up the stream ordering for the last event in that
-        # room received before the threshold time and delete event_push_actions
-        # in the room with a stream_odering before that.
-        txn.execute(
-            "DELETE FROM event_push_actions "
-            " WHERE user_id = ? AND room_id = ? AND "
-            " stream_ordering <= ?"
-            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
-        )
-
-        txn.execute(
-            """
-            DELETE FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """,
-            (room_id, user_id, stream_ordering),
-        )
-
-    def _start_rotate_notifs(self):
-        return run_as_background_process("rotate_notifs", self._rotate_notifs)
-
+    @wrap_as_background_process("rotate_notifs")
     async def _rotate_notifs(self):
         if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
             return
@@ -958,6 +836,121 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         )
 
 
+class EventPushActionsStore(EventPushActionsWorkerStore):
+    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self.db_pool.updates.register_background_index_update(
+            self.EPA_HIGHLIGHT_INDEX,
+            index_name="event_push_actions_u_highlight",
+            table="event_push_actions",
+            columns=["user_id", "stream_ordering"],
+        )
+
+        self.db_pool.updates.register_background_index_update(
+            "event_push_actions_highlights_index",
+            index_name="event_push_actions_highlights_index",
+            table="event_push_actions",
+            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+            where_clause="highlight=1",
+        )
+
+    async def get_push_actions_for_user(
+        self, user_id, before=None, limit=50, only_highlight=False
+    ):
+        def f(txn):
+            before_clause = ""
+            if before:
+                before_clause = "AND epa.stream_ordering < ?"
+                args = [user_id, before, limit]
+            else:
+                args = [user_id, limit]
+
+            if only_highlight:
+                if len(before_clause) > 0:
+                    before_clause += " "
+                before_clause += "AND epa.highlight = 1"
+
+            # NB. This assumes event_ids are globally unique since
+            # it makes the query easier to index
+            sql = (
+                "SELECT epa.event_id, epa.room_id,"
+                " epa.stream_ordering, epa.topological_ordering,"
+                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
+                " FROM event_push_actions epa, events e"
+                " WHERE epa.event_id = e.event_id"
+                " AND epa.user_id = ? %s"
+                " AND epa.notif = 1"
+                " ORDER BY epa.stream_ordering DESC"
+                " LIMIT ?" % (before_clause,)
+            )
+            txn.execute(sql, args)
+            return self.db_pool.cursor_to_dict(txn)
+
+        push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
+        for pa in push_actions:
+            pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
+        return push_actions
+
+    async def get_latest_push_action_stream_ordering(self):
+        def f(txn):
+            txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
+            return txn.fetchone()
+
+        result = await self.db_pool.runInteraction(
+            "get_latest_push_action_stream_ordering", f
+        )
+        return result[0] or 0
+
+    def _remove_old_push_actions_before_txn(
+        self, txn, room_id, user_id, stream_ordering
+    ):
+        """
+        Purges old push actions for a user and room before a given
+        stream_ordering.
+
+        We however keep a months worth of highlighted notifications, so that
+        users can still get a list of recent highlights.
+
+        Args:
+            txn: The transcation
+            room_id: Room ID to delete from
+            user_id: user ID to delete for
+            stream_ordering: The lowest stream ordering which will
+                                  not be deleted.
+        """
+        txn.call_after(
+            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            (room_id, user_id),
+        )
+
+        # We need to join on the events table to get the received_ts for
+        # event_push_actions and sqlite won't let us use a join in a delete so
+        # we can't just delete where received_ts < x. Furthermore we can
+        # only identify event_push_actions by a tuple of room_id, event_id
+        # we we can't use a subquery.
+        # Instead, we look up the stream ordering for the last event in that
+        # room received before the threshold time and delete event_push_actions
+        # in the room with a stream_odering before that.
+        txn.execute(
+            "DELETE FROM event_push_actions "
+            " WHERE user_id = ? AND room_id = ? AND "
+            " stream_ordering <= ?"
+            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+        )
+
+        txn.execute(
+            """
+            DELETE FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+        """,
+            (room_id, user_id, stream_ordering),
+        )
+
+
 def _action_has_highlight(actions):
     for action in actions:
         try:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 18def01f50..90fb1a1f00 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import StateMap, get_domain_from_id
-from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -52,16 +52,6 @@ event_counter = Counter(
 )
 
 
-def encode_json(json_object):
-    """
-    Encode a Python object as JSON and return it in a Unicode string.
-    """
-    out = frozendict_json_encoder.encode(json_object)
-    if isinstance(out, bytes):
-        out = out.decode("utf8")
-    return out
-
-
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
@@ -341,6 +331,10 @@ class PersistEventsStore:
         min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
+        # stream orderings should have been assigned by now
+        assert min_stream_order
+        assert max_stream_order
+
         self._update_forward_extremities_txn(
             txn,
             new_forward_extremities=new_forward_extremeties,
@@ -367,6 +361,8 @@ class PersistEventsStore:
 
         self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
+        self._persist_transaction_ids_txn(txn, events_and_contexts)
+
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
@@ -411,6 +407,35 @@ class PersistEventsStore:
         # room_memberships, where applicable.
         self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
 
+    def _persist_transaction_ids_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ):
+        """Persist the mapping from transaction IDs to event IDs (if defined).
+        """
+
+        to_insert = []
+        for event, _ in events_and_contexts:
+            token_id = getattr(event.internal_metadata, "token_id", None)
+            txn_id = getattr(event.internal_metadata, "txn_id", None)
+            if token_id and txn_id:
+                to_insert.append(
+                    {
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "user_id": event.sender,
+                        "token_id": token_id,
+                        "txn_id": txn_id,
+                        "inserted_ts": self._clock.time_msec(),
+                    }
+                )
+
+        if to_insert:
+            self.db_pool.simple_insert_many_txn(
+                txn, table="event_txn_id", values=to_insert,
+            )
+
     def _update_current_state_txn(
         self,
         txn: LoggingTransaction,
@@ -432,12 +457,12 @@ class PersistEventsStore:
                 # so that async background tasks get told what happened.
                 sql = """
                     INSERT INTO current_state_delta_stream
-                        (stream_id, room_id, type, state_key, event_id, prev_event_id)
-                    SELECT ?, room_id, type, state_key, null, event_id
+                        (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
+                    SELECT ?, ?, room_id, type, state_key, null, event_id
                         FROM current_state_events
                         WHERE room_id = ?
                 """
-                txn.execute(sql, (stream_id, room_id))
+                txn.execute(sql, (stream_id, self._instance_name, room_id))
 
                 self.db_pool.simple_delete_txn(
                     txn, table="current_state_events", keyvalues={"room_id": room_id},
@@ -458,8 +483,8 @@ class PersistEventsStore:
                 #
                 sql = """
                     INSERT INTO current_state_delta_stream
-                    (stream_id, room_id, type, state_key, event_id, prev_event_id)
-                    SELECT ?, ?, ?, ?, ?, (
+                    (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
+                    SELECT ?, ?, ?, ?, ?, ?, (
                         SELECT event_id FROM current_state_events
                         WHERE room_id = ? AND type = ? AND state_key = ?
                     )
@@ -469,6 +494,7 @@ class PersistEventsStore:
                     (
                         (
                             stream_id,
+                            self._instance_name,
                             room_id,
                             etype,
                             state_key,
@@ -743,7 +769,7 @@ class PersistEventsStore:
                     logger.exception("")
                     raise
 
-                metadata_json = encode_json(event.internal_metadata.get_dict())
+                metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
 
                 sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
                 txn.execute(sql, (metadata_json, event.event_id))
@@ -759,6 +785,7 @@ class PersistEventsStore:
                         "event_stream_ordering": stream_order,
                         "event_id": event.event_id,
                         "state_group": state_group_id,
+                        "instance_name": self._instance_name,
                     },
                 )
 
@@ -797,10 +824,10 @@ class PersistEventsStore:
                 {
                     "event_id": event.event_id,
                     "room_id": event.room_id,
-                    "internal_metadata": encode_json(
+                    "internal_metadata": json_encoder.encode(
                         event.internal_metadata.get_dict()
                     ),
-                    "json": encode_json(event_dict(event)),
+                    "json": json_encoder.encode(event_dict(event)),
                     "format_version": event.format_version,
                 }
                 for event, _ in events_and_contexts
@@ -1022,9 +1049,7 @@ class PersistEventsStore:
 
         def prefill():
             for cache_entry in to_prefill:
-                self.store._get_event_cache.prefill(
-                    (cache_entry[0].event_id,), cache_entry
-                )
+                self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
 
         txn.call_after(prefill)
 
@@ -1241,6 +1266,10 @@ class PersistEventsStore:
             )
 
     def _store_retention_policy_for_room_txn(self, txn, event):
+        if not event.is_state():
+            logger.debug("Ignoring non-state m.room.retention event")
+            return
+
         if hasattr(event, "content") and (
             "min_lifetime" in event.content or "max_lifetime" in event.content
         ):
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 5e4af2eb51..97b6754846 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -92,6 +92,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             where_clause="NOT have_censored",
         )
 
+        self.db_pool.updates.register_background_index_update(
+            "users_have_local_media",
+            index_name="users_have_local_media",
+            table="local_media_repository",
+            columns=["user_id", "created_ts"],
+        )
+
     async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f95679ebc4..4732685f6e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import itertools
 import logging
 import threading
@@ -32,9 +31,13 @@ from synapse.api.room_versions import (
     RoomVersions,
 )
 from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
@@ -42,8 +45,9 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
 from synapse.storage.database import DatabasePool
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
-from synapse.types import Collection, get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached
+from synapse.types import Collection, JsonDict, get_domain_from_id
+from synapse.util.caches.descriptors import cached
+from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -74,6 +78,13 @@ class EventRedactBehaviour(Names):
 
 
 class EventsWorkerStore(SQLBaseStore):
+    # Whether to use dedicated DB threads for event fetching. This is only used
+    # if there are multiple DB threads available. When used will lock the DB
+    # thread for periods of time (so unit tests want to disable this when they
+    # run DB transactions on the main thread). See EVENT_QUEUE_* for more
+    # options controlling this.
+    USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
+
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
@@ -130,11 +141,16 @@ class EventsWorkerStore(SQLBaseStore):
                     db_conn, "events", "stream_ordering", step=-1
                 )
 
-        self._get_event_cache = Cache(
-            "*getEvent*",
+        if hs.config.run_background_tasks:
+            # We periodically clean out old transaction ID mappings
+            self._clock.looping_call(
+                self._cleanup_old_transaction_ids, 5 * 60 * 1000,
+            )
+
+        self._get_event_cache = LruCache(
+            cache_name="*getEvent*",
             keylen=3,
-            max_entries=hs.config.caches.event_cache_size,
-            apply_cache_factor_from_config=False,
+            max_size=hs.config.caches.event_cache_size,
         )
 
         self._event_fetch_lock = threading.Condition()
@@ -510,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_map
 
+    async def get_stripped_room_state_from_event_context(
+        self,
+        context: EventContext,
+        state_types_to_include: List[EventTypes],
+        membership_user_id: Optional[str] = None,
+    ) -> List[JsonDict]:
+        """
+        Retrieve the stripped state from a room, given an event context to retrieve state
+        from as well as the state types to include. Optionally, include the membership
+        events from a specific user.
+
+        "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
+        are included from each state event.
+
+        Args:
+            context: The event context to retrieve state of the room from.
+            state_types_to_include: The type of state events to include.
+            membership_user_id: An optional user ID to include the stripped membership state
+                events of. This is useful when generating the stripped state of a room for
+                invites. We want to send membership events of the inviter, so that the
+                invitee can display the inviter's profile information if the room lacks any.
+
+        Returns:
+            A list of dictionaries, each representing a stripped state event from the room.
+        """
+        current_state_ids = await context.get_current_state_ids()
+
+        # We know this event is not an outlier, so this must be
+        # non-None.
+        assert current_state_ids is not None
+
+        # The state to include
+        state_to_include_ids = [
+            e_id
+            for k, e_id in current_state_ids.items()
+            if k[0] in state_types_to_include
+            or (membership_user_id and k == (EventTypes.Member, membership_user_id))
+        ]
+
+        state_to_include = await self.get_events(state_to_include_ids)
+
+        return [
+            {
+                "type": e.type,
+                "state_key": e.state_key,
+                "content": e.content,
+                "sender": e.sender,
+            }
+            for e in state_to_include.values()
+        ]
+
     def _do_fetch(self, conn):
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
@@ -522,7 +589,11 @@ class EventsWorkerStore(SQLBaseStore):
 
                 if not event_list:
                     single_threaded = self.database_engine.single_threaded
-                    if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+                    if (
+                        not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+                        or single_threaded
+                        or i > EVENT_QUEUE_ITERATIONS
+                    ):
                         self._event_fetch_ongoing -= 1
                         return
                     else:
@@ -712,6 +783,7 @@ class EventsWorkerStore(SQLBaseStore):
                 internal_metadata_dict=internal_metadata,
                 rejected_reason=rejected_reason,
             )
+            original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
 
             event_map[event_id] = original_ev
 
@@ -728,7 +800,7 @@ class EventsWorkerStore(SQLBaseStore):
                 event=original_ev, redacted_event=redacted_event
             )
 
-            self._get_event_cache.prefill((event_id,), cache_entry)
+            self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
         return result_map
@@ -779,6 +851,8 @@ class EventsWorkerStore(SQLBaseStore):
 
          * event_id (str)
 
+         * stream_ordering (int): stream ordering for this event
+
          * json (str): json-encoded event structure
 
          * internal_metadata (str): json-encoded internal metadata dict
@@ -811,13 +885,15 @@ class EventsWorkerStore(SQLBaseStore):
             sql = """\
                 SELECT
                   e.event_id,
-                  e.internal_metadata,
-                  e.json,
-                  e.format_version,
+                  e.stream_ordering,
+                  ej.internal_metadata,
+                  ej.json,
+                  ej.format_version,
                   r.room_version,
                   rej.reason
-                FROM event_json as e
-                  LEFT JOIN rooms r USING (room_id)
+                FROM events AS e
+                  JOIN event_json AS ej USING (event_id)
+                  LEFT JOIN rooms r ON r.room_id = e.room_id
                   LEFT JOIN rejections as rej USING (event_id)
                 WHERE """
 
@@ -831,11 +907,12 @@ class EventsWorkerStore(SQLBaseStore):
                 event_id = row[0]
                 event_dict[event_id] = {
                     "event_id": event_id,
-                    "internal_metadata": row[1],
-                    "json": row[2],
-                    "format_version": row[3],
-                    "room_version_id": row[4],
-                    "rejected_reason": row[5],
+                    "stream_ordering": row[1],
+                    "internal_metadata": row[2],
+                    "json": row[3],
+                    "format_version": row[4],
+                    "room_version_id": row[5],
+                    "rejected_reason": row[6],
                     "redactions": [],
                 }
 
@@ -1017,16 +1094,12 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {"v1": complexity_v1}
 
-    def get_current_backfill_token(self):
-        """The current minimum token that backfilled events have reached"""
-        return -self._backfill_id_gen.get_current_token()
-
     def get_current_events_token(self):
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
     async def get_all_new_forward_event_rows(
-        self, last_id: int, current_id: int, limit: int
+        self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> List[Tuple]:
         """Returns new events, for the Events replication stream
 
@@ -1044,16 +1117,19 @@ class EventsWorkerStore(SQLBaseStore):
         def get_all_new_forward_event_rows(txn):
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
+                " LEFT JOIN room_memberships USING (event_id)"
+                " LEFT JOIN rejections USING (event_id)"
                 " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " AND instance_name = ?"
                 " ORDER BY stream_ordering ASC"
                 " LIMIT ?"
             )
-            txn.execute(sql, (last_id, current_id, limit))
+            txn.execute(sql, (last_id, current_id, instance_name, limit))
             return txn.fetchall()
 
         return await self.db_pool.runInteraction(
@@ -1061,7 +1137,7 @@ class EventsWorkerStore(SQLBaseStore):
         )
 
     async def get_ex_outlier_stream_rows(
-        self, last_id: int, current_id: int
+        self, instance_name: str, last_id: int, current_id: int
     ) -> List[Tuple]:
         """Returns de-outliered events, for the Events replication stream
 
@@ -1078,18 +1154,21 @@ class EventsWorkerStore(SQLBaseStore):
         def get_ex_outlier_stream_rows_txn(txn):
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
+                " LEFT JOIN room_memberships USING (event_id)"
+                " LEFT JOIN rejections USING (event_id)"
                 " WHERE ? < event_stream_ordering"
                 " AND event_stream_ordering <= ?"
+                " AND out.instance_name = ?"
                 " ORDER BY event_stream_ordering ASC"
             )
 
-            txn.execute(sql, (last_id, current_id))
+            txn.execute(sql, (last_id, current_id, instance_name))
             return txn.fetchall()
 
         return await self.db_pool.runInteraction(
@@ -1102,6 +1181,9 @@ class EventsWorkerStore(SQLBaseStore):
         """Get updates for backfill replication stream, including all new
         backfilled events and events that have gone from being outliers to not.
 
+        NOTE: The IDs given here are from replication, and so should be
+        *positive*.
+
         Args:
             instance_name: The writer we want to fetch updates from. Unused
                 here since there is only ever one writer.
@@ -1132,10 +1214,11 @@ class EventsWorkerStore(SQLBaseStore):
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                "  AND instance_name = ?"
                 " ORDER BY stream_ordering ASC"
                 " LIMIT ?"
             )
-            txn.execute(sql, (-last_id, -current_id, limit))
+            txn.execute(sql, (-last_id, -current_id, instance_name, limit))
             new_event_updates = [(row[0], row[1:]) for row in txn]
 
             limited = False
@@ -1149,15 +1232,16 @@ class EventsWorkerStore(SQLBaseStore):
                 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
                 " state_key, redacts, relates_to_id"
                 " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > event_stream_ordering"
                 " AND event_stream_ordering >= ?"
+                " AND out.instance_name = ?"
                 " ORDER BY event_stream_ordering DESC"
             )
-            txn.execute(sql, (-last_id, -upper_bound))
+            txn.execute(sql, (-last_id, -upper_bound, instance_name))
             new_event_updates.extend((row[0], row[1:]) for row in txn)
 
             if len(new_event_updates) >= limit:
@@ -1171,7 +1255,7 @@ class EventsWorkerStore(SQLBaseStore):
         )
 
     async def get_all_updated_current_state_deltas(
-        self, from_token: int, to_token: int, target_row_count: int
+        self, instance_name: str, from_token: int, to_token: int, target_row_count: int
     ) -> Tuple[List[Tuple], int, bool]:
         """Fetch updates from current_state_delta_stream
 
@@ -1197,9 +1281,10 @@ class EventsWorkerStore(SQLBaseStore):
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
                 WHERE ? < stream_id AND stream_id <= ?
+                    AND instance_name = ?
                 ORDER BY stream_id ASC LIMIT ?
             """
-            txn.execute(sql, (from_token, to_token, target_row_count))
+            txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
             return txn.fetchall()
 
         def get_deltas_for_stream_id_txn(txn, stream_id):
@@ -1287,3 +1372,77 @@ class EventsWorkerStore(SQLBaseStore):
         return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
+
+    async def get_event_id_from_transaction_id(
+        self, room_id: str, user_id: str, token_id: int, txn_id: str
+    ) -> Optional[str]:
+        """Look up if we have already persisted an event for the transaction ID,
+        returning the event ID if so.
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            table="event_txn_id",
+            keyvalues={
+                "room_id": room_id,
+                "user_id": user_id,
+                "token_id": token_id,
+                "txn_id": txn_id,
+            },
+            retcol="event_id",
+            allow_none=True,
+            desc="get_event_id_from_transaction_id",
+        )
+
+    async def get_already_persisted_events(
+        self, events: Iterable[EventBase]
+    ) -> Dict[str, str]:
+        """Look up if we have already persisted an event for the transaction ID,
+        returning a mapping from event ID in the given list to the event ID of
+        an existing event.
+
+        Also checks if there are duplicates in the given events, if there are
+        will map duplicates to the *first* event.
+        """
+
+        mapping = {}
+        txn_id_to_event = {}  # type: Dict[Tuple[str, int, str], str]
+
+        for event in events:
+            token_id = getattr(event.internal_metadata, "token_id", None)
+            txn_id = getattr(event.internal_metadata, "txn_id", None)
+
+            if token_id and txn_id:
+                # Check if this is a duplicate of an event in the given events.
+                existing = txn_id_to_event.get((event.room_id, token_id, txn_id))
+                if existing:
+                    mapping[event.event_id] = existing
+                    continue
+
+                # Check if this is a duplicate of an event we've already
+                # persisted.
+                existing = await self.get_event_id_from_transaction_id(
+                    event.room_id, event.sender, token_id, txn_id
+                )
+                if existing:
+                    mapping[event.event_id] = existing
+                    txn_id_to_event[(event.room_id, token_id, txn_id)] = existing
+                else:
+                    txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id
+
+        return mapping
+
+    @wrap_as_background_process("_cleanup_old_transaction_ids")
+    async def _cleanup_old_transaction_ids(self):
+        """Cleans out transaction id mappings older than 24hrs.
+        """
+
+        def _cleanup_old_transaction_ids_txn(txn):
+            sql = """
+                DELETE FROM event_txn_id
+                WHERE inserted_ts < ?
+            """
+            one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+            txn.execute(sql, (one_day_ago,))
+
+        return await self.db_pool.runInteraction(
+            "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
+        )
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ad43bb05ab..f8f4bb9b3f 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -122,9 +122,7 @@ class KeyStore(SQLBaseStore):
             # param, which is itself the 2-tuple (server_name, key_id).
             invalidations.append((server_name, key_id))
 
-        await self.db_pool.runInteraction(
-            "store_server_verify_keys",
-            self.db_pool.simple_upsert_many_txn,
+        await self.db_pool.simple_upsert_many(
             table="server_signature_keys",
             key_names=("server_name", "key_id"),
             key_values=key_values,
@@ -135,6 +133,7 @@ class KeyStore(SQLBaseStore):
                 "verify_key",
             ),
             value_values=value_values,
+            desc="store_server_verify_keys",
         )
 
         invalidate = self._get_server_verify_key.invalidate
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index cc538c5c10..4b2f224718 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -93,6 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
+        self.server_name = hs.hostname
 
     async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
         """Get the metadata for a local piece of media
@@ -115,6 +116,109 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_local_media",
         )
 
+    async def get_local_media_by_user_paginate(
+        self, start: int, limit: int, user_id: str
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Get a paginated list of metadata for a local piece of media
+        which an user_id has uploaded
+
+        Args:
+            start: offset in the list
+            limit: maximum amount of media_ids to retrieve
+            user_id: fully-qualified user id
+        Returns:
+            A paginated list of all metadata of user's media,
+            plus the total count of all the user's media
+        """
+
+        def get_local_media_by_user_paginate_txn(txn):
+
+            args = [user_id]
+            sql = """
+                SELECT COUNT(*) as total_media
+                FROM local_media_repository
+                WHERE user_id = ?
+            """
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = """
+                SELECT
+                    "media_id",
+                    "media_type",
+                    "media_length",
+                    "upload_name",
+                    "created_ts",
+                    "last_access_ts",
+                    "quarantined_by",
+                    "safe_from_quarantine"
+                FROM local_media_repository
+                WHERE user_id = ?
+                ORDER BY created_ts DESC, media_id DESC
+                LIMIT ? OFFSET ?
+            """
+
+            args += [limit, start]
+            txn.execute(sql, args)
+            media = self.db_pool.cursor_to_dict(txn)
+            return media, count
+
+        return await self.db_pool.runInteraction(
+            "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn
+        )
+
+    async def get_local_media_before(
+        self, before_ts: int, size_gt: int, keep_profiles: bool,
+    ) -> Optional[List[str]]:
+
+        # to find files that have never been accessed (last_access_ts IS NULL)
+        # compare with `created_ts`
+        sql = """
+            SELECT media_id
+            FROM local_media_repository AS lmr
+            WHERE
+                ( last_access_ts < ?
+                OR ( created_ts < ? AND last_access_ts IS NULL ) )
+                AND media_length > ?
+        """
+
+        if keep_profiles:
+            sql_keep = """
+                AND (
+                    NOT EXISTS
+                        (SELECT 1
+                         FROM profiles
+                         WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id)
+                    AND NOT EXISTS
+                        (SELECT 1
+                         FROM groups
+                         WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id)
+                    AND NOT EXISTS
+                        (SELECT 1
+                         FROM room_memberships
+                         WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id)
+                    AND NOT EXISTS
+                        (SELECT 1
+                         FROM user_directory
+                         WHERE user_directory.avatar_url = '{media_prefix}' || lmr.media_id)
+                    AND NOT EXISTS
+                        (SELECT 1
+                         FROM room_stats_state
+                         WHERE room_stats_state.avatar = '{media_prefix}' || lmr.media_id)
+                )
+            """.format(
+                media_prefix="mxc://%s/" % (self.server_name,),
+            )
+            sql += sql_keep
+
+        def _get_local_media_before_txn(txn):
+            txn.execute(sql, (before_ts, before_ts, size_gt))
+            return [row[0] for row in txn]
+
+        return await self.db_pool.runInteraction(
+            "get_local_media_before", _get_local_media_before_txn
+        )
+
     async def store_local_media(
         self,
         media_id,
@@ -348,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_remote_media_thumbnails",
         )
 
+    async def get_remote_media_thumbnail(
+        self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+    ) -> Optional[Dict[str, Any]]:
+        """Fetch the thumbnail info of given width, height and type.
+        """
+
+        return await self.db_pool.simple_select_one(
+            table="remote_media_cache_thumbnails",
+            keyvalues={
+                "media_origin": origin,
+                "media_id": media_id,
+                "thumbnail_width": t_width,
+                "thumbnail_height": t_height,
+                "thumbnail_type": t_type,
+            },
+            retcols=(
+                "thumbnail_width",
+                "thumbnail_height",
+                "thumbnail_method",
+                "thumbnail_type",
+                "thumbnail_length",
+                "filesystem_id",
+            ),
+            allow_none=True,
+            desc="get_remote_media_thumbnail",
+        )
+
     async def store_remote_media_thumbnail(
         self,
         origin,
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 92099f95ce..ab18cc4d79 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -12,15 +12,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import calendar
+import logging
+import time
+from typing import Dict
 
 from synapse.metrics import GaugeBucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
 
+logger = logging.getLogger(__name__)
+
 # Collect metrics on the number of forward extremities that exist.
 _extremities_collecter = GaugeBucketCollector(
     "synapse_forward_extremities",
@@ -51,15 +57,13 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         super().__init__(database, db_conn, hs)
 
         # Read the extrems every 60 minutes
-        def read_forward_extremities():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "read_forward_extremities", self._read_forward_extremities
-            )
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000)
 
-        hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+        # Used in _generate_user_daily_visits to keep track of progress
+        self._last_user_visit_update = self._get_start_of_day()
 
+    @wrap_as_background_process("read_forward_extremities")
     async def _read_forward_extremities(self):
         def fetch(txn):
             txn.execute(
@@ -137,3 +141,196 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             return count
 
         return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
+
+    async def count_daily_users(self) -> int:
+        """
+        Counts the number of users who used this homeserver in the last 24 hours.
+        """
+        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+        return await self.db_pool.runInteraction(
+            "count_daily_users", self._count_users, yesterday
+        )
+
+    async def count_monthly_users(self) -> int:
+        """
+        Counts the number of users who used this homeserver in the last 30 days.
+        Note this method is intended for phonehome metrics only and is different
+        from the mau figure in synapse.storage.monthly_active_users which,
+        amongst other things, includes a 3 day grace period before a user counts.
+        """
+        thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+        return await self.db_pool.runInteraction(
+            "count_monthly_users", self._count_users, thirty_days_ago
+        )
+
+    def _count_users(self, txn, time_from):
+        """
+        Returns number of users seen in the past time_from period
+        """
+        sql = """
+            SELECT COALESCE(count(*), 0) FROM (
+                SELECT user_id FROM user_ips
+                WHERE last_seen > ?
+                GROUP BY user_id
+            ) u
+        """
+        txn.execute(sql, (time_from,))
+        (count,) = txn.fetchone()
+        return count
+
+    async def count_r30_users(self) -> Dict[str, int]:
+        """
+        Counts the number of 30 day retained users, defined as:-
+         * Users who have created their accounts more than 30 days ago
+         * Where last seen at most 30 days ago
+         * Where account creation and last_seen are > 30 days apart
+
+        Returns:
+             A mapping of counts globally as well as broken out by platform.
+        """
+
+        def _count_r30_users(txn):
+            thirty_days_in_secs = 86400 * 30
+            now = int(self._clock.time())
+            thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+            sql = """
+                SELECT platform, COALESCE(count(*), 0) FROM (
+                     SELECT
+                        users.name, platform, users.creation_ts * 1000,
+                        MAX(uip.last_seen)
+                     FROM users
+                     INNER JOIN (
+                         SELECT
+                         user_id,
+                         last_seen,
+                         CASE
+                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
+                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+                             ELSE 'unknown'
+                         END
+                         AS platform
+                         FROM user_ips
+                     ) uip
+                     ON users.name = uip.user_id
+                     AND users.appservice_id is NULL
+                     AND users.creation_ts < ?
+                     AND uip.last_seen/1000 > ?
+                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                     GROUP BY users.name, platform, users.creation_ts
+                ) u GROUP BY platform
+            """
+
+            results = {}
+            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+            for row in txn:
+                if row[0] == "unknown":
+                    pass
+                results[row[0]] = row[1]
+
+            sql = """
+                SELECT COALESCE(count(*), 0) FROM (
+                    SELECT users.name, users.creation_ts * 1000,
+                                                        MAX(uip.last_seen)
+                    FROM users
+                    INNER JOIN (
+                        SELECT
+                        user_id,
+                        last_seen
+                        FROM user_ips
+                    ) uip
+                    ON users.name = uip.user_id
+                    AND appservice_id is NULL
+                    AND users.creation_ts < ?
+                    AND uip.last_seen/1000 > ?
+                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                    GROUP BY users.name, users.creation_ts
+                ) u
+            """
+
+            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+            (count,) = txn.fetchone()
+            results["all"] = count
+
+            return results
+
+        return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+
+    def _get_start_of_day(self):
+        """
+        Returns millisecond unixtime for start of UTC day.
+        """
+        now = time.gmtime()
+        today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+        return today_start * 1000
+
+    @wrap_as_background_process("generate_user_daily_visits")
+    async def generate_user_daily_visits(self) -> None:
+        """
+        Generates daily visit data for use in cohort/ retention analysis
+        """
+
+        def _generate_user_daily_visits(txn):
+            logger.info("Calling _generate_user_daily_visits")
+            today_start = self._get_start_of_day()
+            a_day_in_milliseconds = 24 * 60 * 60 * 1000
+            now = self._clock.time_msec()
+
+            # A note on user_agent. Technically a given device can have multiple
+            # user agents, so we need to decide which one to pick. We could have
+            # handled this in number of ways, but given that we don't care
+            # _that_ much we have gone for MAX(). For more details of the other
+            # options considered see
+            # https://github.com/matrix-org/synapse/pull/8503#discussion_r502306111
+            sql = """
+                INSERT INTO user_daily_visits (user_id, device_id, timestamp, user_agent)
+                    SELECT u.user_id, u.device_id, ?, MAX(u.user_agent)
+                    FROM user_ips AS u
+                    LEFT JOIN (
+                      SELECT user_id, device_id, timestamp FROM user_daily_visits
+                      WHERE timestamp = ?
+                    ) udv
+                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+                    INNER JOIN users ON users.name=u.user_id
+                    WHERE last_seen > ? AND last_seen <= ?
+                    AND udv.timestamp IS NULL AND users.is_guest=0
+                    AND users.appservice_id IS NULL
+                    GROUP BY u.user_id, u.device_id
+            """
+
+            # This means that the day has rolled over but there could still
+            # be entries from the previous day. There is an edge case
+            # where if the user logs in at 23:59 and overwrites their
+            # last_seen at 00:01 then they will not be counted in the
+            # previous day's stats - it is important that the query is run
+            # often to minimise this case.
+            if today_start > self._last_user_visit_update:
+                yesterday_start = today_start - a_day_in_milliseconds
+                txn.execute(
+                    sql,
+                    (
+                        yesterday_start,
+                        yesterday_start,
+                        self._last_user_visit_update,
+                        today_start,
+                    ),
+                )
+                self._last_user_visit_update = today_start
+
+            txn.execute(
+                sql, (today_start, today_start, self._last_user_visit_update, now)
+            )
+            # Update _last_user_visit_update to now. The reason to do this
+            # rather just clamping to the beginning of the day is to limit
+            # the size of the join - meaning that the query can be run more
+            # frequently
+            self._last_user_visit_update = now
+
+        await self.db_pool.runInteraction(
+            "generate_user_daily_visits", _generate_user_daily_visits
+        )
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e93aad33cd..d788dc0fc6 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,6 +15,7 @@
 import logging
 from typing import Dict, List
 
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached
@@ -32,6 +33,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         self._clock = hs.get_clock()
         self.hs = hs
 
+        self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+        self._max_mau_value = hs.config.max_mau_value
+
     @cached(num_args=0)
     async def get_monthly_active_count(self) -> int:
         """Generates current count of monthly active users
@@ -124,60 +128,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             desc="user_last_seen_monthly_active",
         )
 
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        self._limit_usage_by_mau = hs.config.limit_usage_by_mau
-        self._mau_stats_only = hs.config.mau_stats_only
-        self._max_mau_value = hs.config.max_mau_value
-
-        # Do not add more reserved users than the total allowable number
-        # cur = LoggingTransaction(
-        self.db_pool.new_transaction(
-            db_conn,
-            "initialise_mau_threepids",
-            [],
-            [],
-            self._initialise_reserved_users,
-            hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
-        )
-
-    def _initialise_reserved_users(self, txn, threepids):
-        """Ensures that reserved threepids are accounted for in the MAU table, should
-        be called on start up.
-
-        Args:
-            txn (cursor):
-            threepids (list[dict]): List of threepid dicts to reserve
-        """
-
-        # XXX what is this function trying to achieve?  It upserts into
-        # monthly_active_users for each *registered* reserved mau user, but why?
-        #
-        #  - shouldn't there already be an entry for each reserved user (at least
-        #    if they have been active recently)?
-        #
-        #  - if it's important that the timestamp is kept up to date, why do we only
-        #    run this at startup?
-
-        for tp in threepids:
-            user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
-
-            if user_id:
-                is_support = self.is_support_user_txn(txn, user_id)
-                if not is_support:
-                    # We do this manually here to avoid hitting #6791
-                    self.db_pool.simple_upsert_txn(
-                        txn,
-                        table="monthly_active_users",
-                        keyvalues={"user_id": user_id},
-                        values={"timestamp": int(self._clock.time_msec())},
-                    )
-            else:
-                logger.warning("mau limit reserved threepid %s not found in db" % tp)
-
+    @wrap_as_background_process("reap_monthly_active_users")
     async def reap_monthly_active_users(self):
         """Cleans out monthly active user table to ensure that no stale
         entries exist.
@@ -257,6 +208,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             "reap_monthly_active_users", _reap_users, reserved_users
         )
 
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._mau_stats_only = hs.config.mau_stats_only
+
+        # Do not add more reserved users than the total allowable number
+        self.db_pool.new_transaction(
+            db_conn,
+            "initialise_mau_threepids",
+            [],
+            [],
+            self._initialise_reserved_users,
+            hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
+        )
+
+    def _initialise_reserved_users(self, txn, threepids):
+        """Ensures that reserved threepids are accounted for in the MAU table, should
+        be called on start up.
+
+        Args:
+            txn (cursor):
+            threepids (list[dict]): List of threepid dicts to reserve
+        """
+
+        # XXX what is this function trying to achieve?  It upserts into
+        # monthly_active_users for each *registered* reserved mau user, but why?
+        #
+        #  - shouldn't there already be an entry for each reserved user (at least
+        #    if they have been active recently)?
+        #
+        #  - if it's important that the timestamp is kept up to date, why do we only
+        #    run this at startup?
+
+        for tp in threepids:
+            user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
+
+            if user_id:
+                is_support = self.is_support_user_txn(txn, user_id)
+                if not is_support:
+                    # We do this manually here to avoid hitting #6791
+                    self.db_pool.simple_upsert_txn(
+                        txn,
+                        table="monthly_active_users",
+                        keyvalues={"user_id": user_id},
+                        values={"timestamp": int(self._clock.time_msec())},
+                    )
+            else:
+                logger.warning("mau limit reserved threepid %s not found in db" % tp)
+
     async def upsert_monthly_active_user(self, user_id: str) -> None:
         """Updates or inserts the user into the monthly active user table, which
         is used to track the current MAU usage of the server
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index d2e0685e9e..0e25ca3d7a 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
@@ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore):
             avatar_url=profile["avatar_url"], display_name=profile["displayname"]
         )
 
-    async def get_profile_displayname(self, user_localpart: str) -> str:
+    async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
         return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
@@ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="get_profile_displayname",
         )
 
-    async def get_profile_avatar_url(self, user_localpart: str) -> str:
+    async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
         return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
@@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     async def set_profile_displayname(
-        self, user_localpart: str, new_displayname: str
+        self, user_localpart: str, new_displayname: Optional[str]
     ) -> None:
         await self.db_pool.simple_update_one(
             table="profiles",
@@ -91,27 +91,6 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="set_profile_avatar_url",
         )
 
-
-class ProfileStore(ProfileWorkerStore):
-    async def add_remote_profile_cache(
-        self, user_id: str, displayname: str, avatar_url: str
-    ) -> None:
-        """Ensure we are caching the remote user's profiles.
-
-        This should only be called when `is_subscribed_remote_profile_for_user`
-        would return true for the user.
-        """
-        await self.db_pool.simple_upsert(
-            table="remote_profile_cache",
-            keyvalues={"user_id": user_id},
-            values={
-                "displayname": displayname,
-                "avatar_url": avatar_url,
-                "last_check": self._clock.time_msec(),
-            },
-            desc="add_remote_profile_cache",
-        )
-
     async def update_remote_profile_cache(
         self, user_id: str, displayname: str, avatar_url: str
     ) -> int:
@@ -138,9 +117,34 @@ class ProfileStore(ProfileWorkerStore):
                 desc="delete_remote_profile_cache",
             )
 
+    async def is_subscribed_remote_profile_for_user(self, user_id):
+        """Check whether we are interested in a remote user's profile.
+        """
+        res = await self.db_pool.simple_select_one_onecol(
+            table="group_users",
+            keyvalues={"user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="should_update_remote_profile_cache_for_user",
+        )
+
+        if res:
+            return True
+
+        res = await self.db_pool.simple_select_one_onecol(
+            table="group_invites",
+            keyvalues={"user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="should_update_remote_profile_cache_for_user",
+        )
+
+        if res:
+            return True
+
     async def get_remote_profile_cache_entries_that_expire(
         self, last_checked: int
-    ) -> Dict[str, str]:
+    ) -> List[Dict[str, str]]:
         """Get all users who haven't been checked since `last_checked`
         """
 
@@ -160,27 +164,23 @@ class ProfileStore(ProfileWorkerStore):
             _get_remote_profile_cache_entries_that_expire_txn,
         )
 
-    async def is_subscribed_remote_profile_for_user(self, user_id):
-        """Check whether we are interested in a remote user's profile.
-        """
-        res = await self.db_pool.simple_select_one_onecol(
-            table="group_users",
-            keyvalues={"user_id": user_id},
-            retcol="user_id",
-            allow_none=True,
-            desc="should_update_remote_profile_cache_for_user",
-        )
 
-        if res:
-            return True
+class ProfileStore(ProfileWorkerStore):
+    async def add_remote_profile_cache(
+        self, user_id: str, displayname: str, avatar_url: str
+    ) -> None:
+        """Ensure we are caching the remote user's profiles.
 
-        res = await self.db_pool.simple_select_one_onecol(
-            table="group_invites",
+        This should only be called when `is_subscribed_remote_profile_for_user`
+        would return true for the user.
+        """
+        await self.db_pool.simple_upsert(
+            table="remote_profile_cache",
             keyvalues={"user_id": user_id},
-            retcol="user_id",
-            allow_none=True,
-            desc="should_update_remote_profile_cache_for_user",
+            values={
+                "displayname": displayname,
+                "avatar_url": avatar_url,
+                "last_check": self._clock.time_msec(),
+            },
+            desc="add_remote_profile_cache",
         )
-
-        if res:
-            return True
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ecfc6717b3..5d668aadb2 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         for table in (
             "event_auth",
             "event_edges",
+            "event_json",
             "event_push_actions_staging",
             "event_reference_hashes",
             "event_relations",
@@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             "destination_rooms",
             "event_backward_extremities",
             "event_forward_extremities",
-            "event_json",
             "event_push_actions",
             "event_search",
             "events",
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df8609b97b..7997242d90 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore):
                 lock=False,
             )
 
-            user_has_pusher = self.get_if_user_has_pusher.cache.get(
+            user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
                 (user_id,), None, update_metrics=False
             )
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c79ddff680..1e7949a323 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -23,8 +23,8 @@ from twisted.internet import defer
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
 from synapse.util import json_encoder
-from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -274,6 +274,65 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
         }
         return results
 
+    @cached(num_args=2,)
+    async def get_linearized_receipts_for_all_rooms(
+        self, to_key: int, from_key: Optional[int] = None
+    ) -> Dict[str, JsonDict]:
+        """Get receipts for all rooms between two stream_ids, up
+        to a limit of the latest 100 read receipts.
+
+        Args:
+            to_key: Max stream id to fetch receipts upto.
+            from_key: Min stream id to fetch receipts from. None fetches
+                from the start.
+
+        Returns:
+            A dictionary of roomids to a list of receipts.
+        """
+
+        def f(txn):
+            if from_key:
+                sql = """
+                    SELECT * FROM receipts_linearized WHERE
+                    stream_id > ? AND stream_id <= ?
+                    ORDER BY stream_id DESC
+                    LIMIT 100
+                """
+                txn.execute(sql, [from_key, to_key])
+            else:
+                sql = """
+                    SELECT * FROM receipts_linearized WHERE
+                    stream_id <= ?
+                    ORDER BY stream_id DESC
+                    LIMIT 100
+                """
+
+                txn.execute(sql, [to_key])
+
+            return self.db_pool.cursor_to_dict(txn)
+
+        txn_results = await self.db_pool.runInteraction(
+            "get_linearized_receipts_for_all_rooms", f
+        )
+
+        results = {}
+        for row in txn_results:
+            # We want a single event per room, since we want to batch the
+            # receipts by room, event and type.
+            room_event = results.setdefault(
+                row["room_id"],
+                {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+            )
+
+            # The content is of the form:
+            # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
+            event_entry = room_event["content"].setdefault(row["event_id"], {})
+            receipt_type = event_entry.setdefault(row["receipt_type"], {})
+
+            receipt_type[row["user_id"]] = db_to_json(row["data"])
+
+        return results
+
     async def get_users_sent_receipts_between(
         self, last_id: int, current_id: int
     ) -> List[str]:
@@ -358,18 +417,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
         if receipt_type != "m.read":
             return
 
-        # Returns either an ObservableDeferred or the raw result
-        res = self.get_users_with_read_receipts_in_room.cache.get(
+        res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
             room_id, None, update_metrics=False
         )
 
-        # first handle the ObservableDeferred case
-        if isinstance(res, ObservableDeferred):
-            if res.has_called():
-                res = res.get_result()
-            else:
-                res = None
-
         if res and user_id in res:
             # We'd only be adding to the set, so no point invalidating if the
             # user is already there
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a83df7759d..fedb8a6c26 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,32 +14,66 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import re
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+import attr
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import SQLBaseStore
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage.database import DatabasePool
-from synapse.storage.types import Cursor
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.id_generators import IdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 
 logger = logging.getLogger(__name__)
 
 
-class RegistrationWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+@attr.s(frozen=True, slots=True)
+class TokenLookupResult:
+    """Result of looking up an access token.
+
+    Attributes:
+        user_id: The user that this token authenticates as
+        is_guest
+        shadow_banned
+        token_id: The ID of the access token looked up
+        device_id: The device associated with the token, if any.
+        valid_until_ms: The timestamp the token expires, if any.
+        token_owner: The "owner" of the token. This is either the same as the
+            user, or a server admin who is logged in as the user.
+    """
+
+    user_id = attr.ib(type=str)
+    is_guest = attr.ib(type=bool, default=False)
+    shadow_banned = attr.ib(type=bool, default=False)
+    token_id = attr.ib(type=Optional[int], default=None)
+    device_id = attr.ib(type=Optional[str], default=None)
+    valid_until_ms = attr.ib(type=Optional[int], default=None)
+    token_owner = attr.ib(type=str)
+
+    # Make the token owner default to the user ID, which is the common case.
+    @token_owner.default
+    def _default_token_owner(self):
+        return self.user_id
+
+
+class RegistrationWorkerStore(CacheInvalidationWorkerStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
-        self.clock = hs.get_clock()
 
         # Note: we don't check this sequence for consistency as we'd have to
         # call `find_max_generated_user_id_localpart` each time, which is
@@ -48,6 +82,18 @@ class RegistrationWorkerStore(SQLBaseStore):
             database.engine, find_max_generated_user_id_localpart, "user_id_seq",
         )
 
+        self._account_validity = hs.config.account_validity
+        if hs.config.run_background_tasks and self._account_validity.enabled:
+            self._clock.call_later(
+                0.0, self._set_expiration_date_when_missing,
+            )
+
+        # Create a background job for culling expired 3PID validity tokens
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(
+                self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+            )
+
     @cached()
     async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
         return await self.db_pool.simple_select_one(
@@ -81,21 +127,19 @@ class RegistrationWorkerStore(SQLBaseStore):
         if not info:
             return False
 
-        now = self.clock.time_msec()
+        now = self._clock.time_msec()
         trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
         is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
         return is_trial
 
     @cached()
-    async def get_user_by_access_token(self, token: str) -> Optional[dict]:
+    async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
         """Get a user from the given access token.
 
         Args:
             token: The access token of a user.
         Returns:
-            None, if the token did not match, otherwise dict
-            including the keys `name`, `is_guest`, `device_id`, `token_id`,
-            `valid_until_ms`.
+            None, if the token did not match, otherwise a `TokenLookupResult`
         """
         return await self.db_pool.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
@@ -225,13 +269,13 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_renewal_token_for_user",
         )
 
-    async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
+    async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
         """Selects users whose account will expire in the [now, now + renew_at] time
         window (see configuration for account_validity for information on what renew_at
         refers to).
 
         Returns:
-            A list of dictionaries mapping user ID to expiration time (in milliseconds).
+            A list of dictionaries, each with a user ID and expiration time (in milliseconds).
         """
 
         def select_users_txn(txn, now_ms, renew_at):
@@ -246,7 +290,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         return await self.db_pool.runInteraction(
             "get_users_expiring_soon",
             select_users_txn,
-            self.clock.time_msec(),
+            self._clock.time_msec(),
             self.config.account_validity.renew_at,
         )
 
@@ -316,19 +360,24 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
-    def _query_for_auth(self, txn, token):
-        sql = (
-            "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
-            " access_tokens.device_id, access_tokens.valid_until_ms"
-            " FROM users"
-            " INNER JOIN access_tokens on users.name = access_tokens.user_id"
-            " WHERE token = ?"
-        )
+    def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
+        sql = """
+            SELECT users.name as user_id,
+                users.is_guest,
+                users.shadow_banned,
+                access_tokens.id as token_id,
+                access_tokens.device_id,
+                access_tokens.valid_until_ms,
+                access_tokens.user_id as token_owner
+            FROM users
+            INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
+            WHERE token = ?
+        """
 
         txn.execute(sql, (token,))
         rows = self.db_pool.cursor_to_dict(txn)
         if rows:
-            return rows[0]
+            return TokenLookupResult(**rows[0])
 
         return None
 
@@ -778,12 +827,111 @@ class RegistrationWorkerStore(SQLBaseStore):
             "delete_threepid_session", delete_threepid_session_txn
         )
 
+    @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+    async def cull_expired_threepid_validation_tokens(self) -> None:
+        """Remove threepid validation tokens with expiry dates that have passed"""
+
+        def cull_expired_threepid_validation_tokens_txn(txn, ts):
+            sql = """
+            DELETE FROM threepid_validation_token WHERE
+            expires < ?
+            """
+            txn.execute(sql, (ts,))
+
+        await self.db_pool.runInteraction(
+            "cull_expired_threepid_validation_tokens",
+            cull_expired_threepid_validation_tokens_txn,
+            self._clock.time_msec(),
+        )
+
+    @wrap_as_background_process("account_validity_set_expiration_dates")
+    async def _set_expiration_date_when_missing(self):
+        """
+        Retrieves the list of registered users that don't have an expiration date, and
+        adds an expiration date for each of them.
+        """
+
+        def select_users_with_no_expiration_date_txn(txn):
+            """Retrieves the list of registered users with no expiration date from the
+            database, filtering out deactivated users.
+            """
+            sql = (
+                "SELECT users.name FROM users"
+                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+            )
+            txn.execute(sql, [])
+
+            res = self.db_pool.cursor_to_dict(txn)
+            if res:
+                for user in res:
+                    self.set_expiration_date_for_user_txn(
+                        txn, user["name"], use_delta=True
+                    )
+
+        await self.db_pool.runInteraction(
+            "get_users_with_no_expiration_date",
+            select_users_with_no_expiration_date_txn,
+        )
+
+    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+        """Sets an expiration date to the account with the given user ID.
+
+        Args:
+             user_id (str): User ID to set an expiration date for.
+             use_delta (bool): If set to False, the expiration date for the user will be
+                now + validity period. If set to True, this expiration date will be a
+                random value in the [now + period - d ; now + period] range, d being a
+                delta equal to 10% of the validity period.
+        """
+        now_ms = self._clock.time_msec()
+        expiration_ts = now_ms + self._account_validity.period
+
+        if use_delta:
+            expiration_ts = self.rand.randrange(
+                expiration_ts - self._account_validity.startup_job_max_delta,
+                expiration_ts,
+            )
+
+        self.db_pool.simple_upsert_txn(
+            txn,
+            "account_validity",
+            keyvalues={"user_id": user_id},
+            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+        )
+
+    async def get_user_pending_deactivation(self) -> Optional[str]:
+        """
+        Gets one user from the table of users waiting to be parted from all the rooms
+        they're in.
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            "users_pending_deactivation",
+            keyvalues={},
+            retcol="user_id",
+            allow_none=True,
+            desc="get_users_pending_deactivation",
+        )
+
+    async def del_user_pending_deactivation(self, user_id: str) -> None:
+        """
+        Removes the given user to the table of users who need to be parted from all the
+        rooms they're in, effectively marking that user as fully deactivated.
+        """
+        # XXX: This should be simple_delete_one but we failed to put a unique index on
+        # the table, so somehow duplicate entries have ended up in it.
+        await self.db_pool.simple_delete(
+            "users_pending_deactivation",
+            keyvalues={"user_id": user_id},
+            desc="del_user_pending_deactivation",
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
-        self.clock = hs.get_clock()
+        self._clock = hs.get_clock()
         self.config = hs.config
 
         self.db_pool.updates.register_background_index_update(
@@ -906,32 +1054,55 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
         return 1
 
+    async def set_user_deactivated_status(
+        self, user_id: str, deactivated: bool
+    ) -> None:
+        """Set the `deactivated` property for the provided user to the provided value.
 
-class RegistrationStore(RegistrationBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
+        Args:
+            user_id: The ID of the user to set the status for.
+            deactivated: The value to set for `deactivated`.
+        """
 
-        self._account_validity = hs.config.account_validity
-        self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
+        await self.db_pool.runInteraction(
+            "set_user_deactivated_status",
+            self.set_user_deactivated_status_txn,
+            user_id,
+            deactivated,
+        )
 
-        if self._account_validity.enabled:
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "account_validity_set_expiration_dates",
-                self._set_expiration_date_when_missing,
-            )
+    def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
+        self.db_pool.simple_update_one_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            updatevalues={"deactivated": 1 if deactivated else 0},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_user_deactivated_status, (user_id,)
+        )
+        txn.call_after(self.is_guest.invalidate, (user_id,))
+
+    @cached()
+    async def is_guest(self, user_id: str) -> bool:
+        res = await self.db_pool.simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="is_guest",
+            allow_none=True,
+            desc="is_guest",
+        )
+
+        return res if res else False
 
-        # Create a background job for culling expired 3PID validity tokens
-        def start_cull():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "cull_expired_threepid_validation_tokens",
-                self.cull_expired_threepid_validation_tokens,
-            )
 
-        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
+class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+        super().__init__(database, db_conn, hs)
+
+        self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
+
+        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
 
     async def add_access_token_to_user(
         self,
@@ -939,7 +1110,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         token: str,
         device_id: Optional[str],
         valid_until_ms: Optional[int],
-    ) -> None:
+        puppets_user_id: Optional[str] = None,
+    ) -> int:
         """Adds an access token for the given user.
 
         Args:
@@ -949,6 +1121,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             valid_until_ms: when the token is valid until. None for no expiry.
         Raises:
             StoreError if there was a problem adding this.
+        Returns:
+            The token ID
         """
         next_id = self._access_tokens_id_gen.get_next()
 
@@ -960,10 +1134,43 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 "token": token,
                 "device_id": device_id,
                 "valid_until_ms": valid_until_ms,
+                "puppets_user_id": puppets_user_id,
             },
             desc="add_access_token_to_user",
         )
 
+        return next_id
+
+    def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+        old_device_id = self.db_pool.simple_select_one_onecol_txn(
+            txn, "access_tokens", {"token": token}, "device_id"
+        )
+
+        self.db_pool.simple_update_txn(
+            txn, "access_tokens", {"token": token}, {"device_id": device_id}
+        )
+
+        self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
+
+        return old_device_id
+
+    async def set_device_for_access_token(self, token: str, device_id: str) -> str:
+        """Sets the device ID associated with an access token.
+
+        Args:
+            token: The access token to modify.
+            device_id: The new device ID.
+        Returns:
+            The old device ID associated with the access token.
+        """
+
+        return await self.db_pool.runInteraction(
+            "set_device_for_access_token",
+            self._set_device_for_access_token_txn,
+            token,
+            device_id,
+        )
+
     async def register_user(
         self,
         user_id: str,
@@ -1014,19 +1221,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
     def _register_user(
         self,
         txn,
-        user_id,
-        password_hash,
-        was_guest,
-        make_guest,
-        appservice_id,
-        create_profile_with_displayname,
-        admin,
-        user_type,
-        shadow_banned,
+        user_id: str,
+        password_hash: Optional[str],
+        was_guest: bool,
+        make_guest: bool,
+        appservice_id: Optional[str],
+        create_profile_with_displayname: Optional[str],
+        admin: bool,
+        user_type: Optional[str],
+        shadow_banned: bool,
     ):
         user_id_obj = UserID.from_string(user_id)
 
-        now = int(self.clock.time())
+        now = int(self._clock.time())
 
         try:
             if was_guest:
@@ -1121,7 +1328,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="record_user_external_id",
         )
 
-    async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
+    async def user_set_password_hash(
+        self, user_id: str, password_hash: Optional[str]
+    ) -> None:
         """
         NB. This does *not* evict any cache because the one use for this
             removes most of the entries subsequently anyway so it would be
@@ -1248,18 +1457,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         await self.db_pool.runInteraction("delete_access_token", f)
 
-    @cached()
-    async def is_guest(self, user_id: str) -> bool:
-        res = await self.db_pool.simple_select_one_onecol(
-            table="users",
-            keyvalues={"name": user_id},
-            retcol="is_guest",
-            allow_none=True,
-            desc="is_guest",
-        )
-
-        return res if res else False
-
     async def add_user_pending_deactivation(self, user_id: str) -> None:
         """
         Adds a user to the table of users who need to be parted from all the rooms they're
@@ -1271,32 +1468,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_user_pending_deactivation",
         )
 
-    async def del_user_pending_deactivation(self, user_id: str) -> None:
-        """
-        Removes the given user to the table of users who need to be parted from all the
-        rooms they're in, effectively marking that user as fully deactivated.
-        """
-        # XXX: This should be simple_delete_one but we failed to put a unique index on
-        # the table, so somehow duplicate entries have ended up in it.
-        await self.db_pool.simple_delete(
-            "users_pending_deactivation",
-            keyvalues={"user_id": user_id},
-            desc="del_user_pending_deactivation",
-        )
-
-    async def get_user_pending_deactivation(self) -> Optional[str]:
-        """
-        Gets one user from the table of users waiting to be parted from all the rooms
-        they're in.
-        """
-        return await self.db_pool.simple_select_one_onecol(
-            "users_pending_deactivation",
-            keyvalues={},
-            retcol="user_id",
-            allow_none=True,
-            desc="get_users_pending_deactivation",
-        )
-
     async def validate_threepid_session(
         self, session_id: str, client_secret: str, token: str, current_ts: int
     ) -> Optional[str]:
@@ -1379,7 +1550,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
-                updatevalues={"validated_at": self.clock.time_msec()},
+                updatevalues={"validated_at": self._clock.time_msec()},
             )
 
             return next_link
@@ -1447,106 +1618,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             start_or_continue_validation_session_txn,
         )
 
-    async def cull_expired_threepid_validation_tokens(self) -> None:
-        """Remove threepid validation tokens with expiry dates that have passed"""
-
-        def cull_expired_threepid_validation_tokens_txn(txn, ts):
-            sql = """
-            DELETE FROM threepid_validation_token WHERE
-            expires < ?
-            """
-            txn.execute(sql, (ts,))
-
-        await self.db_pool.runInteraction(
-            "cull_expired_threepid_validation_tokens",
-            cull_expired_threepid_validation_tokens_txn,
-            self.clock.time_msec(),
-        )
-
-    async def set_user_deactivated_status(
-        self, user_id: str, deactivated: bool
-    ) -> None:
-        """Set the `deactivated` property for the provided user to the provided value.
-
-        Args:
-            user_id: The ID of the user to set the status for.
-            deactivated: The value to set for `deactivated`.
-        """
-
-        await self.db_pool.runInteraction(
-            "set_user_deactivated_status",
-            self.set_user_deactivated_status_txn,
-            user_id,
-            deactivated,
-        )
-
-    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
-        self.db_pool.simple_update_one_txn(
-            txn=txn,
-            table="users",
-            keyvalues={"name": user_id},
-            updatevalues={"deactivated": 1 if deactivated else 0},
-        )
-        self._invalidate_cache_and_stream(
-            txn, self.get_user_deactivated_status, (user_id,)
-        )
-        txn.call_after(self.is_guest.invalidate, (user_id,))
-
-    async def _set_expiration_date_when_missing(self):
-        """
-        Retrieves the list of registered users that don't have an expiration date, and
-        adds an expiration date for each of them.
-        """
-
-        def select_users_with_no_expiration_date_txn(txn):
-            """Retrieves the list of registered users with no expiration date from the
-            database, filtering out deactivated users.
-            """
-            sql = (
-                "SELECT users.name FROM users"
-                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
-                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
-            )
-            txn.execute(sql, [])
-
-            res = self.db_pool.cursor_to_dict(txn)
-            if res:
-                for user in res:
-                    self.set_expiration_date_for_user_txn(
-                        txn, user["name"], use_delta=True
-                    )
-
-        await self.db_pool.runInteraction(
-            "get_users_with_no_expiration_date",
-            select_users_with_no_expiration_date_txn,
-        )
-
-    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
-        """Sets an expiration date to the account with the given user ID.
-
-        Args:
-             user_id (str): User ID to set an expiration date for.
-             use_delta (bool): If set to False, the expiration date for the user will be
-                now + validity period. If set to True, this expiration date will be a
-                random value in the [now + period - d ; now + period] range, d being a
-                delta equal to 10% of the validity period.
-        """
-        now_ms = self._clock.time_msec()
-        expiration_ts = now_ms + self._account_validity.period
-
-        if use_delta:
-            expiration_ts = self.rand.randrange(
-                expiration_ts - self._account_validity.startup_job_max_delta,
-                expiration_ts,
-            )
-
-        self.db_pool.simple_upsert_txn(
-            txn,
-            "account_validity",
-            keyvalues={"user_id": user_id},
-            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
-        )
-
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3c7630857f..6b89db15c9 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore):
             "count_public_rooms", _count_public_rooms_txn
         )
 
+    async def get_room_count(self) -> int:
+        """Retrieve the total number of rooms.
+        """
+
+        def f(txn):
+            sql = "SELECT count(*)  FROM rooms"
+            txn.execute(sql)
+            row = txn.fetchone()
+            return row[0] or 0
+
+        return await self.db_pool.runInteraction("get_rooms", f)
+
     async def get_largest_public_rooms(
         self,
         network_tuple: Optional[ThirdPartyInstanceID],
@@ -857,6 +869,89 @@ class RoomWorkerStore(SQLBaseStore):
             "get_all_new_public_rooms", get_all_new_public_rooms
         )
 
+    async def get_rooms_for_retention_period_in_range(
+        self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
+    ) -> Dict[str, dict]:
+        """Retrieves all of the rooms within the given retention range.
+
+        Optionally includes the rooms which don't have a retention policy.
+
+        Args:
+            min_ms: Duration in milliseconds that define the lower limit of
+                the range to handle (exclusive). If None, doesn't set a lower limit.
+            max_ms: Duration in milliseconds that define the upper limit of
+                the range to handle (inclusive). If None, doesn't set an upper limit.
+            include_null: Whether to include rooms which retention policy is NULL
+                in the returned set.
+
+        Returns:
+            The rooms within this range, along with their retention
+            policy. The key is "room_id", and maps to a dict describing the retention
+            policy associated with this room ID. The keys for this nested dict are
+            "min_lifetime" (int|None), and "max_lifetime" (int|None).
+        """
+
+        def get_rooms_for_retention_period_in_range_txn(txn):
+            range_conditions = []
+            args = []
+
+            if min_ms is not None:
+                range_conditions.append("max_lifetime > ?")
+                args.append(min_ms)
+
+            if max_ms is not None:
+                range_conditions.append("max_lifetime <= ?")
+                args.append(max_ms)
+
+            # Do a first query which will retrieve the rooms that have a retention policy
+            # in their current state.
+            sql = """
+                SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+                INNER JOIN current_state_events USING (event_id, room_id)
+                """
+
+            if len(range_conditions):
+                sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+                if include_null:
+                    sql += " OR max_lifetime IS NULL"
+
+            txn.execute(sql, args)
+
+            rows = self.db_pool.cursor_to_dict(txn)
+            rooms_dict = {}
+
+            for row in rows:
+                rooms_dict[row["room_id"]] = {
+                    "min_lifetime": row["min_lifetime"],
+                    "max_lifetime": row["max_lifetime"],
+                }
+
+            if include_null:
+                # If required, do a second query that retrieves all of the rooms we know
+                # of so we can handle rooms with no retention policy.
+                sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+                txn.execute(sql)
+
+                rows = self.db_pool.cursor_to_dict(txn)
+
+                # If a room isn't already in the dict (i.e. it doesn't have a retention
+                # policy in its state), add it with a null policy.
+                for row in rows:
+                    if row["room_id"] not in rooms_dict:
+                        rooms_dict[row["room_id"]] = {
+                            "min_lifetime": None,
+                            "max_lifetime": None,
+                        }
+
+            return rooms_dict
+
+        return await self.db_pool.runInteraction(
+            "get_rooms_for_retention_period_in_range",
+            get_rooms_for_retention_period_in_range_txn,
+        )
+
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1145,13 +1240,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             logger.error("store_room with room_id=%s failed: %s", room_id, e)
             raise StoreError(500, "Problem creating room.")
 
-    async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion):
+    async def maybe_store_room_on_outlier_membership(
+        self, room_id: str, room_version: RoomVersion
+    ):
         """
-        When we receive an invite over federation, store the version of the room if we
-        don't already know the room version.
+        When we receive an invite or any other event over federation that may relate to a room
+        we are not in, store the version of the room if we don't already know the room version.
         """
         await self.db_pool.simple_upsert(
-            desc="maybe_store_room_on_invite",
+            desc="maybe_store_room_on_outlier_membership",
             table="rooms",
             keyvalues={"room_id": room_id},
             values={},
@@ -1292,18 +1389,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    async def get_room_count(self) -> int:
-        """Retrieve the total number of rooms.
-        """
-
-        def f(txn):
-            sql = "SELECT count(*)  FROM rooms"
-            txn.execute(sql)
-            row = txn.fetchone()
-            return row[0] or 0
-
-        return await self.db_pool.runInteraction("get_rooms", f)
-
     async def add_event_report(
         self,
         room_id: str,
@@ -1328,6 +1413,65 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             desc="add_event_report",
         )
 
+    async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
+        """Retrieve an event report
+
+        Args:
+            report_id: ID of reported event in database
+        Returns:
+            event_report: json list of information from event report
+        """
+
+        def _get_event_report_txn(txn, report_id):
+
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.content,
+                    events.sender,
+                    room_stats_state.canonical_alias,
+                    room_stats_state.name,
+                    event_json.json AS event_json
+                FROM event_reports AS er
+                LEFT JOIN events
+                    ON events.event_id = er.event_id
+                JOIN event_json
+                    ON event_json.event_id = er.event_id
+                JOIN room_stats_state
+                    ON room_stats_state.room_id = er.room_id
+                WHERE er.id = ?
+            """
+
+            txn.execute(sql, [report_id])
+            row = txn.fetchone()
+
+            if not row:
+                return None
+
+            event_report = {
+                "id": row[0],
+                "received_ts": row[1],
+                "room_id": row[2],
+                "event_id": row[3],
+                "user_id": row[4],
+                "score": db_to_json(row[5]).get("score"),
+                "reason": db_to_json(row[5]).get("reason"),
+                "sender": row[6],
+                "canonical_alias": row[7],
+                "name": row[8],
+                "event_json": db_to_json(row[9]),
+            }
+
+            return event_report
+
+        return await self.db_pool.runInteraction(
+            "get_event_report", _get_event_report_txn, report_id
+        )
+
     async def get_event_reports_paginate(
         self,
         start: int,
@@ -1385,18 +1529,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     er.room_id,
                     er.event_id,
                     er.user_id,
-                    er.reason,
                     er.content,
                     events.sender,
-                    room_aliases.room_alias,
-                    event_json.json AS event_json
+                    room_stats_state.canonical_alias,
+                    room_stats_state.name
                 FROM event_reports AS er
-                LEFT JOIN room_aliases
-                    ON room_aliases.room_id = er.room_id
-                JOIN events
+                LEFT JOIN events
                     ON events.event_id = er.event_id
-                JOIN event_json
-                    ON event_json.event_id = er.event_id
+                JOIN room_stats_state
+                    ON room_stats_state.room_id = er.room_id
                 {where_clause}
                 ORDER BY er.received_ts {order}
                 LIMIT ?
@@ -1407,15 +1548,29 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
             args += [limit, start]
             txn.execute(sql, args)
-            event_reports = self.db_pool.cursor_to_dict(txn)
-
-            if count > 0:
-                for row in event_reports:
-                    try:
-                        row["content"] = db_to_json(row["content"])
-                        row["event_json"] = db_to_json(row["event_json"])
-                    except Exception:
-                        continue
+
+            event_reports = []
+            for row in txn:
+                try:
+                    s = db_to_json(row[5]).get("score")
+                    r = db_to_json(row[5]).get("reason")
+                except Exception:
+                    logger.error("Unable to parse json from event_reports: %s", row[0])
+                    continue
+                event_reports.append(
+                    {
+                        "id": row[0],
+                        "received_ts": row[1],
+                        "room_id": row[2],
+                        "event_id": row[3],
+                        "user_id": row[4],
+                        "score": s,
+                        "reason": r,
+                        "sender": row[6],
+                        "canonical_alias": row[7],
+                        "name": row[8],
+                    }
+                )
 
             return event_reports, count
 
@@ -1446,88 +1601,3 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             self.is_room_blocked,
             (room_id,),
         )
-
-    async def get_rooms_for_retention_period_in_range(
-        self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
-    ) -> Dict[str, dict]:
-        """Retrieves all of the rooms within the given retention range.
-
-        Optionally includes the rooms which don't have a retention policy.
-
-        Args:
-            min_ms: Duration in milliseconds that define the lower limit of
-                the range to handle (exclusive). If None, doesn't set a lower limit.
-            max_ms: Duration in milliseconds that define the upper limit of
-                the range to handle (inclusive). If None, doesn't set an upper limit.
-            include_null: Whether to include rooms which retention policy is NULL
-                in the returned set.
-
-        Returns:
-            The rooms within this range, along with their retention
-            policy. The key is "room_id", and maps to a dict describing the retention
-            policy associated with this room ID. The keys for this nested dict are
-            "min_lifetime" (int|None), and "max_lifetime" (int|None).
-        """
-
-        def get_rooms_for_retention_period_in_range_txn(txn):
-            range_conditions = []
-            args = []
-
-            if min_ms is not None:
-                range_conditions.append("max_lifetime > ?")
-                args.append(min_ms)
-
-            if max_ms is not None:
-                range_conditions.append("max_lifetime <= ?")
-                args.append(max_ms)
-
-            # Do a first query which will retrieve the rooms that have a retention policy
-            # in their current state.
-            sql = """
-                SELECT room_id, min_lifetime, max_lifetime FROM room_retention
-                INNER JOIN current_state_events USING (event_id, room_id)
-                """
-
-            if len(range_conditions):
-                sql += " WHERE (" + " AND ".join(range_conditions) + ")"
-
-                if include_null:
-                    sql += " OR max_lifetime IS NULL"
-
-            txn.execute(sql, args)
-
-            rows = self.db_pool.cursor_to_dict(txn)
-            rooms_dict = {}
-
-            for row in rows:
-                rooms_dict[row["room_id"]] = {
-                    "min_lifetime": row["min_lifetime"],
-                    "max_lifetime": row["max_lifetime"],
-                }
-
-            if include_null:
-                # If required, do a second query that retrieves all of the rooms we know
-                # of so we can handle rooms with no retention policy.
-                sql = "SELECT DISTINCT room_id FROM current_state_events"
-
-                txn.execute(sql)
-
-                rows = self.db_pool.cursor_to_dict(txn)
-
-                # If a room isn't already in the dict (i.e. it doesn't have a retention
-                # policy in its state), add it with a null policy.
-                for row in rows:
-                    if row["room_id"] not in rooms_dict:
-                        rooms_dict[row["room_id"]] = {
-                            "min_lifetime": None,
-                            "max_lifetime": None,
-                        }
-
-            return rooms_dict
-
-        rooms = await self.db_pool.runInteraction(
-            "get_rooms_for_retention_period_in_range",
-            get_rooms_for_retention_period_in_range_txn,
-        )
-
-        return rooms
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 86ffe2479e..dcdaf09682 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -14,19 +14,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
-    LoggingTransaction,
-    SQLBaseStore,
-    db_to_json,
-    make_in_list_sql_clause,
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
 )
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
@@ -60,27 +58,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # background update still running?
         self._current_state_events_membership_up_to_date = False
 
-        txn = LoggingTransaction(
-            db_conn.cursor(),
-            name="_check_safe_current_state_events_membership_updated",
-            database_engine=self.database_engine,
+        txn = db_conn.cursor(
+            txn_name="_check_safe_current_state_events_membership_updated"
         )
         self._check_safe_current_state_events_membership_updated_txn(txn)
         txn.close()
 
-        if self.hs.config.metrics_flags.known_servers:
+        if (
+            self.hs.config.run_background_tasks
+            and self.hs.config.metrics_flags.known_servers
+        ):
             self._known_servers_count = 1
             self.hs.get_clock().looping_call(
-                run_as_background_process,
-                60 * 1000,
-                "_count_known_servers",
-                self._count_known_servers,
+                self._count_known_servers, 60 * 1000,
             )
             self.hs.get_clock().call_later(
-                1000,
-                run_as_background_process,
-                "_count_known_servers",
-                self._count_known_servers,
+                1000, self._count_known_servers,
             )
             LaterGauge(
                 "synapse_federation_known_servers",
@@ -89,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 lambda: self._known_servers_count,
             )
 
+    @wrap_as_background_process("_count_known_servers")
     async def _count_known_servers(self):
         """
         Count the servers that this server knows about.
@@ -356,6 +350,38 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return results
 
+    async def get_local_current_membership_for_user_in_room(
+        self, user_id: str, room_id: str
+    ) -> Tuple[Optional[str], Optional[str]]:
+        """Retrieve the current local membership state and event ID for a user in a room.
+
+        Args:
+            user_id: The ID of the user.
+            room_id: The ID of the room.
+
+        Returns:
+            A tuple of (membership_type, event_id). Both will be None if a
+                room_id/user_id pair is not found.
+        """
+        # Paranoia check.
+        if not self.hs.is_mine_id(user_id):
+            raise Exception(
+                "Cannot call 'get_local_current_membership_for_user_in_room' on "
+                "non-local user %s" % (user_id,),
+            )
+
+        results_dict = await self.db_pool.simple_select_one(
+            "local_current_membership",
+            {"room_id": room_id, "user_id": user_id},
+            ("membership", "event_id"),
+            allow_none=True,
+            desc="get_local_current_membership_for_user_in_room",
+        )
+        if not results_dict:
+            return None, None
+
+        return results_dict.get("membership"), results_dict.get("event_id")
+
     @cached(max_entries=500000, iterable=True)
     async def get_rooms_for_user_with_stream_ordering(
         self, user_id: str
@@ -535,7 +561,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # If we do then we can reuse that result and simply update it with
             # any membership changes in `delta_ids`
             if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get(
+                prev_res = self._get_joined_users_from_context.cache.get_immediate(
                     (room_id, context.prev_group), None
                 )
                 if prev_res and isinstance(prev_res, dict):
diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..45b846e6a7 100644
--- a/synapse/storage/databases/main/schema/delta/20/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
@@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs):
         row[8] = bytes(row[8]).decode("utf-8")
         row[11] = bytes(row[11]).decode("utf-8")
         cur.execute(
-            database_engine.convert_param_style(
-                """
-            INSERT into pushers2 (
-            id, user_name, access_token, profile_tag, kind,
-            app_id, app_display_name, device_display_name,
-            pushkey, ts, lang, data, last_token, last_success,
-            failing_since
-            ) values (%s)"""
-                % (",".join(["?" for _ in range(len(row))]))
-            ),
+            """
+                INSERT into pushers2 (
+                id, user_name, access_token, profile_tag, kind,
+                app_id, app_display_name, device_display_name,
+                pushkey, ts, lang, data, last_token, last_success,
+                failing_since
+                ) values (%s)
+            """
+            % (",".join(["?" for _ in range(len(row))])),
             row,
         )
         count += 1
diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index ee675e71ff..21f57825d4 100644
--- a/synapse/storage/databases/main/schema/delta/25/fts.py
+++ b/synapse/storage/databases/main/schema/delta/25/fts.py
@@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_search", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index b7972cfa8e..1c6058063f 100644
--- a/synapse/storage/databases/main/schema/delta/27/ts.py
+++ b/synapse/storage/databases/main/schema/delta/27/ts.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_origin_server_ts", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index b42c02710a..7f08fabe9f 100644
--- a/synapse/storage/databases/main/schema/delta/30/as_users.py
+++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
@@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
         user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
         for chunk in user_chunks:
             cur.execute(
-                database_engine.convert_param_style(
-                    "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
-                    % (",".join("?" for _ in chunk),)
-                ),
+                "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+                % (",".join("?" for _ in chunk),),
                 [as_id] + chunk,
             )
diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..5be81c806a 100644
--- a/synapse/storage/databases/main/schema/delta/31/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
@@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs):
         row = list(row)
         row[12] = token_to_stream_ordering(row[12])
         cur.execute(
-            database_engine.convert_param_style(
-                """
-            INSERT into pushers2 (
-            id, user_name, access_token, profile_tag, kind,
-            app_id, app_display_name, device_display_name,
-            pushkey, ts, lang, data, last_stream_ordering, last_success,
-            failing_since
-            ) values (%s)"""
-                % (",".join(["?" for _ in range(len(row))]))
-            ),
+            """
+                INSERT into pushers2 (
+                id, user_name, access_token, profile_tag, kind,
+                app_id, app_display_name, device_display_name,
+                pushkey, ts, lang, data, last_stream_ordering, last_success,
+                failing_since
+                ) values (%s)
+            """
+            % (",".join(["?" for _ in range(len(row))])),
             row,
         )
         count += 1
diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 63b757ade6..b84c844e3a 100644
--- a/synapse/storage/databases/main/schema/delta/31/search_update.py
+++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
@@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_search_order", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index a3e81eeac7..e928c66a8f 100644
--- a/synapse/storage/databases/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_fields_sender_url", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..ad875c733a 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs):
 
 def run_upgrade(cur, database_engine, *args, **kwargs):
     cur.execute(
-        database_engine.convert_param_style(
-            "UPDATE remote_media_cache SET last_access_ts = ?"
-        ),
-        (int(time.time() * 1000),),
+        "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
     )
diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..bb7296852a 100644
--- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
@@ -1,6 +1,8 @@
 import logging
+from io import StringIO
 
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
 
 logger = logging.getLogger(__name__)
 
@@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs):
         select_clause,
     )
 
-    if isinstance(database_engine, PostgresEngine):
-        cur.execute(sql)
-    else:
-        cur.executescript(sql)
+    execute_statements_from_stream(cur, StringIO(sql))
diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..44917f0a2e 100644
--- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
+++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
@@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
                 INNER JOIN room_memberships AS r USING (event_id)
                 WHERE type = 'm.room.member' AND state_key LIKE ?
         """
-    sql = database_engine.convert_param_style(sql)
     cur.execute(sql, ("%:" + config.server_name,))
 
     cur.execute(
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
index b64926e9c9..3275ae2b20 100644
--- a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -20,14 +20,14 @@
  */
 
 -- add new index that includes method to local media
-INSERT INTO background_updates (update_name, progress_json) VALUES
-  ('local_media_repository_thumbnails_method_idx', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (5807, 'local_media_repository_thumbnails_method_idx', '{}');
 
 -- add new index that includes method to remote media
-INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
-  ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+  (5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
 
 -- drop old index
-INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
-  ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+  (5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
 
diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
new file mode 100644
index 0000000000..7851a0a825
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
@@ -0,0 +1,20 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS dehydrated_devices(
+    user_id TEXT NOT NULL PRIMARY KEY,
+    device_id TEXT NOT NULL,
+    device_data TEXT NOT NULL -- JSON-encoded client-defined data
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
new file mode 100644
index 0000000000..4ed981dbf8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
+    user_id TEXT NOT NULL, -- The user this fallback key is for.
+    device_id TEXT NOT NULL, -- The device this fallback key is for.
+    algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
+    key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+    key_json TEXT NOT NULL, -- The key as a JSON blob.
+    used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
+    CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
index cade5dcca8..fd733adf13 100644
--- a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
+++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
@@ -28,5 +28,5 @@
 -- functionality as the old one. This effectively restarts the background job
 -- from the beginning, without running it twice in a row, supporting both
 -- upgrade usecases.
-INSERT INTO background_updates (update_name, progress_json) VALUES
-    ('populate_stats_process_rooms_2', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+    (5812, 'populate_stats_process_rooms_2', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
new file mode 100644
index 0000000000..841186b826
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A unique and immutable mapping between instance name and an integer ID. This
+-- lets us refer to instances via a small ID in e.g. stream tokens, without
+-- having to encode the full name.
+CREATE TABLE IF NOT EXISTS instance_map (
+    instance_id SERIAL PRIMARY KEY,
+    instance_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);
diff --git a/synapse/storage/databases/main/schema/delta/58/19txn_id.sql b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
new file mode 100644
index 0000000000..b2454121a8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
@@ -0,0 +1,40 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A map of recent events persisted with transaction IDs. Used to deduplicate
+-- send event requests with the same transaction ID.
+--
+-- Note: transaction IDs are scoped to the room ID/user ID/access token that was
+-- used to make the request.
+--
+-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
+-- events or access token we don't want to try and de-duplicate the event.
+CREATE TABLE IF NOT EXISTS event_txn_id (
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    token_id BIGINT NOT NULL,
+    txn_id TEXT NOT NULL,
+    inserted_ts BIGINT NOT NULL,
+    FOREIGN KEY (event_id)
+        REFERENCES events (event_id) ON DELETE CASCADE,
+    FOREIGN KEY (token_id)
+        REFERENCES access_tokens (id) ON DELETE CASCADE
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
+CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
diff --git a/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
new file mode 100644
index 0000000000..ad1f481428
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE current_state_delta_stream ADD COLUMN instance_name TEXT;
+ALTER TABLE ex_outlier_stream ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql
new file mode 100644
index 0000000000..b0b5dcddce
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ -- Add new column to user_daily_visits to track user agent
+ALTER TABLE user_daily_visits
+    ADD COLUMN user_agent TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql
new file mode 100644
index 0000000000..7b84a207fd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT;
+ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql
new file mode 100644
index 0000000000..01ea6eddcf
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql
@@ -0,0 +1 @@
+DROP TABLE device_max_stream_id;
diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
new file mode 100644
index 0000000000..00a9431a97
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Whether the access token is an admin token for controlling another user.
+ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
new file mode 100644
index 0000000000..e1a35be831
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
@@ -0,0 +1,2 @@
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (5822, 'users_have_local_media', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
new file mode 100644
index 0000000000..75c3915a94
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (5823, 'e2e_cross_signing_keys_idx', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
new file mode 100644
index 0000000000..8a39d54aed
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this index is essentially redundant. The only time it was ever used was when purging
+-- rooms - and Synapse 1.24 will change that.
+
+DROP INDEX IF EXISTS event_json_room_id;
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5beb302be3..0cdb3ec1f7 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,15 +16,18 @@
 
 import logging
 from collections import Counter
+from enum import Enum
 from itertools import chain
 from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import StoreError
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.storage.engines import PostgresEngine
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -59,6 +62,23 @@ TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user
 TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
 
 
+class UserSortOrder(Enum):
+    """
+    Enum to define the sorting method used when returning users
+    with get_users_media_usage_paginate
+
+    MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest.
+    MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest.
+    USER_ID = ordered alphabetically by `user_id`.
+    DISPLAYNAME = ordered alphabetically by `displayname`
+    """
+
+    MEDIA_LENGTH = "media_length"
+    MEDIA_COUNT = "media_count"
+    USER_ID = "user_id"
+    DISPLAYNAME = "displayname"
+
+
 class StatsStore(StateDeltasStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -882,3 +902,110 @@ class StatsStore(StateDeltasStore):
             complete_with_stream_id=pos,
             absolute_field_overrides={"joined_rooms": joined_rooms},
         )
+
+    async def get_users_media_usage_paginate(
+        self,
+        start: int,
+        limit: int,
+        from_ts: Optional[int] = None,
+        until_ts: Optional[int] = None,
+        order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value,
+        direction: Optional[str] = "f",
+        search_term: Optional[str] = None,
+    ) -> Tuple[List[JsonDict], Dict[str, int]]:
+        """Function to retrieve a paginated list of users and their uploaded local media
+        (size and number). This will return a json list of users and the
+        total number of users matching the filter criteria.
+
+        Args:
+            start: offset to begin the query from
+            limit: number of rows to retrieve
+            from_ts: request only media that are created later than this timestamp (ms)
+            until_ts: request only media that are created earlier than this timestamp (ms)
+            order_by: the sort order of the returned list
+            direction: sort ascending or descending
+            search_term: a string to filter user names by
+        Returns:
+            A list of user dicts and an integer representing the total number of
+            users that exist given this query
+        """
+
+        def get_users_media_usage_paginate_txn(txn):
+            filters = []
+            args = [self.hs.config.server_name]
+
+            if search_term:
+                filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)")
+                args.extend(["@%" + search_term + "%:%", "%" + search_term + "%"])
+
+            if from_ts:
+                filters.append("created_ts >= ?")
+                args.extend([from_ts])
+            if until_ts:
+                filters.append("created_ts <= ?")
+                args.extend([until_ts])
+
+            # Set ordering
+            if UserSortOrder(order_by) == UserSortOrder.MEDIA_LENGTH:
+                order_by_column = "media_length"
+            elif UserSortOrder(order_by) == UserSortOrder.MEDIA_COUNT:
+                order_by_column = "media_count"
+            elif UserSortOrder(order_by) == UserSortOrder.USER_ID:
+                order_by_column = "lmr.user_id"
+            elif UserSortOrder(order_by) == UserSortOrder.DISPLAYNAME:
+                order_by_column = "displayname"
+            else:
+                raise StoreError(
+                    500, "Incorrect value for order_by provided: %s" % order_by
+                )
+
+            if direction == "b":
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+            sql_base = """
+                FROM local_media_repository as lmr
+                LEFT JOIN profiles AS p ON lmr.user_id = '@' || p.user_id || ':' || ?
+                {}
+                GROUP BY lmr.user_id, displayname
+            """.format(
+                where_clause
+            )
+
+            # SQLite does not support SELECT COUNT(*) OVER()
+            sql = """
+                SELECT COUNT(*) FROM (
+                    SELECT lmr.user_id
+                    {sql_base}
+                ) AS count_user_ids
+            """.format(
+                sql_base=sql_base,
+            )
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = """
+                SELECT
+                    lmr.user_id,
+                    displayname,
+                    COUNT(lmr.user_id) as media_count,
+                    SUM(media_length) as media_length
+                    {sql_base}
+                ORDER BY {order_by_column} {order}
+                LIMIT ? OFFSET ?
+            """.format(
+                sql_base=sql_base, order_by_column=order_by_column, order=order,
+            )
+
+            args += [limit, start]
+            txn.execute(sql, args)
+            users = self.db_pool.cursor_to_dict(txn)
+
+            return users, count
+
+        return await self.db_pool.runInteraction(
+            "get_users_media_usage_paginate_txn", get_users_media_usage_paginate_txn
+        )
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 37249f1e3f..e3b9ff5ca6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -53,7 +53,9 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 if TYPE_CHECKING:
@@ -208,6 +210,55 @@ def _make_generic_sql_bound(
     )
 
 
+def _filter_results(
+    lower_token: Optional[RoomStreamToken],
+    upper_token: Optional[RoomStreamToken],
+    instance_name: str,
+    topological_ordering: int,
+    stream_ordering: int,
+) -> bool:
+    """Returns True if the event persisted by the given instance at the given
+    topological/stream_ordering falls between the two tokens (taking a None
+    token to mean unbounded).
+
+    Used to filter results from fetching events in the DB against the given
+    tokens. This is necessary to handle the case where the tokens include
+    position maps, which we handle by fetching more than necessary from the DB
+    and then filtering (rather than attempting to construct a complicated SQL
+    query).
+    """
+
+    event_historical_tuple = (
+        topological_ordering,
+        stream_ordering,
+    )
+
+    if lower_token:
+        if lower_token.topological is not None:
+            # If these are historical tokens we compare the `(topological, stream)`
+            # tuples.
+            if event_historical_tuple <= lower_token.as_historical_tuple():
+                return False
+
+        else:
+            # If these are live tokens we compare the stream ordering against the
+            # writers stream position.
+            if stream_ordering <= lower_token.get_stream_pos_for_instance(
+                instance_name
+            ):
+                return False
+
+    if upper_token:
+        if upper_token.topological is not None:
+            if upper_token.as_historical_tuple() < event_historical_tuple:
+                return False
+        else:
+            if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+                return False
+
+    return True
+
+
 def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
@@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     def get_room_max_token(self) -> RoomStreamToken:
-        return RoomStreamToken(None, self.get_room_max_stream_ordering())
+        """Get a `RoomStreamToken` that marks the current maximum persisted
+        position of the events stream. Useful to get a token that represents
+        "now".
+
+        The token returned is a "live" token that may have an instance_map
+        component.
+        """
+
+        min_pos = self._stream_id_gen.get_current_token()
+
+        positions = {}
+        if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+            # The `min_pos` is the minimum position that we know all instances
+            # have finished persisting to, so we only care about instances whose
+            # positions are ahead of that. (Instance positions can be behind the
+            # min position as there are times we can work out that the minimum
+            # position is ahead of the naive minimum across all current
+            # positions. See MultiWriterIdGenerator for details)
+            positions = {
+                i: p
+                for i, p in self._stream_id_gen.get_positions().items()
+                if p > min_pos
+            }
+
+        return RoomStreamToken(None, min_pos, positions)
 
     async def get_room_events_stream_for_rooms(
         self,
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         if from_key == to_key:
             return [], from_key
 
-        from_id = from_key.stream
-        to_id = to_key.stream
-
-        has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+        has_changed = self._events_stream_cache.has_entity_changed(
+            room_id, from_key.stream
+        )
 
         if not has_changed:
             return [], from_key
 
         def f(txn):
-            sql = (
-                "SELECT event_id, stream_ordering FROM events WHERE"
-                " room_id = ?"
-                " AND not outlier"
-                " AND stream_ordering > ? AND stream_ordering <= ?"
-                " ORDER BY stream_ordering %s LIMIT ?"
-            ) % (order,)
-            txn.execute(sql, (room_id, from_id, to_id, limit))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            sql = """
+                SELECT event_id, instance_name, topological_ordering, stream_ordering
+                FROM events
+                WHERE
+                    room_id = ?
+                    AND not outlier
+                    AND stream_ordering > ? AND stream_ordering <= ?
+                ORDER BY stream_ordering %s LIMIT ?
+            """ % (
+                order,
+            )
+            txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, topological_ordering, stream_ordering in txn
+                if _filter_results(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    topological_ordering,
+                    stream_ordering,
+                )
+            ][:limit]
             return rows
 
         rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             [r.event_id for r in rows], get_prev_content=True
         )
 
-        self._set_before_and_after(ret, rows, topo_order=from_id is None)
+        self._set_before_and_after(ret, rows, topo_order=False)
 
         if order.lower() == "desc":
             ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     async def get_membership_changes_for_user(
         self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
-        from_id = from_key.stream
-        to_id = to_key.stream
-
         if from_key == to_key:
             return []
 
-        if from_id:
+        if from_key:
             has_changed = self._membership_stream_cache.has_entity_changed(
-                user_id, int(from_id)
+                user_id, int(from_key.stream)
             )
             if not has_changed:
                 return []
 
         def f(txn):
-            sql = (
-                "SELECT m.event_id, stream_ordering FROM events AS e,"
-                " room_memberships AS m"
-                " WHERE e.event_id = m.event_id"
-                " AND m.user_id = ?"
-                " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
-            )
-            txn.execute(sql, (user_id, from_id, to_id))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            sql = """
+                SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+                FROM events AS e, room_memberships AS m
+                WHERE e.event_id = m.event_id
+                    AND m.user_id = ?
+                    AND e.stream_ordering > ? AND e.stream_ordering <= ?
+                ORDER BY e.stream_ordering ASC
+            """
+            txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, topological_ordering, stream_ordering in txn
+                if _filter_results(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    topological_ordering,
+                    stream_ordering,
+                )
+            ]
 
             return rows
 
@@ -546,7 +651,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
     async def get_room_event_before_stream_ordering(
         self, room_id: str, stream_ordering: int
-    ) -> Tuple[int, int, str]:
+    ) -> Optional[Tuple[int, int, str]]:
         """Gets details of the first event in a room at or before a stream ordering
 
         Args:
@@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             )
             return "t%d-%d" % (topo, token)
 
-    async def get_stream_id_for_event(self, event_id: str) -> int:
-        """The stream ID for an event
-        Args:
-            event_id: The id of the event to look up a stream token for.
-        Raises:
-            StoreError if the event wasn't in the database.
-        Returns:
-            A stream ID.
-        """
-        return await self.db_pool.runInteraction(
-            "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
-        )
-
     def get_stream_id_for_event_txn(
         self, txn: LoggingTransaction, event_id: str, allow_none=False,
     ) -> int:
@@ -979,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         else:
             order = "ASC"
 
+        # The bounds for the stream tokens are complicated by the fact
+        # that we need to handle the instance_map part of the tokens. We do this
+        # by fetching all events between the min stream token and the maximum
+        # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+        # then filtering the results.
+        if from_token.topological is not None:
+            from_bound = (
+                from_token.as_historical_tuple()
+            )  # type: Tuple[Optional[int], int]
+        elif direction == "b":
+            from_bound = (
+                None,
+                from_token.get_max_stream_pos(),
+            )
+        else:
+            from_bound = (
+                None,
+                from_token.stream,
+            )
+
+        to_bound = None  # type: Optional[Tuple[Optional[int], int]]
+        if to_token:
+            if to_token.topological is not None:
+                to_bound = to_token.as_historical_tuple()
+            elif direction == "b":
+                to_bound = (
+                    None,
+                    to_token.stream,
+                )
+            else:
+                to_bound = (
+                    None,
+                    to_token.get_max_stream_pos(),
+                )
+
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token.as_tuple(),
-            to_token=to_token.as_tuple() if to_token else None,
+            from_token=from_bound,
+            to_token=to_bound,
             engine=self.database_engine,
         )
 
@@ -993,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             bounds += " AND " + filter_clause
             args.extend(filter_args)
 
-        args.append(int(limit))
+        # We fetch more events as we'll filter the result set
+        args.append(int(limit) * 2)
 
         select_keywords = "SELECT"
         join_clause = ""
@@ -1015,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                 select_keywords += "DISTINCT"
 
         sql = """
-            %(select_keywords)s event_id, topological_ordering, stream_ordering
+            %(select_keywords)s
+                event_id, instance_name,
+                topological_ordering, stream_ordering
             FROM events
             %(join_clause)s
             WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1030,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         txn.execute(sql, args)
 
-        rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+        # Filter the result set.
+        rows = [
+            _EventDictReturn(event_id, topological_ordering, stream_ordering)
+            for event_id, instance_name, topological_ordering, stream_ordering in txn
+            if _filter_results(
+                lower_token=to_token if direction == "b" else from_token,
+                upper_token=from_token if direction == "b" else to_token,
+                instance_name=instance_name,
+                topological_ordering=topological_ordering,
+                stream_ordering=stream_ordering,
+            )
+        ][:limit]
 
         if rows:
             topo = rows[-1].topological_ordering
@@ -1095,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         return (events, token)
 
+    @cached()
+    async def get_id_for_instance(self, instance_name: str) -> int:
+        """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+        """
+
+        def _get_id_for_instance_txn(txn):
+            instance_id = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                retcol="instance_id",
+                allow_none=True,
+            )
+            if instance_id is not None:
+                return instance_id
+
+            # If we don't have an entry upsert one.
+            #
+            # We could do this before the first check, and rely on the cache for
+            # efficiency, but each UPSERT causes the next ID to increment which
+            # can quickly bloat the size of the generated IDs for new instances.
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                values={},
+            )
+
+            return self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                retcol="instance_id",
+            )
+
+        return await self.db_pool.runInteraction(
+            "get_id_for_instance", _get_id_for_instance_txn
+        )
+
+    @cached()
+    async def get_name_from_instance_id(self, instance_id: int) -> str:
+        """Get the instance name from an ID previously returned by
+        `get_id_for_instance`.
+        """
+
+        return await self.db_pool.simple_select_one_onecol(
+            table="instance_map",
+            keyvalues={"instance_id": instance_id},
+            retcol="instance_name",
+            desc="get_name_from_instance_id",
+        )
+
 
 class StreamStore(StreamWorkerStore):
     def get_room_max_stream_ordering(self) -> int:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 97aed1500e..59207cadd4 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -43,15 +43,33 @@ _UpdateTransactionRow = namedtuple(
 SENTINEL = object()
 
 
-class TransactionStore(SQLBaseStore):
+class TransactionWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
+
+    @wrap_as_background_process("cleanup_transactions")
+    async def _cleanup_transactions(self) -> None:
+        now = self._clock.time_msec()
+        month_ago = now - 30 * 24 * 60 * 60 * 1000
+
+        def _cleanup_transactions_txn(txn):
+            txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
+
+        await self.db_pool.runInteraction(
+            "_cleanup_transactions", _cleanup_transactions_txn
+        )
+
+
+class TransactionStore(TransactionWorkerStore):
     """A collection of queries for handling PDUs.
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
-        self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
-
         self._destination_retry_cache = ExpiringCache(
             cache_name="get_destination_retry_timings",
             clock=self._clock,
@@ -190,42 +208,56 @@ class TransactionStore(SQLBaseStore):
         """
 
         self._destination_retry_cache.pop(destination, None)
-        return await self.db_pool.runInteraction(
-            "set_destination_retry_timings",
-            self._set_destination_retry_timings,
-            destination,
-            failure_ts,
-            retry_last_ts,
-            retry_interval,
-        )
+        if self.database_engine.can_native_upsert:
+            return await self.db_pool.runInteraction(
+                "set_destination_retry_timings",
+                self._set_destination_retry_timings_native,
+                destination,
+                failure_ts,
+                retry_last_ts,
+                retry_interval,
+                db_autocommit=True,  # Safe as its a single upsert
+            )
+        else:
+            return await self.db_pool.runInteraction(
+                "set_destination_retry_timings",
+                self._set_destination_retry_timings_emulated,
+                destination,
+                failure_ts,
+                retry_last_ts,
+                retry_interval,
+            )
 
-    def _set_destination_retry_timings(
+    def _set_destination_retry_timings_native(
         self, txn, destination, failure_ts, retry_last_ts, retry_interval
     ):
+        assert self.database_engine.can_native_upsert
+
+        # Upsert retry time interval if retry_interval is zero (i.e. we're
+        # resetting it) or greater than the existing retry interval.
+        #
+        # WARNING: This is executed in autocommit, so we shouldn't add any more
+        # SQL calls in here (without being very careful).
+        sql = """
+            INSERT INTO destinations (
+                destination, failure_ts, retry_last_ts, retry_interval
+            )
+                VALUES (?, ?, ?, ?)
+            ON CONFLICT (destination) DO UPDATE SET
+                    failure_ts = EXCLUDED.failure_ts,
+                    retry_last_ts = EXCLUDED.retry_last_ts,
+                    retry_interval = EXCLUDED.retry_interval
+                WHERE
+                    EXCLUDED.retry_interval = 0
+                    OR destinations.retry_interval IS NULL
+                    OR destinations.retry_interval < EXCLUDED.retry_interval
+        """
 
-        if self.database_engine.can_native_upsert:
-            # Upsert retry time interval if retry_interval is zero (i.e. we're
-            # resetting it) or greater than the existing retry interval.
-
-            sql = """
-                INSERT INTO destinations (
-                    destination, failure_ts, retry_last_ts, retry_interval
-                )
-                    VALUES (?, ?, ?, ?)
-                ON CONFLICT (destination) DO UPDATE SET
-                        failure_ts = EXCLUDED.failure_ts,
-                        retry_last_ts = EXCLUDED.retry_last_ts,
-                        retry_interval = EXCLUDED.retry_interval
-                    WHERE
-                        EXCLUDED.retry_interval = 0
-                        OR destinations.retry_interval IS NULL
-                        OR destinations.retry_interval < EXCLUDED.retry_interval
-            """
-
-            txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
-
-            return
+        txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
 
+    def _set_destination_retry_timings_emulated(
+        self, txn, destination, failure_ts, retry_last_ts, retry_interval
+    ):
         self.database_engine.lock_table(txn, "destinations")
 
         # We need to be careful here as the data may have changed from under us
@@ -266,22 +298,6 @@ class TransactionStore(SQLBaseStore):
                 },
             )
 
-    def _start_cleanup_transactions(self):
-        return run_as_background_process(
-            "cleanup_transactions", self._cleanup_transactions
-        )
-
-    async def _cleanup_transactions(self) -> None:
-        now = self._clock.time_msec()
-        month_ago = now - 30 * 24 * 60 * 60 * 1000
-
-        def _cleanup_transactions_txn(txn):
-            txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
-
-        await self.db_pool.runInteraction(
-            "_cleanup_transactions", _cleanup_transactions_txn
-        )
-
     async def store_destination_rooms_entries(
         self, destinations: Iterable[str], room_id: str, stream_ordering: int,
     ) -> None:
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 3b9211a6d2..79b7ece330 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore):
         )
         return [(row["user_agent"], row["ip"]) for row in rows]
 
-
-class UIAuthStore(UIAuthWorkerStore):
     async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
         """
         Remove sessions which were last used earlier than the expiration time.
@@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore):
             iterable=session_ids,
             keyvalues={},
         )
+
+
+class UIAuthStore(UIAuthWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 5a390ff2f6..d87ceec6da 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -480,21 +480,16 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             user_id_tuples: iterable of 2-tuple of user IDs.
         """
 
-        def _add_users_who_share_room_txn(txn):
-            self.db_pool.simple_upsert_many_txn(
-                txn,
-                table="users_who_share_private_rooms",
-                key_names=["user_id", "other_user_id", "room_id"],
-                key_values=[
-                    (user_id, other_user_id, room_id)
-                    for user_id, other_user_id in user_id_tuples
-                ],
-                value_names=(),
-                value_values=None,
-            )
-
-        await self.db_pool.runInteraction(
-            "add_users_who_share_room", _add_users_who_share_room_txn
+        await self.db_pool.simple_upsert_many(
+            table="users_who_share_private_rooms",
+            key_names=["user_id", "other_user_id", "room_id"],
+            key_values=[
+                (user_id, other_user_id, room_id)
+                for user_id, other_user_id in user_id_tuples
+            ],
+            value_names=(),
+            value_values=None,
+            desc="add_users_who_share_room",
         )
 
     async def add_users_in_public_rooms(
@@ -508,19 +503,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             user_ids
         """
 
-        def _add_users_in_public_rooms_txn(txn):
-
-            self.db_pool.simple_upsert_many_txn(
-                txn,
-                table="users_in_public_rooms",
-                key_names=["user_id", "room_id"],
-                key_values=[(user_id, room_id) for user_id in user_ids],
-                value_names=(),
-                value_values=None,
-            )
-
-        await self.db_pool.runInteraction(
-            "add_users_in_public_rooms", _add_users_in_public_rooms_txn
+        await self.db_pool.simple_upsert_many(
+            table="users_in_public_rooms",
+            key_names=["user_id", "room_id"],
+            key_values=[(user_id, room_id) for user_id in user_ids],
+            value_names=(),
+            value_values=None,
+            desc="add_users_in_public_rooms",
         )
 
     async def delete_all_from_user_dir(self) -> None: