summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2021-03-01 10:06:09 +0000
committerBen Banfield-Zanin <benbz@matrix.org>2021-03-01 10:06:09 +0000
commitb26bee9faf957643cd34c4146b250b0009be205d (patch)
treea7a7e29f30acb437d010bdf6116c0f2729f21a1b /synapse/storage/databases
parentMerge remote-tracking branch 'origin/release-v1.26.0' into toml/keycloak_hints (diff)
parentFixup changelog (diff)
downloadsynapse-toml/keycloak_hints.tar.xz
Merge remote-tracking branch 'origin/release-v1.28.0' into toml/keycloak_hints github/toml/keycloak_hints toml/keycloak_hints
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/__init__.py5
-rw-r--r--synapse/storage/databases/main/__init__.py6
-rw-r--r--synapse/storage/databases/main/appservice.py3
-rw-r--r--synapse/storage/databases/main/client_ips.py12
-rw-r--r--synapse/storage/databases/main/deviceinbox.py2
-rw-r--r--synapse/storage/databases/main/devices.py42
-rw-r--r--synapse/storage/databases/main/directory.py7
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py15
-rw-r--r--synapse/storage/databases/main/event_federation.py19
-rw-r--r--synapse/storage/databases/main/event_push_actions.py22
-rw-r--r--synapse/storage/databases/main/events.py262
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py41
-rw-r--r--synapse/storage/databases/main/events_forward_extremities.py104
-rw-r--r--synapse/storage/databases/main/events_worker.py16
-rw-r--r--synapse/storage/databases/main/group_server.py40
-rw-r--r--synapse/storage/databases/main/keys.py7
-rw-r--r--synapse/storage/databases/main/media_repository.py25
-rw-r--r--synapse/storage/databases/main/metrics.py58
-rw-r--r--synapse/storage/databases/main/presence.py4
-rw-r--r--synapse/storage/databases/main/profile.py6
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py11
-rw-r--r--synapse/storage/databases/main/receipts.py13
-rw-r--r--synapse/storage/databases/main/registration.py81
-rw-r--r--synapse/storage/databases/main/room.py25
-rw-r--r--synapse/storage/databases/main/roommember.py23
-rw-r--r--synapse/storage/databases/main/schema/delta/33/remote_media_ts.py3
-rw-r--r--synapse/storage/databases/main/schema/delta/59/01ignored_user.py2
-rw-r--r--synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite10
-rw-r--r--synapse/storage/databases/main/search.py7
-rw-r--r--synapse/storage/databases/main/state.py11
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/stats.py26
-rw-r--r--synapse/storage/databases/main/stream.py42
-rw-r--r--synapse/storage/databases/main/transactions.py21
-rw-r--r--synapse/storage/databases/main/ui_auth.py22
-rw-r--r--synapse/storage/databases/main/user_directory.py16
-rw-r--r--synapse/storage/databases/state/bg_updates.py2
-rw-r--r--synapse/storage/databases/state/store.py10
40 files changed, 709 insertions, 326 deletions
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0c24325011..e84f8b42f7 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -56,7 +56,10 @@ class Databases:
                     database_config.databases,
                 )
                 prepare_database(
-                    db_conn, engine, hs.config, databases=database_config.databases,
+                    db_conn,
+                    engine,
+                    hs.config,
+                    databases=database_config.databases,
                 )
 
                 database = DatabasePool(hs, database_config, engine)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ae561a2da3..70b49854cf 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
 from .event_federation import EventFederationStore
 from .event_push_actions import EventPushActionsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
+from .events_forward_extremities import EventForwardExtremitiesStore
 from .filtering import FilteringStore
 from .group_server import GroupServerStore
 from .keys import KeyStore
@@ -118,6 +119,7 @@ class DataStore(
     UIAuthStore,
     CacheInvalidationWorkerStore,
     ServerMetricsStore,
+    EventForwardExtremitiesStore,
 ):
     def __init__(self, database: DatabasePool, db_conn, hs):
         self.hs = hs
@@ -338,7 +340,7 @@ class DataStore(
             count = txn.fetchone()[0]
 
             sql = (
-                "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+                "SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url "
                 + sql_base
                 + " ORDER BY u.name LIMIT ? OFFSET ?"
             )
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e550cbc866..03a38422a1 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -73,8 +73,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         return self.services_cache
 
     def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
-        """Check if the user is one associated with an app service (exclusively)
-        """
+        """Check if the user is one associated with an app service (exclusively)"""
         if self.exclusive_user_regex:
             return bool(self.exclusive_user_regex.match(user_id))
         else:
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index ea1e8fb580..6d18e692b0 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -280,8 +280,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         return batch_size
 
     async def _devices_last_seen_update(self, progress, batch_size):
-        """Background update to insert last seen info into devices table
-        """
+        """Background update to insert last seen info into devices table"""
 
         last_user_id = progress.get("last_user_id", "")
         last_device_id = progress.get("last_device_id", "")
@@ -363,8 +362,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
     @wrap_as_background_process("prune_old_user_ips")
     async def _prune_old_user_ips(self):
-        """Removes entries in user IPs older than the configured period.
-        """
+        """Removes entries in user IPs older than the configured period."""
 
         if self.user_ips_max_age is None:
             # Nothing to do
@@ -565,7 +563,11 @@ class ClientIpStore(ClientIpWorkerStore):
         results = {}
 
         for key in self._batch_row_update:
-            uid, access_token, ip, = key
+            (
+                uid,
+                access_token,
+                ip,
+            ) = key
             if uid == user_id:
                 user_agent, _, last_seen = self._batch_row_update[key]
                 results[(access_token, ip)] = (user_agent, last_seen)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 31f70ac5ef..45ca6620a8 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -450,7 +450,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 },
             )
 
