summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/account_data.py2
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py4
-rw-r--r--synapse/storage/databases/main/events.py217
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py12
-rw-r--r--synapse/storage/databases/main/events_forward_extremities.py101
-rw-r--r--synapse/storage/databases/main/media_repository.py10
-rw-r--r--synapse/storage/databases/main/metrics.py56
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/receipts.py4
-rw-r--r--synapse/storage/databases/main/registration.py71
-rw-r--r--synapse/storage/databases/main/room.py6
-rw-r--r--synapse/storage/databases/main/roommember.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/59/01ignored_user.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql18
-rw-r--r--synapse/storage/databases/main/search.py7
-rw-r--r--synapse/storage/databases/main/stats.py22
-rw-r--r--synapse/storage/databases/main/user_directory.py2
-rw-r--r--synapse/storage/databases/state/store.py4
22 files changed, 425 insertions, 137 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ae561a2da3..5d0845588c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
 from .event_federation import EventFederationStore
 from .event_push_actions import EventPushActionsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
+from .events_forward_extremities import EventForwardExtremitiesStore
 from .filtering import FilteringStore
 from .group_server import GroupServerStore
 from .keys import KeyStore
@@ -118,6 +119,7 @@ class DataStore(
     UIAuthStore,
     CacheInvalidationWorkerStore,
     ServerMetricsStore,
+    EventForwardExtremitiesStore,
 ):
     def __init__(self, database: DatabasePool, db_conn, hs):
         self.hs = hs
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 68896f34af..a277a1ef13 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -68,7 +68,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
             # updated over replication. (Multiple writers are not supported for
             # SQLite).
-            if hs.get_instance_name() in hs.config.worker.writers.events:
+            if hs.get_instance_name() in hs.config.worker.writers.account_data:
                 self._account_data_id_gen = StreamIdGenerator(
                     db_conn,
                     "room_account_data",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9097677648..659d8f245f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 DELETE FROM device_lists_outbound_last_success
                 WHERE destination = ? AND user_id = ?
             """
-            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+            txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
 
             logger.info("Pruned %d device list outbound pokes", count)
 
@@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
-        txn.executemany(
+        txn.execute_batch(
             """
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c128889bf9..309f1e865b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, dict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str]]
-    ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+    ) -> Dict[str, Dict[str, Dict[str, str]]]:
         """Take a list of one time keys out of the database.
 
         Args:
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 1b657191a9..438383abe1 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 VALUES (?, ?, ?, ?, ?, ?)
             """
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     _gen_entry(user_id, actions)
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             ],
         )
 
