diff options
Diffstat (limited to '')
-rw-r--r-- | changelog.d/11654.misc | 1 | ||||
-rw-r--r-- | mypy.ini | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_bg_updates.py | 69 |
3 files changed, 44 insertions, 30 deletions
diff --git a/changelog.d/11654.misc b/changelog.d/11654.misc new file mode 100644 index 0000000000..8e405b9226 --- /dev/null +++ b/changelog.d/11654.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index e0204a3c04..85fa22d28f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,7 +28,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py |synapse/storage/databases/main/event_federation.py - |synapse/storage/databases/main/events_bg_updates.py |synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py @@ -200,6 +199,9 @@ disallow_untyped_defs = True [mypy-synapse.storage.databases.main.event_push_actions] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.events_bg_updates] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.events_worker] disallow_untyped_defs = True diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 9b36941fec..a68f14ba48 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast import attr @@ -240,12 +240,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ################################################################################ - async def _background_reindex_fields_sender(self, progress, batch_size): + async def _background_reindex_fields_sender( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - def reindex_txn(txn): + def reindex_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT stream_ordering, event_id, json FROM events" " INNER JOIN event_json USING (event_id)" @@ -307,12 +309,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return result - async def _background_reindex_origin_server_ts(self, progress, batch_size): + async def _background_reindex_origin_server_ts( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - def reindex_search_txn(txn): + def reindex_search_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT stream_ordering, event_id FROM events" " WHERE ? <= stream_ordering AND stream_ordering < ?" @@ -381,7 +385,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return result - async def _cleanup_extremities_bg_update(self, progress, batch_size): + async def _cleanup_extremities_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to clean out extremities that should have been deleted previously. @@ -402,12 +408,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # have any descendants, but if they do then we should delete those # extremities. - def _cleanup_extremities_bg_update_txn(txn): + def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int: # The set of extremity event IDs that we're checking this round original_set = set() - # A dict[str, set[str]] of event ID to their prev events. - graph = {} + # A dict[str, Set[str]] of event ID to their prev events. + graph: Dict[str, Set[str]] = {} # The set of descendants of the original set that are not rejected # nor soft-failed. Ancestors of these events should be removed @@ -536,7 +542,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): room_ids = {row["room_id"] for row in rows} for room_id in room_ids: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) + self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined] ) self.db_pool.simple_delete_many_txn( @@ -558,7 +564,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES ) - def _drop_table_txn(txn): + def _drop_table_txn(txn: LoggingTransaction) -> None: txn.execute("DROP TABLE _extremities_to_check") await self.db_pool.runInteraction( @@ -567,11 +573,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return num_handled - async def _redactions_received_ts(self, progress, batch_size): + async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int: """Handles filling out the `received_ts` column in redactions.""" last_event_id = progress.get("last_event_id", "") - def _redactions_received_ts_txn(txn): + def _redactions_received_ts_txn(txn: LoggingTransaction) -> int: # Fetch the set of event IDs that we want to update sql = """ SELECT event_id FROM redactions @@ -622,10 +628,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return count - async def _event_fix_redactions_bytes(self, progress, batch_size): + async def _event_fix_redactions_bytes( + self, progress: JsonDict, batch_size: int + ) -> int: """Undoes hex encoded censored redacted event JSON.""" - def _event_fix_redactions_bytes_txn(txn): + def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None: # This update is quite fast due to new index. txn.execute( """ @@ -650,11 +658,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return 1 - async def _event_store_labels(self, progress, batch_size): + async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int: """Background update handler which will store labels for existing events.""" last_event_id = progress.get("last_event_id", "") - def _event_store_labels_txn(txn): + def _event_store_labels_txn(txn: LoggingTransaction) -> int: txn.execute( """ SELECT event_id, json FROM event_json @@ -754,7 +762,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ), ) - return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore + return cast( + List[Tuple[str, str, JsonDict, bool, bool]], + [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn], + ) results = await self.db_pool.runInteraction( desc="_rejected_events_metadata_get", func=get_rejected_events @@ -912,7 +923,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): def _calculate_chain_cover_txn( self, - txn: Cursor, + txn: LoggingTransaction, last_room_id: str, last_depth: int, last_stream: int, @@ -1023,10 +1034,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): PersistEventsStore._add_chain_cover_index( txn, self.db_pool, - self.event_chain_id_gen, + self.event_chain_id_gen, # type: ignore[attr-defined] event_to_room_id, event_to_types, - event_to_auth_chain, + cast(Dict[str, Sequence[str]], event_to_auth_chain), ) return _CalculateChainCover( @@ -1046,7 +1057,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): """ current_event_id = progress.get("current_event_id", "") - def purged_chain_cover_txn(txn) -> int: + def purged_chain_cover_txn(txn: LoggingTransaction) -> int: # The event ID from events will be null if the chain ID / sequence # number points to a purged event. sql = """ @@ -1181,14 +1192,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # Iterate the parent IDs and invalidate caches. for parent_id in {r[1] for r in relations_to_insert}: cache_tuple = (parent_id,) - self._invalidate_cache_and_stream( - txn, self.get_relations_for_event, cache_tuple + self._invalidate_cache_and_stream( # type: ignore[attr-defined] + txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined] ) - self._invalidate_cache_and_stream( - txn, self.get_aggregation_groups_for_event, cache_tuple + self._invalidate_cache_and_stream( # type: ignore[attr-defined] + txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined] ) - self._invalidate_cache_and_stream( - txn, self.get_thread_summary, cache_tuple + self._invalidate_cache_and_stream( # type: ignore[attr-defined] + txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined] ) if results: @@ -1220,7 +1231,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): """ batch_size = max(batch_size, 1) - def process(txn: Cursor) -> int: + def process(txn: LoggingTransaction) -> int: last_stream = progress.get("last_stream", -(1 << 31)) txn.execute( """ |