-            # Add the messages to the approriate local device inboxes so that
+            # Add the messages to the appropriate local device inboxes so that
             # they'll be sent to the devices when they next sync.
             self._add_messages_to_local_device_inbox_txn(
                 txn, stream_id, local_messages_by_user_then_device
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9097677648..d327e9aa0b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -315,7 +315,8 @@ class DeviceWorkerStore(SQLBaseStore):
 
             # make sure we go through the devices in stream order
             device_ids = sorted(
-                user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+                user_devices.keys(),
+                key=lambda i: query_map[(user_id, i)][0],
             )
 
             for device_id in device_ids:
@@ -366,8 +367,7 @@ class DeviceWorkerStore(SQLBaseStore):
     async def mark_as_sent_devices_by_remote(
         self, destination: str, stream_id: int
     ) -> None:
-        """Mark that updates have successfully been sent to the destination.
-        """
+        """Mark that updates have successfully been sent to the destination."""
         await self.db_pool.runInteraction(
             "mark_as_sent_devices_by_remote",
             self._mark_as_sent_devices_by_remote_txn,
@@ -681,7 +681,8 @@ class DeviceWorkerStore(SQLBaseStore):
         return results
 
     async def get_user_ids_requiring_device_list_resync(
-        self, user_ids: Optional[Collection[str]] = None,
+        self,
+        user_ids: Optional[Collection[str]] = None,
     ) -> Set[str]:
         """Given a list of remote users return the list of users that we
         should resync the device lists for. If None is given instead of a list,
@@ -721,8 +722,7 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
-        """Mark that we no longer track device lists for remote user.
-        """
+        """Mark that we no longer track device lists for remote user."""
 
         def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
             self.db_pool.simple_delete_txn(
@@ -897,12 +897,13 @@ class DeviceWorkerStore(SQLBaseStore):
                 DELETE FROM device_lists_outbound_last_success
                 WHERE destination = ? AND user_id = ?
             """
-            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+            txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
 
             logger.info("Pruned %d device list outbound pokes", count)
 
         await self.db_pool.runInteraction(
-            "_prune_old_outbound_device_pokes", _prune_txn,
+            "_prune_old_outbound_device_pokes",
+            _prune_txn,
         )
 
 
@@ -943,7 +944,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
         # clear out duplicate device list outbound pokes
         self.db_pool.updates.register_background_update_handler(
-            BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
+            BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
+            self._remove_duplicate_outbound_pokes,
         )
 
         # a pair of background updates that were added during the 1.14 release cycle,
@@ -1004,17 +1006,23 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             row = None
             for row in rows:
                 self.db_pool.simple_delete_txn(
-                    txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
+                    txn,
+                    "device_lists_outbound_pokes",
+                    {x: row[x] for x in KEY_COLS},
                 )
 
                 row["sent"] = False
                 self.db_pool.simple_insert_txn(
-                    txn, "device_lists_outbound_pokes", row,
+                    txn,
+                    "device_lists_outbound_pokes",
+                    row,
                 )
 
             if row:
                 self.db_pool.updates._background_update_progress_txn(
-                    txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
+                    txn,
+                    BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
+                    {"last_row": row},
                 )
 
             return len(rows)
@@ -1286,7 +1294,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         # we've done a full resync, so we remove the entry that says we need
         # to resync
         self.db_pool.simple_delete_txn(
-            txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
+            txn,
+            table="device_lists_remote_resync",
+            keyvalues={"user_id": user_id},
         )
 
     async def add_device_change_to_streams(
@@ -1336,14 +1346,16 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         stream_ids: List[str],
     ):
         txn.call_after(
-            self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
+            self._device_list_stream_cache.entity_has_changed,
+            user_id,
+            stream_ids[-1],
         )
 
         min_stream_id = stream_ids[0]
 
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
-        txn.executemany(
+        txn.execute_batch(
             """
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index e5060d4c46..267b948397 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -85,7 +85,7 @@ class DirectoryStore(DirectoryWorkerStore):
         servers: Iterable[str],
         creator: Optional[str] = None,
     ) -> None:
-        """ Creates an association between a room alias and room_id/servers
+        """Creates an association between a room alias and room_id/servers
 
         Args:
             room_alias: The alias to create.
@@ -160,7 +160,10 @@ class DirectoryStore(DirectoryWorkerStore):
         return room_id
 
     async def update_aliases_for_room(
-        self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
+        self,
+        old_room_id: str,
+        new_room_id: str,
+        creator: Optional[str] = None,
     ) -> None:
         """Repoint all of the aliases for a given room, to a different room.
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c128889bf9..f1e7859d26 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -361,7 +361,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
     async def count_e2e_one_time_keys(
         self, user_id: str, device_id: str
     ) -> Dict[str, int]:
-        """ Count the number of one time keys the server has for a device
+        """Count the number of one time keys the server has for a device
         Returns:
             A mapping from algorithm to number of keys for that algorithm.
         """
@@ -494,7 +494,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         )
 
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
-        self, txn: Connection, user_ids: List[str],
+        self,
+        txn: Connection,
+        user_ids: List[str],
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
@@ -556,7 +558,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return result
 
     def _get_e2e_cross_signing_signatures_txn(
-        self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+        self,
+        txn: Connection,
+        keys: Dict[str, Dict[str, dict]],
+        from_user_id: str,
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing signatures made by a user on a set of keys.
 
@@ -634,7 +639,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, dict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -724,7 +729,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str]]
-    ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+    ) -> Dict[str, Dict[str, Dict[str, str]]]:
         """Take a list of one time keys out of the database.
 
         Args:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 8326640d20..18ddb92fcc 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -71,7 +71,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         return await self.get_events_as_list(event_ids)
 
     async def get_auth_chain_ids(
-        self, event_ids: Collection[str], include_given: bool = False,
+        self,
+        event_ids: Collection[str],
+        include_given: bool = False,
     ) -> List[str]:
         """Get auth events for given event_ids. The events *must* be state events.
 
@@ -273,7 +275,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                     # origin chain.
                     if origin_sequence_number <= chains.get(origin_chain_id, 0):
                         chains[target_chain_id] = max(
-                            target_sequence_number, chains.get(target_chain_id, 0),
+                            target_sequence_number,
+                            chains.get(target_chain_id, 0),
                         )
 
                 seen_chains.add(target_chain_id)
@@ -371,7 +374,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         # and state sets {A} and {B} then walking the auth chains of A and B
         # would immediately show that C is reachable by both. However, if we
         # stopped at C then we'd only reach E via the auth chain of B and so E
-        # would errornously get included in the returned difference.
+        # would erroneously get included in the returned difference.
         #
         # The other thing that we do is limit the number of auth chains we walk
         # at once, due to practical limits (i.e. we can only query the database
@@ -497,7 +500,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
                         a_ids = new_aids
 
-                # Mark that the auth event is reachable by the approriate sets.
+                # Mark that the auth event is reachable by the appropriate sets.
                 sets.intersection_update(event_to_missing_sets[event_id])
 
             search.sort()
@@ -632,8 +635,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         )
 
     async def get_min_depth(self, room_id: str) -> int:
-        """For the given room, get the minimum depth we have seen for it.
-        """
+        """For the given room, get the minimum depth we have seen for it."""
         return await self.db_pool.runInteraction(
             "get_min_depth", self._get_min_depth_interaction, room_id
         )
@@ -858,12 +860,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             )
 
         await self.db_pool.runInteraction(
-            "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
+            "_delete_old_forward_extrem_cache",
+            _delete_old_forward_extrem_cache_txn,
         )
 
 
 class EventFederationStore(EventFederationWorkerStore):
-    """ Responsible for storing and serving up the various graphs associated
+    """Responsible for storing and serving up the various graphs associated
     with an event. Including the main event graph and the auth chains for an
     event.
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 1b657191a9..78245ad5bd 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -54,8 +54,7 @@ def _serialize_action(actions, is_highlight):
 
 
 def _deserialize_action(actions, is_highlight):
-    """Custom deserializer for actions. This allows us to "compress" common actions
-    """
+    """Custom deserializer for actions. This allows us to "compress" common actions"""
     if actions:
         return db_to_json(actions)
 
@@ -91,7 +90,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     @cached(num_args=3, tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
-        self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+        self,
+        room_id: str,
+        user_id: str,
+        last_read_event_id: Optional[str],
     ) -> Dict[str, int]:
         """Get the notification count, the highlight count and the unread message count
         for a given user in a given room after the given read receipt.
@@ -120,13 +122,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     def _get_unread_counts_by_receipt_txn(
-        self, txn, room_id, user_id, last_read_event_id,
+        self,
+        txn,
+        room_id,
+        user_id,
+        last_read_event_id,
     ):
         stream_ordering = None
 
         if last_read_event_id is not None:
             stream_ordering = self.get_stream_id_for_event_txn(
-                txn, last_read_event_id, allow_none=True,
+                txn,
+                last_read_event_id,
+                allow_none=True,
             )
 
         if stream_ordering is None:
@@ -487,7 +495,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 VALUES (?, ?, ?, ?, ?, ?)
             """
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     _gen_entry(user_id, actions)
@@ -803,7 +811,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             ],
         )
 
-        txn.executemany(
+        txn.execute_batch(
             """
                 UPDATE event_push_summary
                 SET notif_count = ?, unread_count = ?, stream_ordering = ?
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3216b3f3c8..287606cb4f 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -399,7 +399,9 @@ class PersistEventsStore:
         self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
 
     def _persist_event_auth_chain_txn(
-        self, txn: LoggingTransaction, events: List[EventBase],
+        self,
+        txn: LoggingTransaction,
+        events: List[EventBase],
     ) -> None:
 
         # We only care about state events, so this if there are no state events.
@@ -470,11 +472,16 @@ class PersistEventsStore:
         event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
 
         self._add_chain_cover_index(
-            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+            txn,
+            self.db_pool,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
         )
 
-    @staticmethod
+    @classmethod
     def _add_chain_cover_index(
+        cls,
         txn,
         db_pool: DatabasePool,
         event_to_room_id: Dict[str, str],
@@ -516,7 +523,10 @@ class PersistEventsStore:
             # simple_select_many, but this case happens rarely and almost always
             # with a single row.)
             auth_events = db_pool.simple_select_onecol_txn(
-                txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
+                txn,
+                "event_auth",
+                keyvalues={"event_id": event_id},
+                retcol="auth_id",
             )
 
             events_to_calc_chain_id_for.add(event_id)
@@ -549,7 +559,9 @@ class PersistEventsStore:
                 WHERE
             """
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "event_id", missing_auth_chains,
+                txn.database_engine,
+                "event_id",
+                missing_auth_chains,
             )
             txn.execute(sql + clause, args)
 
@@ -614,60 +626,17 @@ class PersistEventsStore:
         if not events_to_calc_chain_id_for:
             return
 
-        # We now calculate the chain IDs/sequence numbers for the events. We
-        # do this by looking at the chain ID and sequence number of any auth
-        # event with the same type/state_key and incrementing the sequence
-        # number by one. If there was no match or the chain ID/sequence
-        # number is already taken we generate a new chain.
-        #
-        # We need to do this in a topologically sorted order as we want to
-        # generate chain IDs/sequence numbers of an event's auth events
-        # before the event itself.
-        chains_tuples_allocated = set()  # type: Set[Tuple[int, int]]
-        new_chain_tuples = {}  # type: Dict[str, Tuple[int, int]]
-        for event_id in sorted_topologically(
-            events_to_calc_chain_id_for, event_to_auth_chain
-        ):
-            existing_chain_id = None
-            for auth_id in event_to_auth_chain.get(event_id, []):
-                if event_to_types.get(event_id) == event_to_types.get(auth_id):
-                    existing_chain_id = chain_map[auth_id]
-                    break
-
-            new_chain_tuple = None
-            if existing_chain_id:
-                # We found a chain ID/sequence number candidate, check its
-                # not already taken.
-                proposed_new_id = existing_chain_id[0]
-                proposed_new_seq = existing_chain_id[1] + 1
-                if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
-                    already_allocated = db_pool.simple_select_one_onecol_txn(
-                        txn,
-                        table="event_auth_chains",
-                        keyvalues={
-                            "chain_id": proposed_new_id,
-                            "sequence_number": proposed_new_seq,
-                        },
-                        retcol="event_id",
-                        allow_none=True,
-                    )
-                    if already_allocated:
-                        # Mark it as already allocated so we don't need to hit
-                        # the DB again.
-                        chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
-                    else:
-                        new_chain_tuple = (
-                            proposed_new_id,
-                            proposed_new_seq,
-                        )
-
-            if not new_chain_tuple:
-                new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
-
-            chains_tuples_allocated.add(new_chain_tuple)
-
-            chain_map[event_id] = new_chain_tuple
-            new_chain_tuples[event_id] = new_chain_tuple
+        # Allocate chain ID/sequence numbers to each new event.
+        new_chain_tuples = cls._allocate_chain_ids(
+            txn,
+            db_pool,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
+            events_to_calc_chain_id_for,
+            chain_map,
+        )
+        chain_map.update(new_chain_tuples)
 
         db_pool.simple_insert_many_txn(
             txn,
@@ -746,7 +715,8 @@ class PersistEventsStore:
                 if chain_map[a_id][0] != chain_id
             }
             for start_auth_id, end_auth_id in itertools.permutations(
-                event_to_auth_chain.get(event_id, []), r=2,
+                event_to_auth_chain.get(event_id, []),
+                r=2,
             ):
                 if chain_links.exists_path_from(
                     chain_map[start_auth_id], chain_map[end_auth_id]
@@ -794,13 +764,143 @@ class PersistEventsStore:
             ],
         )
 
+    @staticmethod
+    def _allocate_chain_ids(
+        txn,
+        db_pool: DatabasePool,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+        events_to_calc_chain_id_for: Set[str],
+        chain_map: Dict[str, Tuple[int, int]],
+    ) -> Dict[str, Tuple[int, int]]:
+        """Allocates, but does not persist, chain ID/sequence numbers for the
+        events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+        for info on args)
+        """
+
+        # We now calculate the chain IDs/sequence numbers for the events. We do
+        # this by looking at the chain ID and sequence number of any auth event
+        # with the same type/state_key and incrementing the sequence number by
+        # one. If there was no match or the chain ID/sequence number is already
+        # taken we generate a new chain.
+        #
+        # We try to reduce the number of times that we hit the database by
+        # batching up calls, to make this more efficient when persisting large
+        # numbers of state events (e.g. during joins).
+        #
+        # We do this by:
+        #   1. Calculating for each event which auth event will be used to
+        #      inherit the chain ID, i.e. converting the auth chain graph to a
+        #      tree that we can allocate chains on. We also keep track of which
+        #      existing chain IDs have been referenced.
+        #   2. Fetching the max allocated sequence number for each referenced
+        #      existing chain ID, generating a map from chain ID to the max
+        #      allocated sequence number.
+        #   3. Iterating over the tree and allocating a chain ID/seq no. to the
+        #      new event, by incrementing the sequence number from the
+        #      referenced event's chain ID/seq no. and checking that the
+        #      incremented sequence number hasn't already been allocated (by
+        #      looking in the map generated in the previous step). We generate a
+        #      new chain if the sequence number has already been allocated.
+        #
+
+        existing_chains = set()  # type: Set[int]
+        tree = []  # type: List[Tuple[str, Optional[str]]]
+
+        # We need to do this in a topologically sorted order as we want to
+        # generate chain IDs/sequence numbers of an event's auth events before
+        # the event itself.
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            for auth_id in event_to_auth_chain.get(event_id, []):
+                if event_to_types.get(event_id) == event_to_types.get(auth_id):
+                    existing_chain_id = chain_map.get(auth_id)
+                    if existing_chain_id:
+                        existing_chains.add(existing_chain_id[0])
+
+                    tree.append((event_id, auth_id))
+                    break
+            else:
+                tree.append((event_id, None))
+
+        # Fetch the current max sequence number for each existing referenced chain.
+        sql = """
+            SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+            WHERE %s
+            GROUP BY chain_id
+        """
+        clause, args = make_in_list_sql_clause(
+            db_pool.engine, "chain_id", existing_chains
+        )
+        txn.execute(sql % (clause,), args)
+
+        chain_to_max_seq_no = {row[0]: row[1] for row in txn}  # type: Dict[Any, int]
+
+        # Allocate the new events chain ID/sequence numbers.
+        #
+        # To reduce the number of calls to the database we don't allocate a
+        # chain ID number in the loop, instead we use a temporary `object()` for
+        # each new chain ID. Once we've done the loop we generate the necessary
+        # number of new chain IDs in one call, replacing all temporary
+        # objects with real allocated chain IDs.
+
+        unallocated_chain_ids = set()  # type: Set[object]
+        new_chain_tuples = {}  # type: Dict[str, Tuple[Any, int]]
+        for event_id, auth_event_id in tree:
+            # If we reference an auth_event_id we fetch the allocated chain ID,
+            # either from the existing `chain_map` or the newly generated
+            # `new_chain_tuples` map.
+            existing_chain_id = None
+            if auth_event_id:
+                existing_chain_id = new_chain_tuples.get(auth_event_id)
+                if not existing_chain_id:
+                    existing_chain_id = chain_map[auth_event_id]
+
+            new_chain_tuple = None  # type: Optional[Tuple[Any, int]]
+            if existing_chain_id:
+                # We found a chain ID/sequence number candidate, check its
+                # not already taken.
+                proposed_new_id = existing_chain_id[0]
+                proposed_new_seq = existing_chain_id[1] + 1
+
+                if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+                    new_chain_tuple = (
+                        proposed_new_id,
+                        proposed_new_seq,
+                    )
+
+            # If we need to start a new chain we allocate a temporary chain ID.
+            if not new_chain_tuple:
+                new_chain_tuple = (object(), 1)
+                unallocated_chain_ids.add(new_chain_tuple[0])
+
+            new_chain_tuples[event_id] = new_chain_tuple
+            chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+        # Generate new chain IDs for all unallocated chain IDs.
+        newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+            txn, len(unallocated_chain_ids)
+        )
+
+        # Map from potentially temporary chain ID to real chain ID
+        chain_id_to_allocated_map = dict(
+            zip(unallocated_chain_ids, newly_allocated_chain_ids)
+        )  # type: Dict[Any, int]
+        chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+        return {
+            event_id: (chain_id_to_allocated_map[chain_id], seq)
+            for event_id, (chain_id, seq) in new_chain_tuples.items()
+        }
+
     def _persist_transaction_ids_txn(
         self,
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
     ):
-        """Persist the mapping from transaction IDs to event IDs (if defined).
-        """
+        """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
         to_insert = []
         for event, _ in events_and_contexts:
@@ -820,7 +920,9 @@ class PersistEventsStore:
 
         if to_insert:
             self.db_pool.simple_insert_many_txn(
-                txn, table="event_txn_id", values=to_insert,
+                txn,
+                table="event_txn_id",
+                values=to_insert,
             )
 
     def _update_current_state_txn(
@@ -852,7 +954,9 @@ class PersistEventsStore:
                 txn.execute(sql, (stream_id, self._instance_name, room_id))
 
                 self.db_pool.simple_delete_txn(
-                    txn, table="current_state_events", keyvalues={"room_id": room_id},
+                    txn,
+                    table="current_state_events",
+                    keyvalues={"room_id": room_id},
                 )
             else:
                 # We're still in the room, so we update the current state as normal.
@@ -876,7 +980,7 @@ class PersistEventsStore:
                         WHERE room_id = ? AND type = ? AND state_key = ?
                     )
                 """
-                txn.executemany(
+                txn.execute_batch(
                     sql,
                     (
                         (
@@ -895,7 +999,7 @@ class PersistEventsStore:
                 )
                 # Now we actually update the current_state_events table
 
-                txn.executemany(
+                txn.execute_batch(
                     "DELETE FROM current_state_events"
                     " WHERE room_id = ? AND type = ? AND state_key = ?",
                     (
@@ -907,7 +1011,7 @@ class PersistEventsStore:
                 # We include the membership in the current state table, hence we do
                 # a lookup when we insert. This assumes that all events have already
                 # been inserted into room_memberships.
-                txn.executemany(
+                txn.execute_batch(
                     """INSERT INTO current_state_events
                         (room_id, type, state_key, event_id, membership)
                     VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -927,7 +1031,7 @@ class PersistEventsStore:
             # we have no record of the fact the user *was* a member of the
             # room but got, say, state reset out of it.
             if to_delete or to_insert:
-                txn.executemany(
+                txn.execute_batch(
                     "DELETE FROM local_current_membership"
                     " WHERE room_id = ? AND user_id = ?",
                     (
@@ -938,7 +1042,7 @@ class PersistEventsStore:
                 )
 
             if to_insert:
-                txn.executemany(
+                txn.execute_batch(
                     """INSERT INTO local_current_membership
                         (room_id, user_id, event_id, membership)
                     VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -961,7 +1065,7 @@ class PersistEventsStore:
             # Figure out the changes of membership to invalidate the
             # `get_rooms_for_user` cache.
             # We find out which membership events we may have deleted
-            # and which we have added, then we invlidate the caches for all
+            # and which we have added, then we invalidate the caches for all
             # those users.
             members_changed = {
                 state_key
@@ -1519,8 +1623,7 @@ class PersistEventsStore:
         )
 
     def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database.
-        """
+        """Store a room member in the database."""
 
         def str_or_none(val: Any) -> Optional[str]:
             return val if isinstance(val, str) else None
@@ -1738,7 +1841,7 @@ class PersistEventsStore:
         """
 
         if events_and_contexts:
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (
@@ -1767,7 +1870,7 @@ class PersistEventsStore:
 
         # Now we delete the staging area for *all* events that were being
         # persisted.
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM event_push_actions_staging WHERE event_id = ?",
             ((event.event_id,) for event, _ in all_events_and_contexts),
         )
@@ -1886,7 +1989,7 @@ class PersistEventsStore:
             " )"
         )
 
-        txn.executemany(
+        txn.execute_batch(
             query,
             [
                 (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1900,7 +2003,7 @@ class PersistEventsStore:
             "DELETE FROM event_backward_extremities"
             " WHERE event_id = ? AND room_id = ?"
         )
-        txn.executemany(
+        txn.execute_batch(
             query,
             [
                 (ev.event_id, ev.room_id)
@@ -1912,8 +2015,7 @@ class PersistEventsStore:
 
 @attr.s(slots=True)
 class _LinkMap:
-    """A helper type for tracking links between chains.
-    """
+    """A helper type for tracking links between chains."""
 
     # Stores the set of links as nested maps: source chain ID -> target chain ID
     # -> source sequence number -> target sequence number.
@@ -2019,7 +2121,9 @@ class _LinkMap:
                 yield (src_chain, src_seq, target_chain, target_seq)
 
     def exists_path_from(
-        self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
+        self,
+        src_tuple: Tuple[int, int],
+        target_tuple: Tuple[int, int],
     ) -> bool:
         """Checks if there is a path between the source chain ID/sequence and
         target chain ID/sequence.
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e46e44ba54..89274e75f7 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -32,8 +32,7 @@ logger = logging.getLogger(__name__)
 
 @attr.s(slots=True, frozen=True)
 class _CalculateChainCover:
-    """Return value for _calculate_chain_cover_txn.
-    """
+    """Return value for _calculate_chain_cover_txn."""
 
     # The last room_id/depth/stream processed.
     room_id = attr.ib(type=str)
@@ -127,11 +126,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         )
 
         self.db_pool.updates.register_background_update_handler(
-            "rejected_events_metadata", self._rejected_events_metadata,
+            "rejected_events_metadata",
+            self._rejected_events_metadata,
         )
 
         self.db_pool.updates.register_background_update_handler(
-            "chain_cover", self._chain_cover_index,
+            "chain_cover",
+            self._chain_cover_index,
         )
 
     async def _background_reindex_fields_sender(self, progress, batch_size):
@@ -139,8 +140,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id, json FROM events"
@@ -178,9 +177,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
 
-            for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
-                clump = update_rows[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, update_rows)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -210,8 +207,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_search_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id FROM events"
@@ -256,9 +251,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
 
-            for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
-                clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, rows_to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -470,8 +463,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         return num_handled
 
     async def _redactions_received_ts(self, progress, batch_size):
-        """Handles filling out the `received_ts` column in redactions.
-        """
+        """Handles filling out the `received_ts` column in redactions."""
         last_event_id = progress.get("last_event_id", "")
 
         def _redactions_received_ts_txn(txn):
@@ -526,8 +518,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         return count
 
     async def _event_fix_redactions_bytes(self, progress, batch_size):
-        """Undoes hex encoded censored redacted event JSON.
-        """
+        """Undoes hex encoded censored redacted event JSON."""
 
         def _event_fix_redactions_bytes_txn(txn):
             # This update is quite fast due to new index.
@@ -650,7 +641,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 LIMIT ?
             """
 
-            txn.execute(sql, (last_event_id, batch_size,))
+            txn.execute(
+                sql,
+                (
+                    last_event_id,
+                    batch_size,
+                ),
+            )
 
             return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn]  # type: ignore
 
@@ -918,7 +915,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         # Annoyingly we need to gut wrench into the persit event store so that
         # we can reuse the function to calculate the chain cover for rooms.
         PersistEventsStore._add_chain_cover_index(
-            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+            txn,
+            self.db_pool,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
         )
 
         return _CalculateChainCover(
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
new file mode 100644
index 0000000000..b3703ae161
--- /dev/null
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventForwardExtremitiesStore(SQLBaseStore):
+    async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+        """Delete any extra forward extremities for a room.
+
+        Invalidates the "get_latest_event_ids_in_room" cache if any forward
+        extremities were deleted.
+
+        Returns count deleted.
+        """
+
+        def delete_forward_extremities_for_room_txn(txn):
+            # First we need to get the event_id to not delete
+            sql = """
+                SELECT event_id FROM event_forward_extremities
+                INNER JOIN events USING (room_id, event_id)
+                WHERE room_id = ?
+                ORDER BY stream_ordering DESC
+                LIMIT 1
+            """
+            txn.execute(sql, (room_id,))
+            rows = txn.fetchall()
+            try:
+                event_id = rows[0][0]
+                logger.debug(
+                    "Found event_id %s as the forward extremity to keep for room %s",
+                    event_id,
+                    room_id,
+                )
+            except KeyError:
+                msg = "No forward extremity event found for room %s" % room_id
+                logger.warning(msg)
+                raise SynapseError(400, msg)
+
+            # Now delete the extra forward extremities
+            sql = """
+                DELETE FROM event_forward_extremities
+                WHERE event_id != ? AND room_id = ?
+            """
+
+            txn.execute(sql, (event_id, room_id))
+            logger.info(
+                "Deleted %s extra forward extremities for room %s",
+                txn.rowcount,
+                room_id,
+            )
+
+            if txn.rowcount > 0:
+                # Invalidate the cache
+                self._invalidate_cache_and_stream(
+                    txn,
+                    self.get_latest_event_ids_in_room,
+                    (room_id,),
+                )
+
+            return txn.rowcount
+
+        return await self.db_pool.runInteraction(
+            "delete_forward_extremities_for_room",
+            delete_forward_extremities_for_room_txn,
+        )
+
+    async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+        """Get list of forward extremities for a room."""
+
+        def get_forward_extremities_for_room_txn(txn):
+            sql = """
+                SELECT event_id, state_group, depth, received_ts
+                FROM event_forward_extremities
+                INNER JOIN event_to_state_groups USING (event_id)
+                INNER JOIN events USING (room_id, event_id)
+                WHERE room_id = ?
+            """
+
+            txn.execute(sql, (room_id,))
+            return self.db_pool.cursor_to_dict(txn)
+
+        return await self.db_pool.runInteraction(
+            "get_forward_extremities_for_room",
+            get_forward_extremities_for_room_txn,
+        )
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 71d823be72..c8850a4707 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -120,7 +120,9 @@ class EventsWorkerStore(SQLBaseStore):
             # SQLite).
             if hs.get_instance_name() in hs.config.worker.writers.events:
                 self._stream_id_gen = StreamIdGenerator(
-                    db_conn, "events", "stream_ordering",
+                    db_conn,
+                    "events",
+                    "stream_ordering",
                 )
                 self._backfill_id_gen = StreamIdGenerator(
                     db_conn,
@@ -140,7 +142,8 @@ class EventsWorkerStore(SQLBaseStore):
         if hs.config.run_background_tasks:
             # We periodically clean out old transaction ID mappings
             self._clock.looping_call(
-                self._cleanup_old_transaction_ids, 5 * 60 * 1000,
+                self._cleanup_old_transaction_ids,
+                5 * 60 * 1000,
             )
 
         self._get_event_cache = LruCache(
@@ -1325,8 +1328,7 @@ class EventsWorkerStore(SQLBaseStore):
         return rows, to_token, True
 
     async def is_event_after(self, event_id1, event_id2):
-        """Returns True if event_id1 is after event_id2 in the stream
-        """
+        """Returns True if event_id1 is after event_id2 in the stream"""
         to_1, so_1 = await self.get_event_ordering(event_id1)
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
@@ -1428,8 +1430,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     @wrap_as_background_process("_cleanup_old_transaction_ids")
     async def _cleanup_old_transaction_ids(self):
-        """Cleans out transaction id mappings older than 24hrs.
-        """
+        """Cleans out transaction id mappings older than 24hrs."""
 
         def _cleanup_old_transaction_ids_txn(txn):
             sql = """
@@ -1440,5 +1441,6 @@ class EventsWorkerStore(SQLBaseStore):
             txn.execute(sql, (one_day_ago,))
 
         return await self.db_pool.runInteraction(
-            "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
+            "_cleanup_old_transaction_ids",
+            _cleanup_old_transaction_ids_txn,
         )
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 7218191965..ac07e0197b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple
+
+from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -26,6 +28,9 @@ from synapse.util import json_encoder
 _DEFAULT_CATEGORY_ID = ""
 _DEFAULT_ROLE_ID = ""
 
+# A room in a group.
+_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
+
 
 class GroupServerWorkerStore(SQLBaseStore):
     async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
@@ -72,7 +77,7 @@ class GroupServerWorkerStore(SQLBaseStore):
 
     async def get_rooms_in_group(
         self, group_id: str, include_private: bool = False
-    ) -> List[Dict[str, Union[str, bool]]]:
+    ) -> List[_RoomInGroup]:
         """Retrieve the rooms that belong to a given group. Does not return rooms that
         lack members.
 
@@ -123,7 +128,9 @@ class GroupServerWorkerStore(SQLBaseStore):
         )
 
     async def get_rooms_for_summary_by_category(
-        self, group_id: str, include_private: bool = False,
+        self,
+        group_id: str,
+        include_private: bool = False,
     ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
         """Get the rooms and categories that should be included in a summary request
 
@@ -368,8 +375,7 @@ class GroupServerWorkerStore(SQLBaseStore):
     async def is_user_invited_to_local_group(
         self, group_id: str, user_id: str
     ) -> Optional[bool]:
-        """Has the group server invited a user?
-        """
+        """Has the group server invited a user?"""
         return await self.db_pool.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id, "user_id": user_id},
@@ -427,8 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
         )
 
     async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
-        """Get all groups a user is publicising
-        """
+        """Get all groups a user is publicising"""
         return await self.db_pool.simple_select_onecol(
             table="local_group_membership",
             keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
@@ -437,8 +442,7 @@ class GroupServerWorkerStore(SQLBaseStore):
         )
 
     async def get_attestations_need_renewals(self, valid_until_ms):
-        """Get all attestations that need to be renewed until givent time
-        """
+        """Get all attestations that need to be renewed until givent time"""
 
         def _get_attestations_need_renewals_txn(txn):
             sql = """
@@ -781,8 +785,7 @@ class GroupServerStore(GroupServerWorkerStore):
         profile: Optional[JsonDict],
         is_public: Optional[bool],
     ) -> None:
-        """Add/update room category for group
-        """
+        """Add/update room category for group"""
         insertion_values = {}
         update_values = {"category_id": category_id}  # This cannot be empty
 
@@ -818,8 +821,7 @@ class GroupServerStore(GroupServerWorkerStore):
         profile: Optional[JsonDict],
         is_public: Optional[bool],
     ) -> None:
-        """Add/remove user role
-        """
+        """Add/remove user role"""
         insertion_values = {}
         update_values = {"role_id": role_id}  # This cannot be empty
 
@@ -1012,8 +1014,7 @@ class GroupServerStore(GroupServerWorkerStore):
         )
 
     async def add_group_invite(self, group_id: str, user_id: str) -> None:
-        """Record that the group server has invited a user
-        """
+        """Record that the group server has invited a user"""
         await self.db_pool.simple_insert(
             table="group_invites",
             values={"group_id": group_id, "user_id": user_id},
@@ -1156,8 +1157,7 @@ class GroupServerStore(GroupServerWorkerStore):
     async def update_group_publicity(
         self, group_id: str, user_id: str, publicise: bool
     ) -> None:
-        """Update whether the user is publicising their membership of the group
-        """
+        """Update whether the user is publicising their membership of the group"""
         await self.db_pool.simple_update_one(
             table="local_group_membership",
             keyvalues={"group_id": group_id, "user_id": user_id},
@@ -1300,8 +1300,7 @@ class GroupServerStore(GroupServerWorkerStore):
     async def update_attestation_renewal(
         self, group_id: str, user_id: str, attestation: dict
     ) -> None:
-        """Update an attestation that we have renewed
-        """
+        """Update an attestation that we have renewed"""
         await self.db_pool.simple_update_one(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
@@ -1312,8 +1311,7 @@ class GroupServerStore(GroupServerWorkerStore):
     async def update_remote_attestion(
         self, group_id: str, user_id: str, attestation: dict
     ) -> None:
-        """Update an attestation that a remote has renewed
-        """
+        """Update an attestation that a remote has renewed"""
         await self.db_pool.simple_update_one(
             table="group_attestations_remote",
             keyvalues={"group_id": group_id, "user_id": user_id},
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 04ac2d0ced..d504323b03 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -33,8 +33,7 @@ db_binary_type = memoryview
 
 
 class KeyStore(SQLBaseStore):
-    """Persistence for signature verification keys
-    """
+    """Persistence for signature verification keys"""
 
     @cached()
     def _get_server_verify_key(self, server_name_and_key_id):
@@ -155,7 +154,7 @@ class KeyStore(SQLBaseStore):
         (server_name, key_id, from_server) triplet if one already existed.
         Args:
             server_name: The name of the server.
-            key_id: The identifer of the key this JSON is for.
+            key_id: The identifier of the key this JSON is for.
             from_server: The server this JSON was fetched from.
             ts_now_ms: The time now in milliseconds.
             ts_valid_until_ms: The time when this json stops being valid.
@@ -182,7 +181,7 @@ class KeyStore(SQLBaseStore):
     async def get_server_keys_json(
         self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
     ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
-        """Retrive the key json for a list of server_keys and key ids.
+        """Retrieve the key json for a list of server_keys and key ids.
         If no keys are found for a given server, key_id and source then
         that server, key_id, and source triplet entry will be an empty list.
         The JSON is returned as a byte array so that it can be efficiently
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 283c8a5e22..a0313c3ccf 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -169,7 +169,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     async def get_local_media_before(
-        self, before_ts: int, size_gt: int, keep_profiles: bool,
+        self,
+        before_ts: int,
+        size_gt: int,
+        keep_profiles: bool,
     ) -> List[str]:
 
         # to find files that have never been accessed (last_access_ts IS NULL)
@@ -417,7 +420,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_origin = ? AND media_id = ?"
             )
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (time_ms, media_origin, media_id)
@@ -430,7 +433,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_id = ?"
             )
 
-            txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+            txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
 
         return await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
@@ -454,10 +457,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     async def get_remote_media_thumbnail(
-        self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+        self,
+        origin: str,
+        media_id: str,
+        t_width: int,
+        t_height: int,
+        t_type: str,
     ) -> Optional[Dict[str, Any]]:
-        """Fetch the thumbnail info of given width, height and type.
-        """
+        """Fetch the thumbnail info of given width, height and type."""
 
         return await self.db_pool.simple_select_one(
             table="remote_media_cache_thumbnails",
@@ -557,7 +564,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
         def _delete_url_cache_txn(txn):
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache", _delete_url_cache_txn
@@ -586,11 +593,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def _delete_url_cache_media_txn(txn):
             sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
             sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..614a418a15 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
+    async def count_daily_e2ee_messages(self):
+        """
+        Returns an estimate of the number of messages sent in the last day.
+
+        If it has been significantly less or more than one day since the last
+        call to this function, it will return None.
+        """
+
+        def _count_messages(txn):
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
+
+    async def count_daily_sent_e2ee_messages(self):
+        def _count_messages(txn):
+            # This is good enough as if you have silly characters in your own
+            # hostname then that's your own fault.
+            like_clause = "%:" + self.hs.hostname
+
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                    AND sender LIKE ?
+                AND stream_ordering > ?
+            """
+
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction(
+            "count_daily_sent_e2ee_messages", _count_messages
+        )
+
+    async def count_daily_active_e2ee_rooms(self):
+        def _count(txn):
+            sql = """
+                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction(
+            "count_daily_active_e2ee_rooms", _count
+        )
+
     async def count_daily_messages(self):
         """
         Returns an estimate of the number of messages sent in the last day.
@@ -111,7 +167,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     async def count_daily_sent_messages(self):
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
-            # hostname then thats your own fault.
+            # hostname then that's your own fault.
             like_clause = "%:" + self.hs.hostname
 
             sql = """
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index dbbb99cb95..29edab34d4 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -130,7 +130,9 @@ class PresenceStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
+        cached_method_name="_get_presence_for_user",
+        list_name="user_ids",
+        num_args=1,
     )
     async def get_presence_for_users(self, user_ids):
         rows = await self.db_pool.simple_select_many_batch(
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 54ef0f1f54..ba01d3108a 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -118,8 +118,7 @@ class ProfileWorkerStore(SQLBaseStore):
             )
 
     async def is_subscribed_remote_profile_for_user(self, user_id):
-        """Check whether we are interested in a remote user's profile.
-        """
+        """Check whether we are interested in a remote user's profile."""
         res = await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"user_id": user_id},
@@ -145,8 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
     async def get_remote_profile_cache_entries_that_expire(
         self, last_checked: int
     ) -> List[Dict[str, str]]:
-        """Get all users who haven't been checked since `last_checked`
-        """
+        """Get all users who haven't been checked since `last_checked`"""
 
         def _get_remote_profile_cache_entries_that_expire_txn(txn):
             sql = """
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         )
 
         # Update backward extremeties
-        txn.executemany(
+        txn.execute_batch(
             "INSERT INTO event_backward_extremities (room_id, event_id)"
             " VALUES (?, ?)",
             [(room_id, event_id) for event_id, in new_backwards_extrems],
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 711d5aa23d..9e58dc0e6a 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -168,7 +168,9 @@ class PushRulesWorkerStore(
             )
 
     @cachedList(
-        cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
+        cached_method_name="get_push_rules_for_user",
+        list_name="user_ids",
+        num_args=1,
     )
     async def bulk_get_push_rules(self, user_ids):
         if not user_ids:
@@ -195,7 +197,9 @@ class PushRulesWorkerStore(
             use_new_defaults = user_id in self._users_new_default_push_rules
 
             results[user_id] = _load_rules(
-                rules, enabled_map_by_user.get(user_id, {}), use_new_defaults,
+                rules,
+                enabled_map_by_user.get(user_id, {}),
+                use_new_defaults,
             )
 
         return results
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bc7621b8d6..7cb69dd6bd 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -179,7 +179,9 @@ class PusherWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
+        cached_method_name="get_if_user_has_pusher",
+        list_name="user_ids",
+        num_args=1,
     )
     async def get_if_users_have_pushers(
         self, user_ids: Iterable[str]
@@ -263,7 +265,8 @@ class PusherWorkerStore(SQLBaseStore):
         params_by_room = {}
         for row in res:
             params_by_room[row["room_id"]] = ThrottleParams(
-                row["last_sent_ts"], row["throttle_ms"],
+                row["last_sent_ts"],
+                row["throttle_ms"],
             )
 
         return params_by_room
@@ -344,7 +347,9 @@ class PusherStore(PusherWorkerStore):
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
-            self.db_pool.simple_delete_one_txn(
+            # It is expected that there is exactly one pusher to delete, but
+            # if it isn't there (or there are multiple) delete them all.
+            self.db_pool.simple_delete_txn(
                 txn,
                 "pushers",
                 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e4843a202c..43c852c96c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -160,7 +160,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         Args:
             room_id: List of room_ids.
-            to_key: Max stream id to fetch receipts upto.
+            to_key: Max stream id to fetch receipts up to.
             from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
@@ -189,7 +189,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         Args:
             room_ids: The room id.
-            to_key: Max stream id to fetch receipts upto.
+            to_key: Max stream id to fetch receipts up to.
             from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
@@ -208,8 +208,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
     ) -> List[dict]:
-        """See get_linearized_receipts_for_room
-        """
+        """See get_linearized_receipts_for_room"""
 
         def f(txn):
             if from_key:
@@ -304,7 +303,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         }
         return results
 
-    @cached(num_args=2,)
+    @cached(
+        num_args=2,
+    )
     async def get_linearized_receipts_for_all_rooms(
         self, to_key: int, from_key: Optional[int] = None
     ) -> Dict[str, JsonDict]:
@@ -312,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         to a limit of the latest 100 read receipts.
 
         Args:
-            to_key: Max stream id to fetch receipts upto.
+            to_key: Max stream id to fetch receipts up to.
             from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 8d05288ed4..d5b5507815 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -79,13 +79,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         # call `find_max_generated_user_id_localpart` each time, which is
         # expensive if there are many entries.
         self._user_id_seq = build_sequence_generator(
-            database.engine, find_max_generated_user_id_localpart, "user_id_seq",
+            database.engine,
+            find_max_generated_user_id_localpart,
+            "user_id_seq",
         )
 
         self._account_validity = hs.config.account_validity
         if hs.config.run_background_tasks and self._account_validity.enabled:
             self._clock.call_later(
-                0.0, self._set_expiration_date_when_missing,
+                0.0,
+                self._set_expiration_date_when_missing,
             )
 
         # Create a background job for culling expired 3PID validity tokens
@@ -110,6 +113,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 "creation_ts",
                 "user_type",
                 "deactivated",
+                "shadow_banned",
             ],
             allow_none=True,
             desc="get_user_by_id",
@@ -360,6 +364,37 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
+    async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
+        """Sets whether a user shadow-banned.
+
+        Args:
+            user: user ID of the user to test
+            shadow_banned: true iff the user is to be shadow-banned, false otherwise.
+        """
+
+        def set_shadow_banned_txn(txn):
+            user_id = user.to_string()
+            self.db_pool.simple_update_one_txn(
+                txn,
+                table="users",
+                keyvalues={"name": user_id},
+                updatevalues={"shadow_banned": shadow_banned},
+            )
+            # In order for this to apply immediately, clear the cache for this user.
+            tokens = self.db_pool.simple_select_onecol_txn(
+                txn,
+                table="access_tokens",
+                keyvalues={"user_id": user_id},
+                retcol="token",
+            )
+            for token in tokens:
+                self._invalidate_cache_and_stream(
+                    txn, self.get_user_by_access_token, (token,)
+                )
+            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
+        await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+
     def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
         sql = """
             SELECT users.name as user_id,
@@ -443,6 +478,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
 
+    async def record_user_external_id(
+        self, auth_provider: str, external_id: str, user_id: str
+    ) -> None:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        await self.db_pool.simple_insert(
+            table="user_external_ids",
+            values={
+                "auth_provider": auth_provider,
+                "external_id": external_id,
+                "user_id": user_id,
+            },
+            desc="record_user_external_id",
+        )
+
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
     ) -> Optional[str]:
@@ -1104,7 +1159,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
                 FROM user_threepids
             """
 
-            txn.executemany(sql, [(id_server,) for id_server in id_servers])
+            txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
 
         if id_servers:
             await self.db_pool.runInteraction(
@@ -1371,26 +1426,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-    async def record_user_external_id(
-        self, auth_provider: str, external_id: str, user_id: str
-    ) -> None:
-        """Record a mapping from an external user id to a mxid
-
-        Args:
-            auth_provider: identifier for the remote auth provider
-            external_id: id on that system
-            user_id: complete mxid that it is mapped to
-        """
-        await self.db_pool.simple_insert(
-            table="user_external_ids",
-            values={
-                "auth_provider": auth_provider,
-                "external_id": external_id,
-                "user_id": user_id,
-            },
-            desc="record_user_external_id",
-        )
-
     async def user_set_password_hash(
         self, user_id: str, password_hash: Optional[str]
     ) -> None:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a9fcb5f59c..9cbcd53026 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -193,8 +193,7 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
     async def get_room_count(self) -> int:
-        """Retrieve the total number of rooms.
-        """
+        """Retrieve the total number of rooms."""
 
         def f(txn):
             sql = "SELECT count(*)  FROM rooms"
@@ -517,7 +516,8 @@ class RoomWorkerStore(SQLBaseStore):
             return rooms, room_count[0]
 
         return await self.db_pool.runInteraction(
-            "get_rooms_paginate", _get_rooms_paginate_txn,
+            "get_rooms_paginate",
+            _get_rooms_paginate_txn,
         )
 
     @cached(max_entries=10000)
@@ -578,7 +578,8 @@ class RoomWorkerStore(SQLBaseStore):
             return self.db_pool.cursor_to_dict(txn)
 
         ret = await self.db_pool.runInteraction(
-            "get_retention_policy_for_room", get_retention_policy_for_room_txn,
+            "get_retention_policy_for_room",
+            get_retention_policy_for_room_txn,
         )
 
         # If we don't know this room ID, ret will be None, in this case return the default
@@ -707,7 +708,10 @@ class RoomWorkerStore(SQLBaseStore):
         return local_media_mxcs, remote_media_mxcs
 
     async def quarantine_media_by_id(
-        self, server_name: str, media_id: str, quarantined_by: str,
+        self,
+        server_name: str,
+        media_id: str,
+        quarantined_by: str,
     ) -> int:
         """quarantines a single local or remote media id
 
@@ -961,7 +965,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         self.config = hs.config
 
         self.db_pool.updates.register_background_update_handler(
-            "insert_room_retention", self._background_insert_retention,
+            "insert_room_retention",
+            self._background_insert_retention,
         )
 
         self.db_pool.updates.register_background_update_handler(
@@ -1033,7 +1038,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
                 return False
 
         end = await self.db_pool.runInteraction(
-            "insert_room_retention", _background_insert_retention_txn,
+            "insert_room_retention",
+            _background_insert_retention_txn,
         )
 
         if end:
@@ -1044,7 +1050,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
     async def _background_add_rooms_room_version_column(
         self, progress: dict, batch_size: int
     ):
-        """Background update to go and add room version inforamtion to `rooms`
+        """Background update to go and add room version information to `rooms`
         table from `current_state_events` table.
         """
 
@@ -1588,7 +1594,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 LIMIT ?
                 OFFSET ?
             """.format(
-                where_clause=where_clause, order=order,
+                where_clause=where_clause,
+                order=order,
             )
 
             args += [limit, start]
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..a9216ca9ae 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -70,10 +70,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         ):
             self._known_servers_count = 1
             self.hs.get_clock().looping_call(
-                self._count_known_servers, 60 * 1000,
+                self._count_known_servers,
+                60 * 1000,
             )
             self.hs.get_clock().call_later(
-                1000, self._count_known_servers,
+                1000,
+                self._count_known_servers,
             )
             LaterGauge(
                 "synapse_federation_known_servers",
@@ -174,7 +176,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(max_entries=100000)
     async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
-        """ Get the details of a room roughly suitable for use by the room
+        """Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
         Args:
             room_id: The room ID to query
@@ -488,8 +490,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     async def get_users_who_share_room_with_user(
         self, user_id: str, cache_context: _CacheContext
     ) -> Set[str]:
-        """Returns the set of users who share a room with `user_id`
-        """
+        """Returns the set of users who share a room with `user_id`"""
         room_ids = await self.get_rooms_for_user(
             user_id, on_invalidate=cache_context.invalidate
         )
@@ -618,7 +619,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
+        cached_method_name="_get_joined_profile_from_event_id",
+        list_name="event_ids",
     )
     async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
         """For given set of member event_ids check if they point to a join
@@ -802,8 +804,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
     ) -> List[dict]:
-        """Get user_id and membership of a set of event IDs.
-        """
+        """Get user_id and membership of a set of event IDs."""
 
         return await self.db_pool.simple_select_many_batch(
             table="room_memberships",
@@ -873,8 +874,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             "max_stream_id_exclusive", self._stream_order_on_start + 1
         )
 
-        INSERT_CLUMP_SIZE = 1000
-
         def add_membership_profile_txn(txn):
             sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +914,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
                 UPDATE room_memberships SET display_name = ?, avatar_url = ?
                 WHERE event_id = ? AND room_id = ?
             """
-            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
-                clump = to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(to_update_sql, clump)
+            txn.execute_batch(to_update_sql, to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index ad875c733a..3907189e29 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,5 +23,6 @@ def run_create(cur, database_engine, *args, **kwargs):
 
 def run_upgrade(cur, database_engine, *args, **kwargs):
     cur.execute(
-        "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
+        "UPDATE remote_media_cache SET last_access_ts = ?",
+        (int(time.time() * 1000),),
     )
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
index f35c70b699..9e8f35c1d2 100644
--- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
         # { "ignored_users": "@someone:example.org": {} }
         ignored_users = content.get("ignored_users", {})
         if isinstance(ignored_users, dict) and ignored_users:
-            cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+            cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
 
     # Add indexes after inserting data for efficiency.
     logger.info("Adding constraints to ignored_users table")
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..308124e531 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -67,11 +67,6 @@ CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT
 CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
 CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value )
 /* event_search(event_id,room_id,sender,"key",value) */;
-CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value');
-CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB);
-CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));
-CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB);
-CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB);
 CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) );
 CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) );
 CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag     TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) );
