diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 49c7606d51..9cce62ae6c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -67,7 +67,7 @@ from .state import StateStore
from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
-from .transactions import TransactionStore
+from .transactions import TransactionWorkerStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
@@ -83,7 +83,7 @@ class DataStore(
StreamStore,
ProfileStore,
PresenceStore,
- TransactionStore,
+ TransactionWorkerStore,
DirectoryStore,
KeyStore,
StateStore,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 82335e7a9d..d211c423b2 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -16,13 +16,15 @@ import logging
from collections import namedtuple
from typing import Iterable, List, Optional, Tuple
+import attr
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
-from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.caches.descriptors import cached
db_binary_type = memoryview
@@ -38,10 +40,23 @@ _UpdateTransactionRow = namedtuple(
"_TransactionRow", ("response_code", "response_json")
)
-SENTINEL = object()
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DestinationRetryTimings:
+ """The current destination retry timing info for a remote server."""
-class TransactionWorkerStore(SQLBaseStore):
+ # The first time we tried and failed to reach the remote server, in ms.
+ failure_ts: int
+
+ # The last time we tried and failed to reach the remote server, in ms.
+ retry_last_ts: int
+
+ # How long since the last time we tried to reach the remote server before
+ # trying again, in ms.
+ retry_interval: int
+
+
+class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -60,19 +75,6 @@ class TransactionWorkerStore(SQLBaseStore):
"_cleanup_transactions", _cleanup_transactions_txn
)
-
-class TransactionStore(TransactionWorkerStore):
- """A collection of queries for handling PDUs."""
-
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- self._destination_retry_cache = ExpiringCache(
- cache_name="get_destination_retry_timings",
- clock=self._clock,
- expiry_ms=5 * 60 * 1000,
- )
-
async def get_received_txn_response(
self, transaction_id: str, origin: str
) -> Optional[Tuple[int, JsonDict]]:
@@ -145,7 +147,11 @@ class TransactionStore(TransactionWorkerStore):
desc="set_received_txn_response",
)
- async def get_destination_retry_timings(self, destination):
+ @cached(max_entries=10000)
+ async def get_destination_retry_timings(
+ self,
+ destination: str,
+ ) -> Optional[DestinationRetryTimings]:
"""Gets the current retry timings (if any) for a given destination.
Args:
@@ -156,34 +162,29 @@ class TransactionStore(TransactionWorkerStore):
Otherwise a dict for the retry scheme
"""
- result = self._destination_retry_cache.get(destination, SENTINEL)
- if result is not SENTINEL:
- return result
-
result = await self.db_pool.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings,
destination,
)
- # We don't hugely care about race conditions between getting and
- # invalidating the cache, since we time out fairly quickly anyway.
- self._destination_retry_cache[destination] = result
return result
- def _get_destination_retry_timings(self, txn, destination):
+ def _get_destination_retry_timings(
+ self, txn, destination: str
+ ) -> Optional[DestinationRetryTimings]:
result = self.db_pool.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
- retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
+ retcols=("failure_ts", "retry_last_ts", "retry_interval"),
allow_none=True,
)
# check we have a row and retry_last_ts is not null or zero
# (retry_last_ts can't be negative)
if result and result["retry_last_ts"]:
- return result
+ return DestinationRetryTimings(**result)
else:
return None
@@ -204,7 +205,6 @@ class TransactionStore(TransactionWorkerStore):
retry_interval: how long until next retry in ms
"""
- self._destination_retry_cache.pop(destination, None)
if self.database_engine.can_native_upsert:
return await self.db_pool.runInteraction(
"set_destination_retry_timings",
@@ -252,6 +252,10 @@ class TransactionStore(TransactionWorkerStore):
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
+ self._invalidate_cache_and_stream(
+ txn, self.get_destination_retry_timings, (destination,)
+ )
+
def _set_destination_retry_timings_emulated(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
@@ -295,6 +299,10 @@ class TransactionStore(TransactionWorkerStore):
},
)
+ self._invalidate_cache_and_stream(
+ txn, self.get_destination_retry_timings, (destination,)
+ )
+
async def store_destination_rooms_entries(
self,
destinations: Iterable[str],
|