diff options
-rw-r--r-- | changelog.d/11589.misc | 1 | ||||
-rw-r--r-- | mypy.ini | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/transactions.py | 49 |
3 files changed, 29 insertions, 25 deletions
diff --git a/changelog.d/11589.misc b/changelog.d/11589.misc new file mode 100644 index 0000000000..8e405b9226 --- /dev/null +++ b/changelog.d/11589.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index cbe1e8302c..c546487bdb 100644 --- a/mypy.ini +++ b/mypy.ini @@ -41,7 +41,6 @@ exclude = (?x) |synapse/storage/databases/main/search.py |synapse/storage/databases/main/state.py |synapse/storage/databases/main/stats.py - |synapse/storage/databases/main/transactions.py |synapse/storage/databases/main/user_directory.py |synapse/storage/schema/ @@ -216,6 +215,9 @@ disallow_untyped_defs = True [mypy-synapse.storage.databases.main.state_deltas] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.transactions] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.user_erasure_store] disallow_untyped_defs = True diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 54b41513ee..6c299cafa5 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -13,9 +13,8 @@ # limitations under the License. import logging -from collections import namedtuple from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast import attr from canonicaljson import encode_canonical_json @@ -39,16 +38,6 @@ db_binary_type = memoryview logger = logging.getLogger(__name__) -_TransactionRow = namedtuple( - "_TransactionRow", - ("id", "transaction_id", "destination", "ts", "response_code", "response_json"), -) - -_UpdateTransactionRow = namedtuple( - "_TransactionRow", ("response_code", "response_json") -) - - class DestinationSortOrder(Enum): """Enum to define the sorting method used when returning destinations.""" @@ -91,7 +80,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 - def _cleanup_transactions_txn(txn): + def _cleanup_transactions_txn(txn: LoggingTransaction) -> None: txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) await self.db_pool.runInteraction( @@ -121,7 +110,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): origin, ) - def _get_received_txn_response(self, txn, transaction_id, origin): + def _get_received_txn_response( + self, txn: LoggingTransaction, transaction_id: str, origin: str + ) -> Optional[Tuple[int, JsonDict]]: result = self.db_pool.simple_select_one_txn( txn, table="received_transactions", @@ -196,7 +187,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): return result def _get_destination_retry_timings( - self, txn, destination: str + self, txn: LoggingTransaction, destination: str ) -> Optional[DestinationRetryTimings]: result = self.db_pool.simple_select_one_txn( txn, @@ -231,7 +222,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): """ if self.database_engine.can_native_upsert: - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings_native, destination, @@ -241,7 +232,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): db_autocommit=True, # Safe as its a single upsert ) else: - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings_emulated, destination, @@ -251,8 +242,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) def _set_destination_retry_timings_native( - self, txn, destination, failure_ts, retry_last_ts, retry_interval - ): + self, + txn: LoggingTransaction, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: assert self.database_engine.can_native_upsert # Upsert retry time interval if retry_interval is zero (i.e. we're @@ -282,8 +278,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) def _set_destination_retry_timings_emulated( - self, txn, destination, failure_ts, retry_last_ts, retry_interval - ): + self, + txn: LoggingTransaction, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: self.database_engine.lock_table(txn, "destinations") # We need to be careful here as the data may have changed from under us @@ -393,7 +394,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): last_successful_stream_ordering: the stream_ordering of the most recent successfully-sent PDU """ - return await self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( "destinations", keyvalues={"destination": destination}, values={"last_successful_stream_ordering": last_successful_stream_ordering}, @@ -534,7 +535,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): else: order = "ASC" - args = [] + args: List[object] = [] where_statement = "" if destination: args.extend(["%" + destination.lower() + "%"]) @@ -543,7 +544,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): sql_base = f"FROM destinations {where_statement} " sql = f"SELECT COUNT(*) as total_destinations {sql_base}" txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = f""" SELECT destination, retry_last_ts, retry_interval, failure_ts, |