summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/11654.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py69
3 files changed, 44 insertions, 30 deletions
diff --git a/changelog.d/11654.misc b/changelog.d/11654.misc
new file mode 100644
index 0000000000..8e405b9226
--- /dev/null
+++ b/changelog.d/11654.misc
@@ -0,0 +1 @@
+Add missing type hints to storage classes.
\ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index e0204a3c04..85fa22d28f 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -28,7 +28,6 @@ exclude = (?x)
    |synapse/storage/databases/main/cache.py
    |synapse/storage/databases/main/devices.py
    |synapse/storage/databases/main/event_federation.py
-   |synapse/storage/databases/main/events_bg_updates.py
    |synapse/storage/databases/main/group_server.py
    |synapse/storage/databases/main/metrics.py
    |synapse/storage/databases/main/monthly_active_users.py
@@ -200,6 +199,9 @@ disallow_untyped_defs = True
 [mypy-synapse.storage.databases.main.event_push_actions]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.events_bg_updates]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.events_worker]
 disallow_untyped_defs = True
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 9b36941fec..a68f14ba48 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
 
 import attr
 
@@ -240,12 +240,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         ################################################################################
 
-    async def _background_reindex_fields_sender(self, progress, batch_size):
+    async def _background_reindex_fields_sender(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        def reindex_txn(txn):
+        def reindex_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT stream_ordering, event_id, json FROM events"
                 " INNER JOIN event_json USING (event_id)"
@@ -307,12 +309,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return result
 
-    async def _background_reindex_origin_server_ts(self, progress, batch_size):
+    async def _background_reindex_origin_server_ts(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT stream_ordering, event_id FROM events"
                 " WHERE ? <= stream_ordering AND stream_ordering < ?"
@@ -381,7 +385,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return result
 
-    async def _cleanup_extremities_bg_update(self, progress, batch_size):
+    async def _cleanup_extremities_bg_update(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to clean out extremities that should have been
         deleted previously.
 
@@ -402,12 +408,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         # have any descendants, but if they do then we should delete those
         # extremities.
 
-        def _cleanup_extremities_bg_update_txn(txn):
+        def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
             # The set of extremity event IDs that we're checking this round
             original_set = set()
 
-            # A dict[str, set[str]] of event ID to their prev events.
-            graph = {}
+            # A dict[str, Set[str]] of event ID to their prev events.
+            graph: Dict[str, Set[str]] = {}
 
             # The set of descendants of the original set that are not rejected
             # nor soft-failed. Ancestors of these events should be removed
@@ -536,7 +542,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 room_ids = {row["room_id"] for row in rows}
                 for room_id in room_ids:
                     txn.call_after(
-                        self.get_latest_event_ids_in_room.invalidate, (room_id,)
+                        self.get_latest_event_ids_in_room.invalidate, (room_id,)  # type: ignore[attr-defined]
                     )
 
             self.db_pool.simple_delete_many_txn(
@@ -558,7 +564,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
             )
 
-            def _drop_table_txn(txn):
+            def _drop_table_txn(txn: LoggingTransaction) -> None:
                 txn.execute("DROP TABLE _extremities_to_check")
 
             await self.db_pool.runInteraction(
@@ -567,11 +573,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return num_handled
 
-    async def _redactions_received_ts(self, progress, batch_size):
+    async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
         """Handles filling out the `received_ts` column in redactions."""
         last_event_id = progress.get("last_event_id", "")
 
-        def _redactions_received_ts_txn(txn):
+        def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
             # Fetch the set of event IDs that we want to update
             sql = """
                 SELECT event_id FROM redactions
@@ -622,10 +628,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return count
 
-    async def _event_fix_redactions_bytes(self, progress, batch_size):
+    async def _event_fix_redactions_bytes(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Undoes hex encoded censored redacted event JSON."""
 
-        def _event_fix_redactions_bytes_txn(txn):
+        def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
             # This update is quite fast due to new index.
             txn.execute(
                 """
@@ -650,11 +658,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return 1
 
-    async def _event_store_labels(self, progress, batch_size):
+    async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
         """Background update handler which will store labels for existing events."""
         last_event_id = progress.get("last_event_id", "")
 
-        def _event_store_labels_txn(txn):
+        def _event_store_labels_txn(txn: LoggingTransaction) -> int:
             txn.execute(
                 """
                 SELECT event_id, json FROM event_json
@@ -754,7 +762,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 ),
             )
 
-            return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn]  # type: ignore
+            return cast(
+                List[Tuple[str, str, JsonDict, bool, bool]],
+                [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
+            )
 
         results = await self.db_pool.runInteraction(
             desc="_rejected_events_metadata_get", func=get_rejected_events
@@ -912,7 +923,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
     def _calculate_chain_cover_txn(
         self,
-        txn: Cursor,
+        txn: LoggingTransaction,
         last_room_id: str,
         last_depth: int,
         last_stream: int,
@@ -1023,10 +1034,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         PersistEventsStore._add_chain_cover_index(
             txn,
             self.db_pool,
-            self.event_chain_id_gen,
+            self.event_chain_id_gen,  # type: ignore[attr-defined]
             event_to_room_id,
             event_to_types,
-            event_to_auth_chain,
+            cast(Dict[str, Sequence[str]], event_to_auth_chain),
         )
 
         return _CalculateChainCover(
@@ -1046,7 +1057,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         """
         current_event_id = progress.get("current_event_id", "")
 
-        def purged_chain_cover_txn(txn) -> int:
+        def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
             # The event ID from events will be null if the chain ID / sequence
             # number points to a purged event.
             sql = """
@@ -1181,14 +1192,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 # Iterate the parent IDs and invalidate caches.
                 for parent_id in {r[1] for r in relations_to_insert}:
                     cache_tuple = (parent_id,)
-                    self._invalidate_cache_and_stream(
-                        txn, self.get_relations_for_event, cache_tuple
+                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
+                        txn, self.get_relations_for_event, cache_tuple  # type: ignore[attr-defined]
                     )
-                    self._invalidate_cache_and_stream(
-                        txn, self.get_aggregation_groups_for_event, cache_tuple
+                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
+                        txn, self.get_aggregation_groups_for_event, cache_tuple  # type: ignore[attr-defined]
                     )
-                    self._invalidate_cache_and_stream(
-                        txn, self.get_thread_summary, cache_tuple
+                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
+                        txn, self.get_thread_summary, cache_tuple  # type: ignore[attr-defined]
                     )
 
             if results:
@@ -1220,7 +1231,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         """
         batch_size = max(batch_size, 1)
 
-        def process(txn: Cursor) -> int:
+        def process(txn: LoggingTransaction) -> int:
             last_stream = progress.get("last_stream", -(1 << 31))
             txn.execute(
                 """