@@ -149,11 +144,6 @@ CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_las
 CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') );
 CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value )
 /* user_directory_search(user_id,value) */;
-CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value');
-CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB);
-CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));
-CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB);
-CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB);
 CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL );
 CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);
 CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT );
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..f5e7d9ef98 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
 
@@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
 
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
@@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
     async def search_rooms(
         self,
-        room_ids: List[str],
+        room_ids: Collection[str],
         search_term: str,
         keys: List[str],
         limit,
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 3c1e33819b..a7f371732f 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -52,8 +52,7 @@ class _GetStateGroupDelta(
 
 # this inherits from EventsWorkerStore because it calls self.get_events
 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
-    """The parts of StateGroupStore that can be called from workers.
-    """
+    """The parts of StateGroupStore that can be called from workers."""
 
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -276,8 +275,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         num_args=1,
     )
     async def _get_state_group_for_events(self, event_ids):
-        """Returns mapping event_id -> state_group
-        """
+        """Returns mapping event_id -> state_group"""
         rows = await self.db_pool.simple_select_many_batch(
             table="event_to_state_groups",
             column="event_id",
@@ -338,7 +336,8 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
             columns=["state_group"],
         )
         self.db_pool.updates.register_background_update_handler(
-            self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
+            self.DELETE_CURRENT_STATE_UPDATE_NAME,
+            self._background_remove_left_rooms,
         )
 
     async def _background_remove_left_rooms(self, progress, batch_size):
@@ -487,7 +486,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
 
 
 class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
-    """ Keeps track of the state at a given event.
+    """Keeps track of the state at a given event.
 
     This is done by the concept of `state groups`. Every event is a assigned
     a state group (identified by an arbitrary string), which references a
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 356623fc6e..0dbb501f16 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -64,7 +64,7 @@ class StateDeltasStore(SQLBaseStore):
         def get_current_state_deltas_txn(txn):
             # First we calculate the max stream id that will give us less than
             # N results.
-            # We arbitarily limit to 100 stream_id entries to ensure we don't
+            # We arbitrarily limit to 100 stream_id entries to ensure we don't
             # select toooo many.
             sql = """
                 SELECT stream_id, count(*)
@@ -81,7 +81,7 @@ class StateDeltasStore(SQLBaseStore):
             for stream_id, count in txn:
                 total += count
                 if total > 100:
-                    # We arbitarily limit to 100 entries to ensure we don't
+                    # We arbitrarily limit to 100 entries to ensure we don't
                     # select toooo many.
                     logger.debug(
                         "Clipping current_state_delta_stream rows to stream_id %i",
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0cdb3ec1f7..1c99393c65 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,11 +15,12 @@
 # limitations under the License.
 
 import logging
-from collections import Counter
 from enum import Enum
 from itertools import chain
 from typing import Any, Dict, List, Optional, Tuple
 
+from typing_extensions import Counter
+
 from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
@@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore):
         return slice_list
 
     @cached()
-    async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
+    async def get_earliest_token_for_stats(
+        self, stats_type: str, id: str
+    ) -> Optional[int]:
         """
         Fetch the "earliest token". This is used by the room stats delta
         processor to ignore deltas that have been processed between the
@@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore):
         )
 
     async def bulk_update_stats_delta(
-        self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+        self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
     ) -> None:
         """Bulk update stats tables for a given stream_id and updates the stats
         incremental position.
@@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore):
 
     async def get_changes_room_total_events_and_bytes(
         self, min_pos: int, max_pos: int
-    ) -> Dict[str, Dict[str, int]]:
+    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
         """Fetches the counts of events in the given range of stream IDs.
 
         Args:
