summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/database.py62
-rw-r--r--synapse/storage/databases/main/account_data.py6
-rw-r--r--synapse/storage/databases/main/deviceinbox.py30
-rw-r--r--synapse/storage/databases/main/devices.py45
-rw-r--r--synapse/storage/databases/main/directory.py6
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py34
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py48
-rw-r--r--synapse/storage/databases/main/event_push_actions.py21
-rw-r--r--synapse/storage/databases/main/events.py136
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py42
-rw-r--r--synapse/storage/databases/main/presence.py33
-rw-r--r--synapse/storage/databases/main/pusher.py8
-rw-r--r--synapse/storage/databases/main/registration.py18
-rw-r--r--synapse/storage/databases/main/relations.py128
-rw-r--r--synapse/storage/databases/main/roommember.py6
-rw-r--r--synapse/storage/databases/main/session.py1
-rw-r--r--synapse/storage/databases/main/transactions.py11
-rw-r--r--synapse/storage/databases/main/ui_auth.py12
-rw-r--r--synapse/storage/databases/main/user_directory.py12
-rw-r--r--synapse/storage/databases/state/bg_updates.py15
-rw-r--r--synapse/storage/databases/state/store.py27
-rw-r--r--synapse/storage/keys.py6
-rw-r--r--synapse/storage/prepare_database.py6
-rw-r--r--synapse/storage/relations.py20
-rw-r--r--synapse/storage/state.py6
-rw-r--r--synapse/storage/util/id_generators.py8
26 files changed, 428 insertions, 319 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 2cacc7dd6c..57cc1d76e0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -143,7 +143,7 @@ def make_conn(
     return db_conn
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class LoggingDatabaseConnection:
     """A wrapper around a database connection that returns `LoggingTransaction`
     as its cursor class.
@@ -151,9 +151,9 @@ class LoggingDatabaseConnection:
     This is mainly used on startup to ensure that queries get logged correctly
     """
 
-    conn = attr.ib(type=Connection)
-    engine = attr.ib(type=BaseDatabaseEngine)
-    default_txn_name = attr.ib(type=str)
+    conn: Connection
+    engine: BaseDatabaseEngine
+    default_txn_name: str
 
     def cursor(
         self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
@@ -934,56 +934,6 @@ class DatabasePool:
         txn.execute(sql, vals)
 
     async def simple_insert_many(
-        self, table: str, values: List[Dict[str, Any]], desc: str
-    ) -> None:
-        """Executes an INSERT query on the named table.
-
-        The input is given as a list of dicts, with one dict per row.
-        Generally simple_insert_many_values should be preferred for new code.
-
-        Args:
-            table: string giving the table name
-            values: dict of new column names and values for them
-            desc: description of the transaction, for logging and metrics
-        """
-        await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
-
-    @staticmethod
-    def simple_insert_many_txn(
-        txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
-    ) -> None:
-        """Executes an INSERT query on the named table.
-
-        The input is given as a list of dicts, with one dict per row.
-        Generally simple_insert_many_values_txn should be preferred for new code.
-
-        Args:
-            txn: The transaction to use.
-            table: string giving the table name
-            values: dict of new column names and values for them
-        """
-        if not values:
-            return
-
-        # This is a *slight* abomination to get a list of tuples of key names
-        # and a list of tuples of value names.
-        #
-        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
-        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
-        #
-        # The sort is to ensure that we don't rely on dictionary iteration
-        # order.
-        keys, vals = zip(
-            *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i)
-        )
-
-        for k in keys:
-            if k != keys[0]:
-                raise RuntimeError("All items must have the same keys")
-
-        return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals)
-
-    async def simple_insert_many_values(
         self,
         table: str,
         keys: Collection[str],
@@ -1002,11 +952,11 @@ class DatabasePool:
             desc: description of the transaction, for logging and metrics
         """
         await self.runInteraction(
-            desc, self.simple_insert_many_values_txn, table, keys, values
+            desc, self.simple_insert_many_txn, table, keys, values
         )
 
     @staticmethod
-    def simple_insert_many_values_txn(
+    def simple_insert_many_txn(
         txn: LoggingTransaction,
         table: str,
         keys: Collection[str],
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 32a553fdd7..ef475e18c7 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -450,7 +450,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
     async def add_account_data_for_user(
         self, user_id: str, account_data_type: str, content: JsonDict
     ) -> int:
-        """Add some account_data to a room for a user.
+        """Add some global account_data for a user.
 
         Args:
             user_id: The user to add a tag for.
@@ -536,9 +536,9 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
         self.db_pool.simple_insert_many_txn(
             txn,
             table="ignored_users",
+            keys=("ignorer_user_id", "ignored_user_id"),
             values=[
-                {"ignorer_user_id": user_id, "ignored_user_id": u}
-                for u in currently_ignored_users - previously_ignored_users
+                (user_id, u) for u in currently_ignored_users - previously_ignored_users
             ],
         )
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3682cb6a81..4eca97189b 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -432,14 +432,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="device_federation_outbox",
+                keys=(
+                    "destination",
+                    "stream_id",
+                    "queued_ts",
+                    "messages_json",
+                    "instance_name",
+                ),
                 values=[
-                    {
-                        "destination": destination,
-                        "stream_id": stream_id,
-                        "queued_ts": now_ms,
-                        "messages_json": json_encoder.encode(edu),
-                        "instance_name": self._instance_name,
-                    }
+                    (
+                        destination,
+                        stream_id,
+                        now_ms,
+                        json_encoder.encode(edu),
+                        self._instance_name,
+                    )
                     for destination, edu in remote_messages_by_destination.items()
                 ],
             )
@@ -571,14 +578,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         self.db_pool.simple_insert_many_txn(
             txn,
             table="device_inbox",
+            keys=("user_id", "device_id", "stream_id", "message_json", "instance_name"),
             values=[
-                {
-                    "user_id": user_id,
-                    "device_id": device_id,
-                    "stream_id": stream_id,
-                    "message_json": message_json,
-                    "instance_name": self._instance_name,
-                }
+                (user_id, device_id, stream_id, message_json, self._instance_name)
                 for user_id, messages_by_device in local_by_user_then_device.items()
                 for device_id, message_json in messages_by_device.items()
             ],
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index bc7e876047..8f0cd0695f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -781,7 +781,7 @@ class DeviceWorkerStore(SQLBaseStore):
     @cached(max_entries=10000)
     async def get_device_list_last_stream_id_for_remote(
         self, user_id: str
-    ) -> Optional[Any]:
+    ) -> Optional[str]:
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
         """
@@ -797,7 +797,9 @@ class DeviceWorkerStore(SQLBaseStore):
         cached_method_name="get_device_list_last_stream_id_for_remote",
         list_name="user_ids",
     )