-        txn.executemany(
+        txn.execute_batch(
             """
                 UPDATE event_push_summary
                 SET notif_count = ?, unread_count = ?, stream_ordering = ?
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3216b3f3c8..ccda9f1caa 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -473,8 +473,9 @@ class PersistEventsStore:
             txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
         )
 
-    @staticmethod
+    @classmethod
     def _add_chain_cover_index(
+        cls,
         txn,
         db_pool: DatabasePool,
         event_to_room_id: Dict[str, str],
@@ -614,60 +615,17 @@ class PersistEventsStore:
         if not events_to_calc_chain_id_for:
             return
 
-        # We now calculate the chain IDs/sequence numbers for the events. We
-        # do this by looking at the chain ID and sequence number of any auth
-        # event with the same type/state_key and incrementing the sequence
-        # number by one. If there was no match or the chain ID/sequence
-        # number is already taken we generate a new chain.
-        #
-        # We need to do this in a topologically sorted order as we want to
-        # generate chain IDs/sequence numbers of an event's auth events
-        # before the event itself.
-        chains_tuples_allocated = set()  # type: Set[Tuple[int, int]]
-        new_chain_tuples = {}  # type: Dict[str, Tuple[int, int]]
-        for event_id in sorted_topologically(
-            events_to_calc_chain_id_for, event_to_auth_chain
-        ):
-            existing_chain_id = None
-            for auth_id in event_to_auth_chain.get(event_id, []):
-                if event_to_types.get(event_id) == event_to_types.get(auth_id):
-                    existing_chain_id = chain_map[auth_id]
-                    break
-
-            new_chain_tuple = None
-            if existing_chain_id:
-                # We found a chain ID/sequence number candidate, check its
-                # not already taken.
-                proposed_new_id = existing_chain_id[0]
-                proposed_new_seq = existing_chain_id[1] + 1
-                if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
-                    already_allocated = db_pool.simple_select_one_onecol_txn(
-                        txn,
-                        table="event_auth_chains",
-                        keyvalues={
-                            "chain_id": proposed_new_id,
-                            "sequence_number": proposed_new_seq,
-                        },
-                        retcol="event_id",
-                        allow_none=True,
-                    )
-                    if already_allocated:
-                        # Mark it as already allocated so we don't need to hit
-                        # the DB again.
-                        chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
-                    else:
-                        new_chain_tuple = (
-                            proposed_new_id,
-                            proposed_new_seq,
-                        )
-
-            if not new_chain_tuple:
-                new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
-
-            chains_tuples_allocated.add(new_chain_tuple)
-
-            chain_map[event_id] = new_chain_tuple
-            new_chain_tuples[event_id] = new_chain_tuple
+        # Allocate chain ID/sequence numbers to each new event.
+        new_chain_tuples = cls._allocate_chain_ids(
+            txn,
+            db_pool,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
+            events_to_calc_chain_id_for,
+            chain_map,
+        )
+        chain_map.update(new_chain_tuples)
 
         db_pool.simple_insert_many_txn(
             txn,
@@ -794,6 +752,137 @@ class PersistEventsStore:
             ],
         )
 
+    @staticmethod
+    def _allocate_chain_ids(
+        txn,
+        db_pool: DatabasePool,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+        events_to_calc_chain_id_for: Set[str],
+        chain_map: Dict[str, Tuple[int, int]],
+    ) -> Dict[str, Tuple[int, int]]:
+        """Allocates, but does not persist, chain ID/sequence numbers for the
+        events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+        for info on args)
+        """
+
+        # We now calculate the chain IDs/sequence numbers for the events. We do
+        # this by looking at the chain ID and sequence number of any auth event
+        # with the same type/state_key and incrementing the sequence number by
+        # one. If there was no match or the chain ID/sequence number is already
+        # taken we generate a new chain.
+        #
+        # We try to reduce the number of times that we hit the database by
+        # batching up calls, to make this more efficient when persisting large
+        # numbers of state events (e.g. during joins).
+        #
+        # We do this by:
+        #   1. Calculating for each event which auth event will be used to
+        #      inherit the chain ID, i.e. converting the auth chain graph to a
+        #      tree that we can allocate chains on. We also keep track of which
+        #      existing chain IDs have been referenced.
+        #   2. Fetching the max allocated sequence number for each referenced
+        #      existing chain ID, generating a map from chain ID to the max
+        #      allocated sequence number.
+        #   3. Iterating over the tree and allocating a chain ID/seq no. to the
+        #      new event, by incrementing the sequence number from the
+        #      referenced event's chain ID/seq no. and checking that the
+        #      incremented sequence number hasn't already been allocated (by
+        #      looking in the map generated in the previous step). We generate a
+        #      new chain if the sequence number has already been allocated.
+        #
+
+        existing_chains = set()  # type: Set[int]
+        tree = []  # type: List[Tuple[str, Optional[str]]]
+
+        # We need to do this in a topologically sorted order as we want to
+        # generate chain IDs/sequence numbers of an event's auth events before
+        # the event itself.
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            for auth_id in event_to_auth_chain.get(event_id, []):
+                if event_to_types.get(event_id) == event_to_types.get(auth_id):
+                    existing_chain_id = chain_map.get(auth_id)
+                    if existing_chain_id:
+                        existing_chains.add(existing_chain_id[0])
+
+                    tree.append((event_id, auth_id))
+                    break
+            else:
+                tree.append((event_id, None))
+
+        # Fetch the current max sequence number for each existing referenced chain.
+        sql = """
+            SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+            WHERE %s
+            GROUP BY chain_id
+        """
+        clause, args = make_in_list_sql_clause(
+            db_pool.engine, "chain_id", existing_chains
+        )
+        txn.execute(sql % (clause,), args)
+
+        chain_to_max_seq_no = {row[0]: row[1] for row in txn}  # type: Dict[Any, int]
+
+        # Allocate the new events chain ID/sequence numbers.
+        #
+        # To reduce the number of calls to the database we don't allocate a
+        # chain ID number in the loop, instead we use a temporary `object()` for
+        # each new chain ID. Once we've done the loop we generate the necessary
+        # number of new chain IDs in one call, replacing all temporary
+        # objects with real allocated chain IDs.
+
+        unallocated_chain_ids = set()  # type: Set[object]
+        new_chain_tuples = {}  # type: Dict[str, Tuple[Any, int]]
+        for event_id, auth_event_id in tree:
+            # If we reference an auth_event_id we fetch the allocated chain ID,
+            # either from the existing `chain_map` or the newly generated
+            # `new_chain_tuples` map.
+            existing_chain_id = None
+            if auth_event_id:
+                existing_chain_id = new_chain_tuples.get(auth_event_id)
+                if not existing_chain_id:
+                    existing_chain_id = chain_map[auth_event_id]
+
+            new_chain_tuple = None  # type: Optional[Tuple[Any, int]]
+            if existing_chain_id:
+                # We found a chain ID/sequence number candidate, check its
+                # not already taken.
+                proposed_new_id = existing_chain_id[0]
+                proposed_new_seq = existing_chain_id[1] + 1
+
+                if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+                    new_chain_tuple = (
+                        proposed_new_id,
+                        proposed_new_seq,
+                    )
+
+            # If we need to start a new chain we allocate a temporary chain ID.
+            if not new_chain_tuple:
+                new_chain_tuple = (object(), 1)
+                unallocated_chain_ids.add(new_chain_tuple[0])
+
+            new_chain_tuples[event_id] = new_chain_tuple
+            chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+        # Generate new chain IDs for all unallocated chain IDs.
+        newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+            txn, len(unallocated_chain_ids)
+        )
+
+        # Map from potentially temporary chain ID to real chain ID
+        chain_id_to_allocated_map = dict(
+            zip(unallocated_chain_ids, newly_allocated_chain_ids)
+        )  # type: Dict[Any, int]
+        chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+        return {
+            event_id: (chain_id_to_allocated_map[chain_id], seq)
+            for event_id, (chain_id, seq) in new_chain_tuples.items()
+        }
+
     def _persist_transaction_ids_txn(
         self,
         txn: LoggingTransaction,
@@ -876,7 +965,7 @@ class PersistEventsStore:
                         WHERE room_id = ? AND type = ? AND state_key = ?
                     )
                 """
-                txn.executemany(
+                txn.execute_batch(
                     sql,
                     (
                         (
@@ -895,7 +984,7 @@ class PersistEventsStore:
                 )
                 # Now we actually update the current_state_events table
 
-                txn.executemany(
+                txn.execute_batch(
                     "DELETE FROM current_state_events"
                     " WHERE room_id = ? AND type = ? AND state_key = ?",
                     (
@@ -907,7 +996,7 @@ class PersistEventsStore:
                 # We include the membership in the current state table, hence we do
                 # a lookup when we insert. This assumes that all events have already
                 # been inserted into room_memberships.
-                txn.executemany(
+                txn.execute_batch(
                     """INSERT INTO current_state_events
                         (room_id, type, state_key, event_id, membership)
                     VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -927,7 +1016,7 @@ class PersistEventsStore:
             # we have no record of the fact the user *was* a member of the
             # room but got, say, state reset out of it.
             if to_delete or to_insert:
-                txn.executemany(
+                txn.execute_batch(
                     "DELETE FROM local_current_membership"
                     " WHERE room_id = ? AND user_id = ?",
                     (
@@ -938,7 +1027,7 @@ class PersistEventsStore:
                 )
 
             if to_insert:
-                txn.executemany(
+                txn.execute_batch(
                     """INSERT INTO local_current_membership
                         (room_id, user_id, event_id, membership)
                     VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -1738,7 +1827,7 @@ class PersistEventsStore:
         """
 
         if events_and_contexts:
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (
@@ -1767,7 +1856,7 @@ class PersistEventsStore:
 
         # Now we delete the staging area for *all* events that were being
         # persisted.
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM event_push_actions_staging WHERE event_id = ?",
             ((event.event_id,) for event, _ in all_events_and_contexts),
         )
@@ -1886,7 +1975,7 @@ class PersistEventsStore:
             " )"
         )
 
-        txn.executemany(
+        txn.execute_batch(
             query,
             [
                 (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1900,7 +1989,7 @@ class PersistEventsStore:
             "DELETE FROM event_backward_extremities"
             " WHERE event_id = ? AND room_id = ?"
         )
-        txn.executemany(
+        txn.execute_batch(
             query,
             [
                 (ev.event_id, ev.room_id)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e46e44ba54..5ca4fa6817 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id, json FROM events"
@@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
 
-            for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
-                clump = update_rows[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, update_rows)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_search_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id FROM events"
@@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
 
-            for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
-                clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, rows_to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
new file mode 100644
index 0000000000..0ac1da9c35
--- /dev/null
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventForwardExtremitiesStore(SQLBaseStore):
+    async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+        """Delete any extra forward extremities for a room.
+
+        Invalidates the "get_latest_event_ids_in_room" cache if any forward
+        extremities were deleted.
+
+        Returns count deleted.
+        """
+
+        def delete_forward_extremities_for_room_txn(txn):
+            # First we need to get the event_id to not delete
+            sql = """
+                SELECT event_id FROM event_forward_extremities
+                INNER JOIN events USING (room_id, event_id)
+                WHERE room_id = ?
+                ORDER BY stream_ordering DESC
+                LIMIT 1
+            """
+            txn.execute(sql, (room_id,))
+            rows = txn.fetchall()
+            try:
+                event_id = rows[0][0]
+                logger.debug(
+                    "Found event_id %s as the forward extremity to keep for room %s",
+                    event_id,
+                    room_id,
+                )
+            except KeyError:
+                msg = "No forward extremity event found for room %s" % room_id
+                logger.warning(msg)
+                raise SynapseError(400, msg)
+
+            # Now delete the extra forward extremities
+            sql = """
+                DELETE FROM event_forward_extremities
+                WHERE event_id != ? AND room_id = ?
+            """
+
+            txn.execute(sql, (event_id, room_id))
+            logger.info(
+                "Deleted %s extra forward extremities for room %s",
+                txn.rowcount,
+                room_id,
+            )
+
+            if txn.rowcount > 0:
+                # Invalidate the cache
+                self._invalidate_cache_and_stream(
+                    txn, self.get_latest_event_ids_in_room, (room_id,),
+                )
+
+            return txn.rowcount
+
+        return await self.db_pool.runInteraction(
+            "delete_forward_extremities_for_room",
+            delete_forward_extremities_for_room_txn,
+        )
+
+    async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+        """Get list of forward extremities for a room."""
+
+        def get_forward_extremities_for_room_txn(txn):
+            sql = """
+                SELECT event_id, state_group, depth, received_ts
+                FROM event_forward_extremities
+                INNER JOIN event_to_state_groups USING (event_id)
+                INNER JOIN events USING (room_id, event_id)
+                WHERE room_id = ?
+            """
+
+            txn.execute(sql, (room_id,))
+            return self.db_pool.cursor_to_dict(txn)
+
+        return await self.db_pool.runInteraction(
+            "get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
+        )
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 283c8a5e22..e017177655 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_origin = ? AND media_id = ?"
             )
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (time_ms, media_origin, media_id)
@@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_id = ?"
             )
 
-            txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+            txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
 
         return await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
@@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
         def _delete_url_cache_txn(txn):
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache", _delete_url_cache_txn
@@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def _delete_url_cache_media_txn(txn):
             sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
             sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..92e65aa640 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
+    async def count_daily_e2ee_messages(self):
+        """
+        Returns an estimate of the number of messages sent in the last day.
+
+        If it has been significantly less or more than one day since the last
+        call to this function, it will return None.
+        """
+
+        def _count_messages(txn):
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
+
+    async def count_daily_sent_e2ee_messages(self):
+        def _count_messages(txn):
+            # This is good enough as if you have silly characters in your own
+            # hostname then thats your own fault.
+            like_clause = "%:" + self.hs.hostname
+
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                    AND sender LIKE ?
+                AND stream_ordering > ?
+            """
+
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction(
+            "count_daily_sent_e2ee_messages", _count_messages
+        )
+
+    async def count_daily_active_e2ee_rooms(self):
+        def _count(txn):
+            sql = """
+                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction(
+            "count_daily_active_e2ee_rooms", _count
+        )
+
     async def count_daily_messages(self):
         """
         Returns an estimate of the number of messages sent in the last day.
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         )
 
         # Update backward extremeties
-        txn.executemany(
+        txn.execute_batch(
             "INSERT INTO event_backward_extremities (room_id, event_id)"
             " VALUES (?, ?)",
             [(room_id, event_id) for event_id, in new_backwards_extrems],
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bc7621b8d6..2687ef3e43 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore):
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
-            self.db_pool.simple_delete_one_txn(
+            # It is expected that there is exactly one pusher to delete, but
+            # if it isn't there (or there are multiple) delete them all.
+            self.db_pool.simple_delete_txn(
                 txn,
                 "pushers",
                 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e0e57f0578..e4843a202c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -45,7 +45,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             self._receipts_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
-                stream_name="account_data",
+                stream_name="receipts",
                 instance_name=self._instance_name,
                 tables=[("receipts_linearized", "instance_name", "stream_id")],
                 sequence_name="receipts_sequence",
@@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
             # updated over replication. (Multiple writers are not supported for
             # SQLite).
-            if hs.get_instance_name() in hs.config.worker.writers.events:
+            if hs.get_instance_name() in hs.config.worker.writers.receipts:
                 self._receipts_id_gen = StreamIdGenerator(
                     db_conn, "receipts_linearized", "stream_id"
                 )
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 8d05288ed4..8405dd460f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
+    async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
+        """Sets whether a user shadow-banned.
+
+        Args:
+            user: user ID of the user to test
+            shadow_banned: true iff the user is to be shadow-banned, false otherwise.
+        """
+
+        def set_shadow_banned_txn(txn):
+            self.db_pool.simple_update_one_txn(
+                txn,
+                table="users",
+                keyvalues={"name": user.to_string()},
+                updatevalues={"shadow_banned": shadow_banned},
+            )
+            # In order for this to apply immediately, clear the cache for this user.
+            tokens = self.db_pool.simple_select_onecol_txn(
+                txn,
+                table="access_tokens",
+                keyvalues={"user_id": user.to_string()},
+                retcol="token",
+            )
+            for token in tokens:
+                self._invalidate_cache_and_stream(
+                    txn, self.get_user_by_access_token, (token,)
+                )
+
+        await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+
     def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
         sql = """
             SELECT users.name as user_id,
@@ -443,6 +472,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
 
+    async def record_user_external_id(
+        self, auth_provider: str, external_id: str, user_id: str
+    ) -> None:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        await self.db_pool.simple_insert(
+            table="user_external_ids",
+            values={
+                "auth_provider": auth_provider,
+                "external_id": external_id,
+                "user_id": user_id,
+            },
+            desc="record_user_external_id",
+        )
+
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
     ) -> Optional[str]:
@@ -1104,7 +1153,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
                 FROM user_threepids
             """
 
-            txn.executemany(sql, [(id_server,) for id_server in id_servers])
+            txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
 
         if id_servers:
             await self.db_pool.runInteraction(
@@ -1371,26 +1420,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-    async def record_user_external_id(
-        self, auth_provider: str, external_id: str, user_id: str
-    ) -> None:
-        """Record a mapping from an external user id to a mxid
-
-        Args:
-            auth_provider: identifier for the remote auth provider
-            external_id: id on that system
-            user_id: complete mxid that it is mapped to
-        """
-        await self.db_pool.simple_insert(
-            table="user_external_ids",
-            values={
-                "auth_provider": auth_provider,
-                "external_id": external_id,
-                "user_id": user_id,
-            },
-            desc="record_user_external_id",
-        )
-
     async def user_set_password_hash(
         self, user_id: str, password_hash: Optional[str]
     ) -> None:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 284f2ce77c..a9fcb5f59c 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -16,7 +16,6 @@
 
 import collections
 import logging
-import re
 from abc import abstractmethod
 from enum import Enum
 from typing import Any, Dict, List, Optional, Tuple
@@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore
 from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
+from synapse.util.stringutils import MXC_REGEX
 
 logger = logging.getLogger(__name__)
 
@@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore):
             The local and remote media as a lists of tuples where the key is
             the hostname and the value is the media ID.
         """
-        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
-
         sql = """
             SELECT stream_ordering, json FROM events
             JOIN event_json USING (room_id, event_id)
@@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore):
                 for url in (content_url, thumbnail_url):
                     if not url:
                         continue
-                    matches = mxc_re.match(url)
+                    matches = MXC_REGEX.match(url)
                     if matches:
                         hostname = matches.group(1)
                         media_id = matches.group(2)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..92382bed28 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             "max_stream_id_exclusive", self._stream_order_on_start + 1
         )
 
-        INSERT_CLUMP_SIZE = 1000
-
         def add_membership_profile_txn(txn):
             sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
                 UPDATE room_memberships SET display_name = ?, avatar_url = ?
                 WHERE event_id = ? AND room_id = ?
             """
-            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
-                clump = to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(to_update_sql, clump)
+            txn.execute_batch(to_update_sql, to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
index f35c70b699..9e8f35c1d2 100644
--- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
         # { "ignored_users": "@someone:example.org": {} }
         ignored_users = content.get("ignored_users", {})
         if isinstance(ignored_users, dict) and ignored_users:
-            cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+            cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
 
     # Add indexes after inserting data for efficiency.
     logger.info("Adding constraints to ignored_users table")
diff --git a/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql
new file mode 100644
index 0000000000..9f2b5ebc5a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We incorrectly populated these, so we delete them and let the
+-- MultiWriterIdGenerator repopulate it.
+DELETE FROM stream_positions WHERE stream_name = 'receipts' OR stream_name = 'account_data';
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..f5e7d9ef98 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
 
@@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
 
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
@@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
     async def search_rooms(
         self,
-        room_ids: List[str],
+        room_ids: Collection[str],
         search_term: str,
         keys: List[str],
         limit,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0cdb3ec1f7..d421d18f8d 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,11 +15,12 @@
 # limitations under the License.
 
 import logging
-from collections import Counter
 from enum import Enum
 from itertools import chain
 from typing import Any, Dict, List, Optional, Tuple
 
+from typing_extensions import Counter
+
 from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
@@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore):
         return slice_list
 
     @cached()
-    async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
+    async def get_earliest_token_for_stats(
+        self, stats_type: str, id: str
+    ) -> Optional[int]:
         """
         Fetch the "earliest token". This is used by the room stats delta
         processor to ignore deltas that have been processed between the
@@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore):
         )
 
     async def bulk_update_stats_delta(
-        self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+        self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
     ) -> None:
         """Bulk update stats tables for a given stream_id and updates the stats
         incremental position.
@@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore):
 
     async def get_changes_room_total_events_and_bytes(
         self, min_pos: int, max_pos: int
-    ) -> Dict[str, Dict[str, int]]:
+    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
         """Fetches the counts of events in the given range of stream IDs.
 
         Args:
@@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore):
             max_pos,
         )
 
-    def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+    def get_changes_room_total_events_and_bytes_txn(
+        self, txn, low_pos: int, high_pos: int
+    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
         """Gets the total_events and total_event_bytes counts for rooms and
         senders, in a range of stream_orderings (including backfilled events).
 
         Args:
             txn
-            low_pos (int): Low stream ordering
-            high_pos (int): High stream ordering
+            low_pos: Low stream ordering
+            high_pos: High stream ordering
 
         Returns:
-            tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
-            room and user deltas for total_events/total_event_bytes in the
+            The room and user deltas for total_events/total_event_bytes in the
             format of `stats_id` -> fields
         """
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ef11f1c3b3..7b9729da09 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             desc="get_user_in_directory",
         )
 
-    async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+    async def update_user_directory_stream_pos(self, stream_id: int) -> None:
         await self.db_pool.simple_update_one(
             table="user_directory_stream_pos",
             keyvalues={},
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..89cdc84a9c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             )
 
         logger.info("[purge] removing redundant state groups")
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             ((sg,) for sg in state_groups_to_delete),
         )
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups WHERE id = ?",
             ((sg,) for sg in state_groups_to_delete),
         )