diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..6c299cafa5 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -13,16 +13,19 @@
# limitations under the License.
import logging
-from collections import namedtuple
from enum import Enum
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
import attr
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -35,16 +38,6 @@ db_binary_type = memoryview
logger = logging.getLogger(__name__)
-_TransactionRow = namedtuple(
- "_TransactionRow",
- ("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
-)
-
-_UpdateTransactionRow = namedtuple(
- "_TransactionRow", ("response_code", "response_json")
-)
-
-
class DestinationSortOrder(Enum):
"""Enum to define the sorting method used when returning destinations."""
@@ -71,7 +64,12 @@ class DestinationRetryTimings:
class TransactionWorkerStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -82,7 +80,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
- def _cleanup_transactions_txn(txn):
+ def _cleanup_transactions_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
await self.db_pool.runInteraction(
@@ -112,7 +110,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
origin,
)
- def _get_received_txn_response(self, txn, transaction_id, origin):
+ def _get_received_txn_response(
+ self, txn: LoggingTransaction, transaction_id: str, origin: str
+ ) -> Optional[Tuple[int, JsonDict]]:
result = self.db_pool.simple_select_one_txn(
txn,
table="received_transactions",
@@ -187,7 +187,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
return result
def _get_destination_retry_timings(
- self, txn, destination: str
+ self, txn: LoggingTransaction, destination: str
) -> Optional[DestinationRetryTimings]:
result = self.db_pool.simple_select_one_txn(
txn,
@@ -222,7 +222,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
"""
if self.database_engine.can_native_upsert:
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_native,
destination,
@@ -232,7 +232,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
db_autocommit=True, # Safe as its a single upsert
)
else:
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_emulated,
destination,
@@ -242,8 +242,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
)
def _set_destination_retry_timings_native(
- self, txn, destination, failure_ts, retry_last_ts, retry_interval
- ):
+ self,
+ txn: LoggingTransaction,
+ destination: str,
+ failure_ts: Optional[int],
+ retry_last_ts: int,
+ retry_interval: int,
+ ) -> None:
assert self.database_engine.can_native_upsert
# Upsert retry time interval if retry_interval is zero (i.e. we're
@@ -273,8 +278,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
)
def _set_destination_retry_timings_emulated(
- self, txn, destination, failure_ts, retry_last_ts, retry_interval
- ):
+ self,
+ txn: LoggingTransaction,
+ destination: str,
+ failure_ts: Optional[int],
+ retry_last_ts: int,
+ retry_interval: int,
+ ) -> None:
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
@@ -384,7 +394,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
last_successful_stream_ordering: the stream_ordering of the most
recent successfully-sent PDU
"""
- return await self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
"destinations",
keyvalues={"destination": destination},
values={"last_successful_stream_ordering": last_successful_stream_ordering},
@@ -525,7 +535,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
else:
order = "ASC"
- args = []
+ args: List[object] = []
where_statement = ""
if destination:
args.extend(["%" + destination.lower() + "%"])
@@ -534,7 +544,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
sql_base = f"FROM destinations {where_statement} "
sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
txn.execute(sql, args)
- count = txn.fetchone()[0]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = f"""
SELECT destination, retry_last_ts, retry_interval, failure_ts,
|