From ac95167d2f11096f6338e2c7cfe159b1dbb02454 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 28 Mar 2022 20:11:14 +0200 Subject: Add some type hints to datastore. (#12255) --- synapse/storage/databases/main/media_repository.py | 2 +- synapse/storage/databases/main/receipts.py | 37 +++++++++++++++------- synapse/storage/databases/main/registration.py | 3 +- synapse/storage/databases/main/room.py | 3 +- synapse/storage/databases/main/search.py | 26 +++++++-------- synapse/storage/databases/main/state.py | 24 ++++++++------ synapse/storage/databases/main/stats.py | 2 +- synapse/storage/databases/main/user_directory.py | 2 +- 8 files changed, 60 insertions(+), 39 deletions(-) (limited to 'synapse/storage/databases/main') diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index cbba356b4a..322ed05390 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): hs: "HomeServer", ): super().__init__(database, db_conn, hs) - self.server_name = hs.hostname + self.server_name: str = hs.hostname async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index bf0b903af2..e6f97aeece 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -24,10 +24,9 @@ from typing import ( Optional, Set, Tuple, + cast, ) -from twisted.internet import defer - from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream @@ -38,7 +37,11 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_receipts = ( @@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore): " AND user_id = ?" ) txn.execute(sql, (user_id,)) - return txn.fetchall() + return cast(List[Tuple[str, str, int, int]], txn.fetchall()) rows = await self.db_pool.runInteraction( "get_receipts_for_user_with_orderings", f @@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if not rows: return [] - content = {} + content: JsonDict = {} for row in rows: content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ row["user_id"] @@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "_get_linearized_receipts_for_rooms", f ) - results = {} + results: JsonDict = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. @@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "get_linearized_receipts_for_all_rooms", f ) - results = {} + results: JsonDict = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. @@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore): """ if last_id == current_id: - return defer.succeed([]) + return [] def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ @@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore): """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] + updates = cast( + List[Tuple[int, list]], + [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn], + ) limited = False upper_bound = current_id @@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore): self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self.get_receipts_for_room.invalidate((room_id, receipt_type)) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == ReceiptsStream.NAME: self._receipts_id_gen.advance(instance_name, token) for row in rows: @@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) if receipt_type == ReceiptTypes.READ and stream_ordering is not None: - self._remove_old_push_actions_before_txn( + self._remove_old_push_actions_before_txn( # type: ignore[attr-defined] txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "insert_receipt_conv", graph_to_linear ) - async with self._receipts_id_gen.get_next() as stream_id: + async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index a698d10cc5..7f3d190e94 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -22,6 +22,7 @@ import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( DatabasePool, @@ -123,7 +124,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ): super().__init__(database, db_conn, hs) - self.config = hs.config + self.config: HomeServerConfig = hs.config # Note: we don't check this sequence for consistency as we'd have to # call `find_max_generated_user_id_localpart` each time, which is diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 94068940b9..18b1acd9e1 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -34,6 +34,7 @@ import attr from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -98,7 +99,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ): super().__init__(database, db_conn, hs) - self.config = hs.config + self.config: HomeServerConfig = hs.config async def store_room( self, diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index c5e9010c83..d4482c06db 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set import attr @@ -74,7 +74,7 @@ class SearchWorkerStore(SQLBaseStore): " VALUES (?,?,?,to_tsvector('english', ?),?,?)" ) - args = ( + args1 = ( ( entry.event_id, entry.room_id, @@ -86,14 +86,14 @@ class SearchWorkerStore(SQLBaseStore): for entry in entries ) - txn.execute_batch(sql, args) + txn.execute_batch(sql, args1) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( "INSERT INTO event_search (event_id, room_id, key, value)" " VALUES (?,?,?,?)" ) - args = ( + args2 = ( ( entry.event_id, entry.room_id, @@ -102,7 +102,7 @@ class SearchWorkerStore(SQLBaseStore): ) for entry in entries ) - txn.execute_batch(sql, args) + txn.execute_batch(sql, args2) else: # This should be unreachable. @@ -427,7 +427,7 @@ class SearchStore(SearchBackgroundUpdateStore): search_query = _parse_query(self.database_engine, search_term) - args = [] + args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. @@ -496,7 +496,7 @@ class SearchStore(SearchBackgroundUpdateStore): # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = await self.get_events_as_list( + events = await self.get_events_as_list( # type: ignore[attr-defined] [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) @@ -530,7 +530,7 @@ class SearchStore(SearchBackgroundUpdateStore): room_ids: Collection[str], search_term: str, keys: Iterable[str], - limit, + limit: int, pagination_token: Optional[str] = None, ) -> JsonDict: """Performs a full text search over events with given keys. @@ -549,7 +549,7 @@ class SearchStore(SearchBackgroundUpdateStore): search_query = _parse_query(self.database_engine, search_term) - args = [] + args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. @@ -573,9 +573,9 @@ class SearchStore(SearchBackgroundUpdateStore): if pagination_token: try: - origin_server_ts, stream = pagination_token.split(",") - origin_server_ts = int(origin_server_ts) - stream = int(stream) + origin_server_ts_str, stream_str = pagination_token.split(",") + origin_server_ts = int(origin_server_ts_str) + stream = int(stream_str) except Exception: raise SynapseError(400, "Invalid pagination token") @@ -654,7 +654,7 @@ class SearchStore(SearchBackgroundUpdateStore): # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = await self.get_events_as_list( + events = await self.get_events_as_list( # type: ignore[attr-defined] [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 417aef1dbc..28460fd364 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -14,7 +14,7 @@ # limitations under the License. import collections.abc import logging -from typing import TYPE_CHECKING, Iterable, Optional, Set +from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError @@ -29,7 +29,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter -from synapse.types import StateMap +from synapse.types import JsonDict, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # We delegate to the cached version return await self.get_current_state_ids(room_id) - def _get_filtered_current_state_ids_txn(txn): + def _get_filtered_current_state_ids_txn( + txn: LoggingTransaction, + ) -> StateMap[str]: results = {} sql = """ SELECT type, state_key, event_id FROM current_state_events @@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): event_id = state.get((EventTypes.CanonicalAlias, "")) if not event_id: - return + return None event = await self.get_event(event_id, allow_none=True) if not event: - return + return None return event.content.get("canonical_alias") @@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): list_name="event_ids", num_args=1, ) - async def _get_state_group_for_events(self, event_ids): + async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict: """Returns mapping event_id -> state_group""" rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", @@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): ): super().__init__(database, db_conn, hs) - self.server_name = hs.hostname + self.server_name: str = hs.hostname self.db_pool.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, @@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): self._background_remove_left_rooms, ) - async def _background_remove_left_rooms(self, progress, batch_size): + async def _background_remove_left_rooms( + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to delete rows from `current_state_events` and `event_forward_extremities` tables of rooms that the server is no longer joined to. @@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): last_room_id = progress.get("last_room_id", "") - def _background_remove_left_rooms_txn(txn): + def _background_remove_left_rooms_txn( + txn: LoggingTransaction, + ) -> Tuple[bool, Set[str]]: # get a batch of room ids to consider sql = """ SELECT DISTINCT room_id FROM current_state_events diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 427ae1f649..b95dbef678 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -108,7 +108,7 @@ class StatsStore(StateDeltasStore): ): super().__init__(database, db_conn, hs) - self.server_name = hs.hostname + self.server_name: str = hs.hostname self.clock = self.hs.get_clock() self.stats_enabled = hs.config.stats.stats_enabled diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 0595df01d3..df772d4721 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -68,7 +68,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) -> None: super().__init__(database, db_conn, hs) - self.server_name = hs.hostname + self.server_name: str = hs.hostname self.db_pool.updates.register_background_update_handler( "populate_user_directory_createtables", -- cgit 1.4.1