From b76f1a4d5f918def1f643910939b80e9e035e07f Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 27 Apr 2022 14:05:00 +0200 Subject: Add some type hints to datastore (#12485) --- synapse/storage/databases/main/devices.py | 51 ++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 18 deletions(-) (limited to 'synapse/storage/databases/main/devices.py') diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 483dd80406..2df4dd4ed4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -25,6 +25,7 @@ from typing import ( Optional, Set, Tuple, + cast, ) from synapse.api.errors import Codes, StoreError @@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore): Number of devices of this users. """ - def count_devices_by_users_txn(txn, user_ids): + def count_devices_by_users_txn( + txn: LoggingTransaction, user_ids: List[str] + ) -> int: sql = """ SELECT count(*) FROM devices @@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore): ) txn.execute(sql + clause, args) - return txn.fetchone()[0] + return cast(Tuple[int], txn.fetchone())[0] if not user_ids: return 0 @@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore): """ txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) - return list(txn) + return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall()) async def _get_device_update_edus_by_remote( self, @@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore): async def _get_last_device_update_for_remote_user( self, destination: str, user_id: str, from_stream_id: int ) -> int: - def f(txn): + def f(txn: LoggingTransaction) -> int: prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id FROM device_lists_outbound_last_success @@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore): if not user_ids_to_check: return set() - def _get_users_whose_devices_changed_txn(txn): + def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: changes = set() stream_id_where_clause = "stream_id > ?" @@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore): async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: """Mark that we no longer track device lists for remote user.""" - def _mark_remote_user_device_list_as_unsubscribed_txn(txn): + def _mark_remote_user_device_list_as_unsubscribed_txn( + txn: LoggingTransaction, + ) -> None: self.db_pool.simple_delete_txn( txn, table="device_lists_remote_extremeties", @@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore): ) def _store_dehydrated_device_txn( - self, txn, user_id: str, device_id: str, device_data: str + self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str ) -> Optional[str]: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, @@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore): """ yesterday = self._clock.time_msec() - prune_age - def _prune_txn(txn): + def _prune_txn(txn: LoggingTransaction) -> None: # look for (user, destination) pairs which have an update older than # the cutoff. # @@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): "drop_device_lists_outbound_last_success_non_unique_idx", ) - async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): - def f(conn): + async def _drop_device_list_streams_non_unique_indexes( + self, progress: JsonDict, batch_size: int + ) -> int: + def f(conn: LoggingDatabaseConnection) -> None: txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") @@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ) return 1 - async def _remove_duplicate_outbound_pokes(self, progress, batch_size): + async def _remove_duplicate_outbound_pokes( + self, progress: JsonDict, batch_size: int + ) -> int: # for some reason, we have accumulated duplicate entries in # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less # efficient. @@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, ) - def _txn(txn): + def _txn(txn: LoggingTransaction) -> int: clause, args = make_tuple_comparison_clause( [(x, last_row[x]) for x in KEY_COLS] ) @@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context = get_active_span_text_map() - def add_device_changes_txn(txn, stream_ids): + def add_device_changes_txn( + txn: LoggingTransaction, stream_ids: List[int] + ) -> None: self._add_device_change_to_stream_txn( txn, user_id, @@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn: LoggingTransaction, user_id: str, device_ids: Collection[str], - stream_ids: List[str], - ): + stream_ids: List[int], + ) -> None: txn.call_after( self._device_list_stream_cache.entity_has_changed, user_id, @@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_ids: Iterable[str], room_ids: Collection[str], - stream_ids: List[str], + stream_ids: List[int], context: Dict[str, str], ) -> None: """Record the user in the room has updated their device.""" @@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): LIMIT ? """ - def get_uncoverted_outbound_room_pokes_txn(txn): + def get_uncoverted_outbound_room_pokes_txn( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: txn.execute(sql, (limit,)) return [ @@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Marks the associated row in `device_lists_changes_in_room` as handled. """ - def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): + def add_device_list_outbound_pokes_txn( + txn: LoggingTransaction, stream_ids: List[int] + ) -> None: if hosts: self._add_device_outbound_poke_to_stream_txn( txn, -- cgit 1.4.1