-    async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
+    async def get_device_list_last_stream_id_for_remotes(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, Optional[str]]:
         rows = await self.db_pool.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
@@ -1384,6 +1386,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         content: JsonDict,
         stream_id: str,
     ) -> None:
+        """Delete, update or insert a cache entry for this (user, device) pair."""
         if content.get("deleted"):
             self.db_pool.simple_delete_txn(
                 txn,
@@ -1443,6 +1446,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
     def _update_remote_device_list_cache_txn(
         self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
     ) -> None:
+        """Replace the list of cached devices for this user with the given list."""
         self.db_pool.simple_delete_txn(
             txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
         )
@@ -1450,12 +1454,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self.db_pool.simple_insert_many_txn(
             txn,
             table="device_lists_remote_cache",
+            keys=("user_id", "device_id", "content"),
             values=[
-                {
-                    "user_id": user_id,
-                    "device_id": content["device_id"],
-                    "content": json_encoder.encode(content),
-                }
+                (user_id, content["device_id"], json_encoder.encode(content))
                 for content in devices
             ],
         )
@@ -1543,8 +1544,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self.db_pool.simple_insert_many_txn(
             txn,
             table="device_lists_stream",
+            keys=("stream_id", "user_id", "device_id"),
             values=[
-                {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
+                (stream_id, user_id, device_id)
                 for stream_id, device_id in zip(stream_ids, device_ids)
             ],
         )
@@ -1571,18 +1573,27 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self.db_pool.simple_insert_many_txn(
             txn,
             table="device_lists_outbound_pokes",
+            keys=(
+                "destination",
+                "stream_id",
+                "user_id",
+                "device_id",
+                "sent",
+                "ts",
+                "opentracing_context",
+            ),
             values=[
-                {
-                    "destination": destination,
-                    "stream_id": next(next_stream_id),
-                    "user_id": user_id,
-                    "device_id": device_id,
-                    "sent": False,
-                    "ts": now,
-                    "opentracing_context": json_encoder.encode(context)
+                (
+                    destination,
+                    next(next_stream_id),
+                    user_id,
+                    device_id,
+                    False,
+                    now,
+                    json_encoder.encode(context)
                     if whitelisted_homeserver(destination)
                     else "{}",
-                }
+                )
                 for destination in hosts
                 for device_id in device_ids
             ],
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index f76c6121e8..5903fdaf00 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -112,10 +112,8 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="room_alias_servers",
-                values=[
-                    {"room_alias": room_alias.to_string(), "server": server}
-                    for server in servers
-                ],
+                keys=("room_alias", "server"),
+                values=[(room_alias.to_string(), server) for server in servers],
             )
 
             self._invalidate_cache_and_stream(
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 0cb48b9dd7..b789a588a5 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -110,16 +110,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         values = []
         for (room_id, session_id, room_key) in room_keys:
             values.append(
-                {
-                    "user_id": user_id,
-                    "version": version_int,
-                    "room_id": room_id,
-                    "session_id": session_id,
-                    "first_message_index": room_key["first_message_index"],
-                    "forwarded_count": room_key["forwarded_count"],
-                    "is_verified": room_key["is_verified"],
-                    "session_data": json_encoder.encode(room_key["session_data"]),
-                }
+                (
+                    user_id,
+                    version_int,
+                    room_id,
+                    session_id,
+                    room_key["first_message_index"],
+                    room_key["forwarded_count"],
+                    room_key["is_verified"],
+                    json_encoder.encode(room_key["session_data"]),
+                )
             )
             log_kv(
                 {
@@ -131,7 +131,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             )
 
         await self.db_pool.simple_insert_many(
-            table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
+            table="e2e_room_keys",
+            keys=(
+                "user_id",
+                "version",
+                "room_id",
+                "session_id",
+                "first_message_index",
+                "forwarded_count",
+                "is_verified",
+                "session_data",
+            ),
+            values=values,
+            desc="add_e2e_room_keys",
         )
 
     @trace
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 57b5ffbad3..1f8447b507 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -50,16 +50,16 @@ if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class DeviceKeyLookupResult:
     """The type returned by get_e2e_device_keys_and_signatures"""
 
-    display_name = attr.ib(type=Optional[str])
+    display_name: Optional[str]
 
     # the key data from e2e_device_keys_json. Typically includes fields like
     # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
     # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
-    keys = attr.ib(type=Optional[JsonDict])
+    keys: Optional[JsonDict]
 
 
 class EndToEndKeyBackgroundStore(SQLBaseStore):
@@ -387,15 +387,16 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="e2e_one_time_keys_json",
+                keys=(
+                    "user_id",
+                    "device_id",
+                    "algorithm",
+                    "key_id",
+                    "ts_added_ms",
+                    "key_json",
+                ),
                 values=[
-                    {
-                        "user_id": user_id,
-                        "device_id": device_id,
-                        "algorithm": algorithm,
-                        "key_id": key_id,
-                        "ts_added_ms": time_now,
-                        "key_json": json_bytes,
-                    }
+                    (user_id, device_id, algorithm, key_id, time_now, json_bytes)
                     for algorithm, key_id, json_bytes in new_keys
                 ],
             )
@@ -1186,15 +1187,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         """
         await self.db_pool.simple_insert_many(
             "e2e_cross_signing_signatures",
-            [
-                {
-                    "user_id": user_id,
-                    "key_id": item.signing_key_id,
-                    "target_user_id": item.target_user_id,
-                    "target_device_id": item.target_device_id,
-                    "signature": item.signature,
-                }
+            keys=(
+                "user_id",
+                "key_id",
+                "target_user_id",
+                "target_device_id",
+                "signature",
+            ),
+            values=[
+                (
+                    user_id,
+                    item.signing_key_id,
+                    item.target_user_id,
+                    item.target_device_id,
+                    item.signature,
+                )
                 for item in signatures
             ],
-            "add_e2e_signing_key",
+            desc="add_e2e_signing_key",
         )
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index a98e6b2593..b7c4c62222 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -875,14 +875,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_push_summary",
+            keys=(
+                "user_id",
+                "room_id",
+                "notif_count",
+                "unread_count",
+                "stream_ordering",
+            ),
             values=[
-                {
-                    "user_id": user_id,
-                    "room_id": room_id,
-                    "notif_count": summary.notif_count,
-                    "unread_count": summary.unread_count,
-                    "stream_ordering": summary.stream_ordering,
-                }
+                (
+                    user_id,
+                    room_id,
+                    summary.notif_count,
+                    summary.unread_count,
+                    summary.stream_ordering,
+                )
                 for ((user_id, room_id), summary) in summaries.items()
                 if summary.old_user_id is None
             ],
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index dd255aefb9..de3b48524b 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -69,7 +69,7 @@ event_counter = Counter(
 )
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
 
@@ -80,9 +80,9 @@ class DeltaState:
             should e.g. be removed from `current_state_events` table.
     """
 
-    to_delete = attr.ib(type=List[Tuple[str, str]])
-    to_insert = attr.ib(type=StateMap[str])
-    no_longer_in_room = attr.ib(type=bool, default=False)
+    to_delete: List[Tuple[str, str]]
+    to_insert: StateMap[str]
+    no_longer_in_room: bool = False
 
 
 class PersistEventsStore:
@@ -442,12 +442,9 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_auth",
+            keys=("event_id", "room_id", "auth_id"),
             values=[
-                {
-                    "event_id": event.event_id,
-                    "room_id": event.room_id,
-                    "auth_id": auth_id,
-                }
+                (event.event_id, event.room_id, auth_id)
                 for event in events
                 for auth_id in event.auth_event_ids()
                 if event.is_state()
@@ -675,8 +672,9 @@ class PersistEventsStore:
         db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chains",
+            keys=("event_id", "chain_id", "sequence_number"),
             values=[
-                {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+                (event_id, c_id, seq)
                 for event_id, (c_id, seq) in new_chain_tuples.items()
             ],
         )
@@ -782,13 +780,14 @@ class PersistEventsStore:
         db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chain_links",
+            keys=(
+                "origin_chain_id",
+                "origin_sequence_number",
+                "target_chain_id",
+                "target_sequence_number",
+            ),
             values=[
-                {
-                    "origin_chain_id": source_id,
-                    "origin_sequence_number": source_seq,
-                    "target_chain_id": target_id,
-                    "target_sequence_number": target_seq,
-                }
+                (source_id, source_seq, target_id, target_seq)
                 for (
                     source_id,
                     source_seq,
@@ -943,20 +942,28 @@ class PersistEventsStore:
             txn_id = getattr(event.internal_metadata, "txn_id", None)
             if token_id and txn_id:
                 to_insert.append(
-                    {
-                        "event_id": event.event_id,
-                        "room_id": event.room_id,
-                        "user_id": event.sender,
-                        "token_id": token_id,
-                        "txn_id": txn_id,
-                        "inserted_ts": self._clock.time_msec(),
-                    }
+                    (
+                        event.event_id,
+                        event.room_id,
+                        event.sender,
+                        token_id,
+                        txn_id,
+                        self._clock.time_msec(),
+                    )
                 )
 
         if to_insert:
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="event_txn_id",
+                keys=(
+                    "event_id",
+                    "room_id",
+                    "user_id",
+                    "token_id",
+                    "txn_id",
+                    "inserted_ts",
+                ),
                 values=to_insert,
             )
 
@@ -1161,8 +1168,9 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_forward_extremities",
+            keys=("event_id", "room_id"),
             values=[
-                {"event_id": ev_id, "room_id": room_id}
+                (ev_id, room_id)
                 for room_id, new_extrem in new_forward_extremities.items()
                 for ev_id in new_extrem
             ],
@@ -1174,12 +1182,9 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="stream_ordering_to_exterm",
+            keys=("room_id", "event_id", "stream_ordering"),
             values=[
-                {
-                    "room_id": room_id,
-                    "event_id": event_id,
-                    "stream_ordering": max_stream_order,
-                }
+                (room_id, event_id, max_stream_order)
                 for room_id, new_extrem in new_forward_extremities.items()
                 for event_id in new_extrem
             ],
@@ -1342,7 +1347,7 @@ class PersistEventsStore:
             d.pop("redacted_because", None)
             return d
 
-        self.db_pool.simple_insert_many_values_txn(
+        self.db_pool.simple_insert_many_txn(
             txn,
             table="event_json",
             keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
@@ -1358,7 +1363,7 @@ class PersistEventsStore:
             ),
         )
 
-        self.db_pool.simple_insert_many_values_txn(
+        self.db_pool.simple_insert_many_txn(
             txn,
             table="events",
             keys=(
@@ -1412,7 +1417,7 @@ class PersistEventsStore:
         )
         txn.execute(sql + clause, [False] + args)
 
-        self.db_pool.simple_insert_many_values_txn(
+        self.db_pool.simple_insert_many_txn(
             txn,
             table="state_events",
             keys=("event_id", "room_id", "type", "state_key"),
@@ -1622,14 +1627,9 @@ class PersistEventsStore:
         return self.db_pool.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
+            keys=("event_id", "label", "room_id", "topological_ordering"),
             values=[
-                {
-                    "event_id": event_id,
-                    "label": label,
-                    "room_id": room_id,
-                    "topological_ordering": topological_ordering,
-                }
-                for label in labels
+                (event_id, label, room_id, topological_ordering) for label in labels
             ],
         )
 
@@ -1657,16 +1657,13 @@ class PersistEventsStore:
         vals = []
         for event in events:
             ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append(
-                {
-                    "event_id": event.event_id,
-                    "algorithm": ref_alg,
-                    "hash": memoryview(ref_hash_bytes),
-                }
-            )
+            vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
 
         self.db_pool.simple_insert_many_txn(
-            txn, table="event_reference_hashes", values=vals
+            txn,
+            table="event_reference_hashes",
+            keys=("event_id", "algorithm", "hash"),
+            values=vals,
         )
 
     def _store_room_members_txn(
@@ -1689,18 +1686,25 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="room_memberships",
+            keys=(
+                "event_id",
+                "user_id",
+                "sender",
+                "room_id",
+                "membership",
+                "display_name",
+                "avatar_url",
+            ),
             values=[
-                {
-                    "event_id": event.event_id,
-                    "user_id": event.state_key,
-                    "sender": event.user_id,
-                    "room_id": event.room_id,
-                    "membership": event.membership,
-                    "display_name": non_null_str_or_none(
-                        event.content.get("displayname")
-                    ),
-                    "avatar_url": non_null_str_or_none(event.content.get("avatar_url")),
-                }
+                (
+                    event.event_id,
+                    event.state_key,
+                    event.user_id,
+                    event.room_id,
+                    event.membership,
+                    non_null_str_or_none(event.content.get("displayname")),
+                    non_null_str_or_none(event.content.get("avatar_url")),
+                )
                 for event in events
             ],
         )
@@ -2163,13 +2167,9 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_edges",
+            keys=("event_id", "prev_event_id", "room_id", "is_state"),
             values=[
-                {
-                    "event_id": ev.event_id,
-                    "prev_event_id": e_id,
-                    "room_id": ev.room_id,
-                    "is_state": False,
-                }
+                (ev.event_id, e_id, ev.room_id, False)
                 for ev in events
                 for e_id in ev.prev_event_ids()
             ],
@@ -2226,17 +2226,17 @@ class PersistEventsStore:
         )
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _LinkMap:
     """A helper type for tracking links between chains."""
 
     # Stores the set of links as nested maps: source chain ID -> target chain ID
     # -> source sequence number -> target sequence number.
-    maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+    maps: Dict[int, Dict[int, Dict[int, int]]] = attr.Factory(dict)
 
     # Stores the links that have been added (with new set to true), as tuples of
     # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
-    additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+    additions: Set[Tuple[int, int, int, int]] = attr.Factory(set)
 
     def add_link(
         self,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index a68f14ba48..d5f0059665 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -65,22 +65,22 @@ class _BackgroundUpdates:
     REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class _CalculateChainCover:
     """Return value for _calculate_chain_cover_txn."""
 
     # The last room_id/depth/stream processed.
-    room_id = attr.ib(type=str)
-    depth = attr.ib(type=int)
-    stream = attr.ib(type=int)
+    room_id: str
+    depth: int
+    stream: int
 
     # Number of rows processed
-    processed_count = attr.ib(type=int)
+    processed_count: int
 
     # Map from room_id to last depth/stream processed for each room that we have
     # processed all events for (i.e. the rooms we can flip the
     # `has_auth_chain_index` for)
-    finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+    finished_room_map: Dict[str, Tuple[int, int]]
 
 
 class EventsBackgroundUpdatesStore(SQLBaseStore):
@@ -684,13 +684,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                     self.db_pool.simple_insert_many_txn(
                         txn=txn,
                         table="event_labels",
+                        keys=("event_id", "label", "room_id", "topological_ordering"),
                         values=[
-                            {
-                                "event_id": event_id,
-                                "label": label,
-                                "room_id": event_json["room_id"],
-                                "topological_ordering": event_json["depth"],
-                            }
+                            (
+                                event_id,
+                                label,
+                                event_json["room_id"],
+                                event_json["depth"],
+                            )
                             for label in event_json["content"].get(
                                 EventContentFields.LABELS, []
                             )
@@ -803,29 +804,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             if not has_state:
                 state_events.append(
-                    {
-                        "event_id": event.event_id,
-                        "room_id": event.room_id,
-                        "type": event.type,
-                        "state_key": event.state_key,
-                    }
+                    (event.event_id, event.room_id, event.type, event.state_key)
                 )
 
             if not has_event_auth:
                 # Old, dodgy, events may have duplicate auth events, which we
                 # need to deduplicate as we have a unique constraint.
                 for auth_id in set(event.auth_event_ids()):
-                    auth_events.append(
-                        {
-                            "room_id": event.room_id,
-                            "event_id": event.event_id,
-                            "auth_id": auth_id,
-                        }
-                    )
+                    auth_events.append((event.event_id, event.room_id, auth_id))
 
         if state_events:
             await self.db_pool.simple_insert_many(
                 table="state_events",
+                keys=("event_id", "room_id", "type", "state_key"),
                 values=state_events,
                 desc="_rejected_events_metadata_state_events",
             )
@@ -833,6 +824,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         if auth_events:
             await self.db_pool.simple_insert_many(
                 table="event_auth",
+                keys=("event_id", "room_id", "auth_id"),
                 values=auth_events,
                 desc="_rejected_events_metadata_event_auth",
             )
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index cbf9ec38f7..4f05811a77 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -129,18 +129,29 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         self.db_pool.simple_insert_many_txn(
             txn,
             table="presence_stream",
+            keys=(
+                "stream_id",
+                "user_id",
+                "state",
+                "last_active_ts",
+                "last_federation_update_ts",
+                "last_user_sync_ts",
+                "status_msg",
+                "currently_active",
+                "instance_name",
+            ),
             values=[
-                {
-                    "stream_id": stream_id,
-                    "user_id": state.user_id,
-                    "state": state.state,
-                    "last_active_ts": state.last_active_ts,
-                    "last_federation_update_ts": state.last_federation_update_ts,
-                    "last_user_sync_ts": state.last_user_sync_ts,
-                    "status_msg": state.status_msg,
-                    "currently_active": state.currently_active,
-                    "instance_name": self._instance_name,
-                }
+                (
+                    stream_id,
+                    state.user_id,
+                    state.state,
+                    state.last_active_ts,
+                    state.last_federation_update_ts,
+                    state.last_user_sync_ts,
+                    state.status_msg,
+                    state.currently_active,
+                    self._instance_name,
+                )
                 for stream_id, state in zip(stream_orderings, presence_states)
             ],
         )
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 747b4f31df..cf64cd63a4 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -561,13 +561,9 @@ class PusherStore(PusherWorkerStore):
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="deleted_pushers",
+                keys=("stream_id", "app_id", "pushkey", "user_id"),
                 values=[
-                    {
-                        "stream_id": stream_id,
-                        "app_id": pusher.app_id,
-                        "pushkey": pusher.pushkey,
-                        "user_id": user_id,
-                    }
+                    (stream_id, pusher.app_id, pusher.pushkey, user_id)
                     for stream_id, pusher in zip(stream_ids, pushers)
                 ],
             )
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 4175c82a25..aac94fa464 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -51,7 +51,7 @@ class ExternalIDReuseException(Exception):
     pass
 
 
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
 class TokenLookupResult:
     """Result of looking up an access token.
 
@@ -69,14 +69,14 @@ class TokenLookupResult:
             cached.
     """
 
-    user_id = attr.ib(type=str)
-    is_guest = attr.ib(type=bool, default=False)
-    shadow_banned = attr.ib(type=bool, default=False)
-    token_id = attr.ib(type=Optional[int], default=None)
-    device_id = attr.ib(type=Optional[str], default=None)
-    valid_until_ms = attr.ib(type=Optional[int], default=None)
-    token_owner = attr.ib(type=str)
-    token_used = attr.ib(type=bool, default=False)
+    user_id: str
+    is_guest: bool = False
+    shadow_banned: bool = False
+    token_id: Optional[int] = None
+    device_id: Optional[str] = None
+    valid_until_ms: Optional[int] = None
+    token_owner: str = attr.ib()
+    token_used: bool = False
 
     # Make the token owner default to the user ID, which is the common case.
     @token_owner.default
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 4ff6aed253..c6c4bd18da 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,14 +13,30 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
 
 import attr
+from frozendict import frozendict
 
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.relations import (
     AggregationPaginationToken,
@@ -29,10 +45,24 @@ from synapse.storage.relations import (
 )
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class RelationsWorkerStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._msc1849_enabled = hs.config.experimental.msc1849_enabled
+        self._msc3440_enabled = hs.config.experimental.msc3440_enabled
+
     @cached(tree=True)
     async def get_relations_for_event(
         self,
@@ -515,6 +545,98 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
+    async def _get_bundled_aggregation_for_event(
+        self, event: EventBase
+    ) -> Optional[Dict[str, Any]]:
+        """Generate bundled aggregations for an event.
+
+        Note that this does not use a cache, but depends on cached methods.
+
+        Args:
+            event: The event to calculate bundled aggregations for.
+
+        Returns:
+            The bundled aggregations for an event, if bundled aggregations are
+            enabled and the event can have bundled aggregations.
+        """
+        # State events and redacted events do not get bundled aggregations.
+        if event.is_state() or event.internal_metadata.is_redacted():
+            return None
+
+        # Do not bundle aggregations for an event which represents an edit or an
+        # annotation. It does not make sense for them to have related events.
+        relates_to = event.content.get("m.relates_to")
+        if isinstance(relates_to, (dict, frozendict)):
+            relation_type = relates_to.get("rel_type")
+            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+                return None
+
+        event_id = event.event_id
+        room_id = event.room_id
+
+        # The bundled aggregations to include, a mapping of relation type to a
+        # type-specific value. Some types include the direct return type here
+        # while others need more processing during serialization.
+        aggregations: Dict[str, Any] = {}
+
+        annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
+        if annotations.chunk:
+            aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+        references = await self.get_relations_for_event(
+            event_id, room_id, RelationTypes.REFERENCE, direction="f"
+        )
+        if references.chunk:
+            aggregations[RelationTypes.REFERENCE] = references.to_dict()
+
+        edit = None
+        if event.type == EventTypes.Message:
+            edit = await self.get_applicable_edit(event_id, room_id)
+
+        if edit:
+            aggregations[RelationTypes.REPLACE] = edit
+
+        # If this event is the start of a thread, include a summary of the replies.
+        if self._msc3440_enabled:
+            (
+                thread_count,
+                latest_thread_event,
+            ) = await self.get_thread_summary(event_id, room_id)
+            if latest_thread_event:
+                aggregations[RelationTypes.THREAD] = {
+                    # Don't bundle aggregations as this could recurse forever.
+                    "latest_event": latest_thread_event,
+                    "count": thread_count,
+                }
+
+        # Store the bundled aggregations in the event metadata for later use.
+        return aggregations
+
+    async def get_bundled_aggregations(
+        self, events: Iterable[EventBase]
+    ) -> Dict[str, Dict[str, Any]]:
+        """Generate bundled aggregations for events.
+
+        Args:
+            events: The iterable of events to calculate bundled aggregations for.
+
+        Returns:
+            A map of event ID to the bundled aggregation for the event. Not all
+            events may have bundled aggregations in the results.
+        """
+        # If bundled aggregations are disabled, nothing to do.
+        if not self._msc1849_enabled:
+            return {}
+
+        # TODO Parallelize.
+        results = {}
+        for event in events:
+            event_result = await self._get_bundled_aggregation_for_event(event)
+            if event_result is not None:
+                results[event.event_id] = event_result
+
+        return results
+
 
 class RelationsStore(RelationsWorkerStore):
     pass
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index cda80d6511..4489732fda 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1177,18 +1177,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
         await self.db_pool.runInteraction("forget_membership", f)
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _JoinedHostsCache:
     """The cached data used by the `_get_joined_hosts_cache`."""
 
     # Dict of host to the set of their users in the room at the state group.
-    hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict)
+    hosts_to_joined_users: Dict[str, Set[str]] = attr.Factory(dict)
 
     # The state group `hosts_to_joined_users` is derived from. Will be an object
     # if the instance is newly created or if the state is not based on a state
     # group. (An object is used as a sentinel value to ensure that it never is
     # equal to anything else).
-    state_group = attr.ib(type=Union[object, int], factory=object)
+    state_group: Union[object, int] = attr.Factory(object)
 
     def __len__(self):
         return sum(len(v) for v in self.hosts_to_joined_users.values())
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
index 5a97120437..e8c776b97a 100644
--- a/synapse/storage/databases/main/session.py
+++ b/synapse/storage/databases/main/session.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 #  Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6c299cafa5..4b78b4d098 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -560,3 +560,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         return await self.db_pool.runInteraction(
             "get_destinations_paginate_txn", get_destinations_paginate_txn
         )
+
+    async def is_destination_known(self, destination: str) -> bool:
+        """Check if a destination is known to the server."""
+        result = await self.db_pool.simple_select_one_onecol(
+            table="destinations",
+            keyvalues={"destination": destination},
+            retcol="1",
+            allow_none=True,
+            desc="is_destination_known",
+        )
+        return bool(result)
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index a1a1a6a14a..2d339b6008 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -23,19 +23,19 @@ from synapse.types import JsonDict
 from synapse.util import json_encoder, stringutils
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class UIAuthSessionData:
-    session_id = attr.ib(type=str)
+    session_id: str
     # The dictionary from the client root level, not the 'auth' key.
-    clientdict = attr.ib(type=JsonDict)
+    clientdict: JsonDict
     # The URI and method the session was intiatied with. These are checked at
     # each stage of the authentication to ensure that the asked for operation
     # has not changed.
-    uri = attr.ib(type=str)
-    method = attr.ib(type=str)
+    uri: str
+    method: str
     # A string description of the operation that the current authentication is
     # authorising.
-    description = attr.ib(type=str)
+    description: str
 
 
 class UIAuthWorkerStore(SQLBaseStore):
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 0f9b8575d3..f7c778bdf2 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -105,8 +105,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 GROUP BY room_id
             """
             txn.execute(sql)
-            rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
-            self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+            rooms = list(txn.fetchall())
+            self.db_pool.simple_insert_many_txn(
+                txn, TEMP_TABLE + "_rooms", keys=("room_id", "events"), values=rooms
+            )
             del rooms
 
             sql = (
@@ -117,9 +119,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             txn.execute(sql)
 
             txn.execute("SELECT name FROM users")
-            users = [{"user_id": x[0]} for x in txn.fetchall()]
+            users = list(txn.fetchall())
 
-            self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+            self.db_pool.simple_insert_many_txn(
+                txn, TEMP_TABLE + "_users", keys=("user_id",), values=users
+            )
 
         new_pos = await self.get_max_stream_id_in_current_state_deltas()
         await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index eb1118d2cb..5de70f31d2 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -327,14 +327,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
                         self.db_pool.simple_insert_many_txn(
                             txn,
                             table="state_groups_state",
+                            keys=(
+                                "state_group",
+                                "room_id",
+                                "type",
+                                "state_key",
+                                "event_id",
+                            ),
                             values=[
-                                {
-                                    "state_group": state_group,
-                                    "room_id": room_id,
-                                    "type": key[0],
-                                    "state_key": key[1],
-                                    "event_id": state_id,
-                                }
+                                (state_group, room_id, key[0], key[1], state_id)
                                 for key, state_id in delta_state.items()
                             ],
                         )
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index c4c8c0021b..7614d76ac6 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -460,14 +460,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 self.db_pool.simple_insert_many_txn(
                     txn,
                     table="state_groups_state",
+                    keys=("state_group", "room_id", "type", "state_key", "event_id"),
                     values=[
-                        {
-                            "state_group": state_group,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
+                        (state_group, room_id, key[0], key[1], state_id)
                         for key, state_id in delta_ids.items()
                     ],
                 )
@@ -475,14 +470,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 self.db_pool.simple_insert_many_txn(
                     txn,
                     table="state_groups_state",
+                    keys=("state_group", "room_id", "type", "state_key", "event_id"),
                     values=[
-                        {
-                            "state_group": state_group,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
+                        (state_group, room_id, key[0], key[1], state_id)
                         for key, state_id in current_state_ids.items()
                     ],
                 )
@@ -589,14 +579,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
                 values=[
-                    {
-                        "state_group": sg,
-                        "room_id": room_id,
-                        "type": key[0],
-                        "state_key": key[1],
-                        "event_id": state_id,
-                    }
+                    (sg, room_id, key[0], key[1], state_id)
                     for key, state_id in curr_state.items()
                 ],
             )
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 540adb8781..71584f3f74 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -21,7 +21,7 @@ from signedjson.types import VerifyKey
 logger = logging.getLogger(__name__)
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class FetchKeyResult:
-    verify_key = attr.ib(type=VerifyKey)  # the key itself
-    valid_until_ts = attr.ib(type=int)  # how long we can use this key for
+    verify_key: VerifyKey  # the key itself
+    valid_until_ts: int  # how long we can use this key for
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e45adfcb55..1823e18720 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -696,7 +696,7 @@ def _get_or_create_schema_state(
     )
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _DirectoryListing:
     """Helper class to store schema file name and the
     absolute path to it.
@@ -705,5 +705,5 @@ class _DirectoryListing:
     `file_name` attr is kept first.
     """
 
-    file_name = attr.ib(type=str)
-    absolute_path = attr.ib(type=str)
+    file_name: str
+    absolute_path: str
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 10a46b5e82..b1536c1ca4 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -23,7 +23,7 @@ from synapse.types import JsonDict
 logger = logging.getLogger(__name__)
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class PaginationChunk:
     """Returned by relation pagination APIs.
 
@@ -35,9 +35,9 @@ class PaginationChunk:
             None then there are no previous results.
     """
 
-    chunk = attr.ib(type=List[JsonDict])
-    next_batch = attr.ib(type=Optional[Any], default=None)
-    prev_batch = attr.ib(type=Optional[Any], default=None)
+    chunk: List[JsonDict]
+    next_batch: Optional[Any] = None
+    prev_batch: Optional[Any] = None
 
     def to_dict(self) -> Dict[str, Any]:
         d = {"chunk": self.chunk}
@@ -51,7 +51,7 @@ class PaginationChunk:
         return d
 
 
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
 class RelationPaginationToken:
     """Pagination token for relation pagination API.
 
@@ -64,8 +64,8 @@ class RelationPaginationToken:
         stream: The stream ordering of the boundary event.
     """
 
-    topological = attr.ib(type=int)
-    stream = attr.ib(type=int)
+    topological: int
+    stream: int
 
     @staticmethod
     def from_string(string: str) -> "RelationPaginationToken":
@@ -82,7 +82,7 @@ class RelationPaginationToken:
         return attr.astuple(self)
 
 
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
 class AggregationPaginationToken:
     """Pagination token for relation aggregation pagination API.
 
@@ -94,8 +94,8 @@ class AggregationPaginationToken:
         stream: The MAX stream ordering in the boundary group.
     """
 
-    count = attr.ib(type=int)
-    stream = attr.ib(type=int)
+    count: int
+    stream: int
 
     @staticmethod
     def from_string(string: str) -> "AggregationPaginationToken":
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b5ba1560d1..df8b2f1088 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -45,7 +45,7 @@ logger = logging.getLogger(__name__)
 T = TypeVar("T")
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class StateFilter:
     """A filter used when querying for state.
 
@@ -58,8 +58,8 @@ class StateFilter:
             appear in `types`.
     """
 
-    types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
-    include_others = attr.ib(default=False, type=bool)
+    types: "frozendict[str, Optional[FrozenSet[str]]]"
+    include_others: bool = False
 
     def __attrs_post_init__(self):
         # If `include_others` is set we canonicalise the filter by removing
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b8112e1c05..3c13859faa 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -762,13 +762,13 @@ class _AsyncCtxManagerWrapper(Generic[T]):
         return self.inner.__exit__(exc_type, exc, tb)
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _MultiWriterCtxManager:
     """Async context manager returned by MultiWriterIdGenerator"""
 
-    id_gen = attr.ib(type=MultiWriterIdGenerator)
-    multiple_ids = attr.ib(type=Optional[int], default=None)
-    stream_ids = attr.ib(type=List[int], factory=list)
+    id_gen: MultiWriterIdGenerator
+    multiple_ids: Optional[int] = None
+    stream_ids: List[int] = attr.Factory(list)
 
     async def __aenter__(self) -> Union[int, List[int]]:
         # It's safe to run this in autocommit mode as fetching values from a