@@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore):
             max_pos,
         )
 
-    def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+    def get_changes_room_total_events_and_bytes_txn(
+        self, txn, low_pos: int, high_pos: int
+    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
         """Gets the total_events and total_event_bytes counts for rooms and
         senders, in a range of stream_orderings (including backfilled events).
 
         Args:
             txn
-            low_pos (int): Low stream ordering
-            high_pos (int): High stream ordering
+            low_pos: Low stream ordering
+            high_pos: High stream ordering
 
         Returns:
-            tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
-            room and user deltas for total_events/total_event_bytes in the
+            The room and user deltas for total_events/total_event_bytes in the
             format of `stats_id` -> fields
         """
 
@@ -997,7 +1001,9 @@ class StatsStore(StateDeltasStore):
                 ORDER BY {order_by_column} {order}
                 LIMIT ? OFFSET ?
             """.format(
-                sql_base=sql_base, order_by_column=order_by_column, order=order,
+                sql_base=sql_base,
+                order_by_column=order_by_column,
+                order=order,
             )
 
             args += [limit, start]
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e3b9ff5ca6..91f8abb67d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -565,7 +565,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                     AND e.stream_ordering > ? AND e.stream_ordering <= ?
                 ORDER BY e.stream_ordering ASC
             """
-            txn.execute(sql, (user_id, min_from_id, max_to_id,))
+            txn.execute(
+                sql,
+                (
+                    user_id,
+                    min_from_id,
+                    max_to_id,
+                ),
+            )
 
             rows = [
                 _EventDictReturn(event_id, None, stream_ordering)
@@ -695,7 +702,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             return "t%d-%d" % (topo, token)
 
     def get_stream_id_for_event_txn(
-        self, txn: LoggingTransaction, event_id: str, allow_none=False,
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        allow_none=False,
     ) -> int:
         return self.db_pool.simple_select_one_onecol_txn(
             txn=txn,
@@ -706,8 +716,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         )
 
     async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
-        """Get the persisted position for an event
-        """
+        """Get the persisted position for an event"""
         row = await self.db_pool.simple_select_one(
             table="events",
             keyvalues={"event_id": event_id},
@@ -897,19 +906,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     ) -> Tuple[int, List[EventBase]]:
         """Get all new events
 
-         Returns all events with from_id < stream_ordering <= current_id.
+        Returns all events with from_id < stream_ordering <= current_id.
 
-         Args:
-             from_id:  the stream_ordering of the last event we processed
-             current_id:  the stream_ordering of the most recently processed event
-             limit: the maximum number of events to return
+        Args:
+            from_id:  the stream_ordering of the last event we processed
+            current_id:  the stream_ordering of the most recently processed event
+            limit: the maximum number of events to return
 
-         Returns:
-             A tuple of (next_id, events), where `next_id` is the next value to
-             pass as `from_id` (it will either be the stream_ordering of the
-             last returned event, or, if fewer than `limit` events were found,
-             the `current_id`).
-         """
+        Returns:
+            A tuple of (next_id, events), where `next_id` is the next value to
+            pass as `from_id` (it will either be the stream_ordering of the
+            last returned event, or, if fewer than `limit` events were found,
+            the `current_id`).
+        """
 
         def get_all_new_events_stream_txn(txn):
             sql = (
@@ -1238,8 +1247,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
     @cached()
     async def get_id_for_instance(self, instance_name: str) -> int:
-        """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
-        """
+        """Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
 
         def _get_id_for_instance_txn(txn):
             instance_id = self.db_pool.simple_select_one_onecol_txn(
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index cea595ff19..b921d63d30 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -64,8 +64,7 @@ class TransactionWorkerStore(SQLBaseStore):
 
 
 class TransactionStore(TransactionWorkerStore):
-    """A collection of queries for handling PDUs.
-    """
+    """A collection of queries for handling PDUs."""
 
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -198,7 +197,7 @@ class TransactionStore(TransactionWorkerStore):
         retry_interval: int,
     ) -> None:
         """Sets the current retry timings for a given destination.
-        Both timings should be zero if retrying is no longer occuring.
+        Both timings should be zero if retrying is no longer occurring.
 
         Args:
             destination
@@ -299,7 +298,10 @@ class TransactionStore(TransactionWorkerStore):
             )
 
     async def store_destination_rooms_entries(
-        self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+        self,
+        destinations: Iterable[str],
+        room_id: str,
+        stream_ordering: int,
     ) -> None:
         """
         Updates or creates `destination_rooms` entries in batch for a single event.
@@ -394,7 +396,9 @@ class TransactionStore(TransactionWorkerStore):
         )
 
     async def get_catch_up_room_event_ids(
-        self, destination: str, last_successful_stream_ordering: int,
+        self,
+        destination: str,
+        last_successful_stream_ordering: int,
     ) -> List[str]:
         """
         Returns at most 50 event IDs and their corresponding stream_orderings
@@ -418,7 +422,9 @@ class TransactionStore(TransactionWorkerStore):
 
     @staticmethod
     def _get_catch_up_room_event_ids_txn(
-        txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
+        txn: LoggingTransaction,
+        destination: str,
+        last_successful_stream_ordering: int,
     ) -> List[str]:
         q = """
                 SELECT event_id FROM destination_rooms
@@ -429,7 +435,8 @@ class TransactionStore(TransactionWorkerStore):
                 LIMIT 50
             """
         txn.execute(
-            q, (destination, last_successful_stream_ordering),
+            q,
+            (destination, last_successful_stream_ordering),
         )
         event_ids = [row[0] for row in txn]
         return event_ids
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 79b7ece330..5473ec1485 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -44,7 +44,11 @@ class UIAuthWorkerStore(SQLBaseStore):
     """
 
     async def create_ui_auth_session(
-        self, clientdict: JsonDict, uri: str, method: str, description: str,
+        self,
+        clientdict: JsonDict,
+        uri: str,
+        method: str,
+        description: str,
     ) -> UIAuthSessionData:
         """
         Creates a new user interactive authentication session.
@@ -123,7 +127,10 @@ class UIAuthWorkerStore(SQLBaseStore):
         return UIAuthSessionData(session_id, **result)
 
     async def mark_ui_auth_stage_complete(
-        self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
+        self,
+        session_id: str,
+        stage_type: str,
+        result: Union[str, bool, JsonDict],
     ):
         """
         Mark a session stage as completed.
@@ -261,10 +268,12 @@ class UIAuthWorkerStore(SQLBaseStore):
         return serverdict.get(key, default)
 
     async def add_user_agent_ip_to_ui_auth_session(
-        self, session_id: str, user_agent: str, ip: str,
+        self,
+        session_id: str,
+        user_agent: str,
+        ip: str,
     ):
-        """Add the given user agent / IP to the tracking table
-        """
+        """Add the given user agent / IP to the tracking table"""
         await self.db_pool.simple_upsert(
             table="ui_auth_sessions_ips",
             keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
@@ -273,7 +282,8 @@ class UIAuthWorkerStore(SQLBaseStore):
         )
 
     async def get_user_agents_ips_to_ui_auth_session(
-        self, session_id: str,
+        self,
+        session_id: str,
     ) -> List[Tuple[str, str]]:
         """Get the given user agents / IPs used during the ui auth process
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ef11f1c3b3..63f88eac51 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -336,8 +336,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         return len(users_to_work_on)
 
     async def is_room_world_readable_or_publicly_joinable(self, room_id):
-        """Check if the room is either world_readable or publically joinable
-        """
+        """Check if the room is either world_readable or publically joinable"""
 
         # Create a state filter that only queries join and history state event
         types_to_filter = (
@@ -516,8 +515,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
     async def delete_all_from_user_dir(self) -> None:
-        """Delete the entire user directory
-        """
+        """Delete the entire user directory"""
 
         def _delete_all_from_user_dir_txn(txn):
             txn.execute("DELETE FROM user_directory")
@@ -540,7 +538,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             desc="get_user_in_directory",
         )
 
-    async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+    async def update_user_directory_stream_pos(self, stream_id: int) -> None:
         await self.db_pool.simple_update_one(
             table="user_directory_stream_pos",
             keyvalues={},
@@ -709,7 +707,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
 
         return {row["room_id"] for row in rows}
 
-    async def get_user_directory_stream_pos(self) -> int:
+    async def get_user_directory_stream_pos(self) -> Optional[int]:
+        """
+        Get the stream ID of the user directory stream.
+
+        Returns:
+            The stream token or None if the initial background update hasn't happened yet.
+        """
         return await self.db_pool.simple_select_one_onecol(
             table="user_directory_stream_pos",
             keyvalues={},
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index acb24e33af..1fd333b707 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -27,7 +27,7 @@ MAX_STATE_DELTA_HOPS = 100
 
 
 class StateGroupBackgroundUpdateStore(SQLBaseStore):
-    """Defines functions related to state groups needed to run the state backgroud
+    """Defines functions related to state groups needed to run the state background
     updates.
     """
 
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..b16b9905d8 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -48,8 +48,7 @@ class _GetStateGroupDelta(
 
 
 class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
-    """A data store for fetching/storing state groups.
-    """
+    """A data store for fetching/storing state groups."""
 
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -89,7 +88,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             50000,
         )
         self._state_group_members_cache = DictionaryCache(
-            "*stateGroupMembersCache*", 500000,
+            "*stateGroupMembersCache*",
+            500000,
         )
 
         def get_max_state_group_txn(txn: Cursor):
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             )
 
         logger.info("[purge] removing redundant state groups")
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             ((sg,) for sg in state_groups_to_delete),
         )
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups WHERE id = ?",
             ((sg,) for sg in state_groups_to_delete),
         )