summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10036.misc1
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/federation/transport/server.py2
-rw-r--r--synapse/replication/slave/storage/transactions.py21
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/transactions.py66
-rw-r--r--synapse/util/retryutils.py8
-rw-r--r--tests/handlers/test_typing.py8
-rw-r--r--tests/storage/test_transactions.py8
-rw-r--r--tests/util/test_retryutils.py18
10 files changed, 62 insertions, 76 deletions
diff --git a/changelog.d/10036.misc b/changelog.d/10036.misc
new file mode 100644
index 0000000000..d2cf1e5473
--- /dev/null
+++ b/changelog.d/10036.misc
@@ -0,0 +1 @@
+Properly invalidate caches for destination retry timings every (instead of expiring entries every 5 minutes).
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f730cdbd78..91ad326f19 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -61,7 +61,6 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
 from synapse.rest.admin import register_servlets_for_media_repo
 from synapse.rest.client.v1 import events, login, presence, room
 from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@@ -237,7 +236,6 @@ class GenericWorkerSlavedStore(
     DirectoryStore,
     SlavedApplicationServiceStore,
     SlavedRegistrationStore,
-    SlavedTransactionStore,
     SlavedProfileStore,
     SlavedClientIpStore,
     SlavedFilteringStore,
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index c17a085a4f..9d50b05d01 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -160,7 +160,7 @@ class Authenticator:
         # If we get a valid signed request from the other side, its probably
         # alive
         retry_timings = await self.store.get_destination_retry_timings(origin)
-        if retry_timings and retry_timings["retry_last_ts"]:
+        if retry_timings and retry_timings.retry_last_ts:
             run_in_background(self._reset_retry_timings, origin)
 
         return origin
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
deleted file mode 100644
index a59e543924..0000000000
--- a/synapse/replication/slave/storage/transactions.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.transactions import TransactionStore
-
-from ._base import BaseSlavedStore
-
-
-class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
-    pass
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 49c7606d51..9cce62ae6c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -67,7 +67,7 @@ from .state import StateStore
 from .stats import StatsStore
 from .stream import StreamStore
 from .tags import TagsStore
-from .transactions import TransactionStore
+from .transactions import TransactionWorkerStore
 from .ui_auth import UIAuthStore
 from .user_directory import UserDirectoryStore
 from .user_erasure_store import UserErasureStore
@@ -83,7 +83,7 @@ class DataStore(
     StreamStore,
     ProfileStore,
     PresenceStore,
-    TransactionStore,
+    TransactionWorkerStore,
     DirectoryStore,
     KeyStore,
     StateStore,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 82335e7a9d..d211c423b2 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -16,13 +16,15 @@ import logging
 from collections import namedtuple
 from typing import Iterable, List, Optional, Tuple
 
+import attr
 from canonicaljson import encode_canonical_json
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import db_to_json
 from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.types import JsonDict
-from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.caches.descriptors import cached
 
 db_binary_type = memoryview
 
@@ -38,10 +40,23 @@ _UpdateTransactionRow = namedtuple(
     "_TransactionRow", ("response_code", "response_json")
 )
 
-SENTINEL = object()
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DestinationRetryTimings:
+    """The current destination retry timing info for a remote server."""
 
-class TransactionWorkerStore(SQLBaseStore):
+    # The first time we tried and failed to reach the remote server, in ms.
+    failure_ts: int
+
+    # The last time we tried and failed to reach the remote server, in ms.
+    retry_last_ts: int
+
+    # How long since the last time we tried to reach the remote server before
+    # trying again, in ms.
+    retry_interval: int
+
+
+class TransactionWorkerStore(CacheInvalidationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
@@ -60,19 +75,6 @@ class TransactionWorkerStore(SQLBaseStore):
             "_cleanup_transactions", _cleanup_transactions_txn
         )
 
-
-class TransactionStore(TransactionWorkerStore):
-    """A collection of queries for handling PDUs."""
-
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        self._destination_retry_cache = ExpiringCache(
-            cache_name="get_destination_retry_timings",
-            clock=self._clock,
-            expiry_ms=5 * 60 * 1000,
-        )
-
     async def get_received_txn_response(
         self, transaction_id: str, origin: str
     ) -> Optional[Tuple[int, JsonDict]]:
@@ -145,7 +147,11 @@ class TransactionStore(TransactionWorkerStore):
             desc="set_received_txn_response",
         )
 
-    async def get_destination_retry_timings(self, destination):
+    @cached(max_entries=10000)
+    async def get_destination_retry_timings(
+        self,
+        destination: str,
+    ) -> Optional[DestinationRetryTimings]:
         """Gets the current retry timings (if any) for a given destination.
 
         Args:
@@ -156,34 +162,29 @@ class TransactionStore(TransactionWorkerStore):
             Otherwise a dict for the retry scheme
         """
 
-        result = self._destination_retry_cache.get(destination, SENTINEL)
-        if result is not SENTINEL:
-            return result
-
         result = await self.db_pool.runInteraction(
             "get_destination_retry_timings",
             self._get_destination_retry_timings,
             destination,
         )
 
-        # We don't hugely care about race conditions between getting and
-        # invalidating the cache, since we time out fairly quickly anyway.
-        self._destination_retry_cache[destination] = result
         return result
 
-    def _get_destination_retry_timings(self, txn, destination):
+    def _get_destination_retry_timings(
+        self, txn, destination: str
+    ) -> Optional[DestinationRetryTimings]:
         result = self.db_pool.simple_select_one_txn(
             txn,
             table="destinations",
             keyvalues={"destination": destination},
-            retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
+            retcols=("failure_ts", "retry_last_ts", "retry_interval"),
             allow_none=True,
         )
 
         # check we have a row and retry_last_ts is not null or zero
         # (retry_last_ts can't be negative)
         if result and result["retry_last_ts"]:
-            return result
+            return DestinationRetryTimings(**result)
         else:
             return None
 
@@ -204,7 +205,6 @@ class TransactionStore(TransactionWorkerStore):
             retry_interval: how long until next retry in ms
         """
 
-        self._destination_retry_cache.pop(destination, None)
         if self.database_engine.can_native_upsert:
             return await self.db_pool.runInteraction(
                 "set_destination_retry_timings",
@@ -252,6 +252,10 @@ class TransactionStore(TransactionWorkerStore):
 
         txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
 
+        self._invalidate_cache_and_stream(
+            txn, self.get_destination_retry_timings, (destination,)
+        )
+
     def _set_destination_retry_timings_emulated(
         self, txn, destination, failure_ts, retry_last_ts, retry_interval
     ):
@@ -295,6 +299,10 @@ class TransactionStore(TransactionWorkerStore):
                 },
             )
 
+        self._invalidate_cache_and_stream(
+            txn, self.get_destination_retry_timings, (destination,)
+        )
+
     async def store_destination_rooms_entries(
         self,
         destinations: Iterable[str],
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index f9c370a814..129b47cd49 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -82,11 +82,9 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
     retry_timings = await store.get_destination_retry_timings(destination)
 
     if retry_timings:
-        failure_ts = retry_timings["failure_ts"]
-        retry_last_ts, retry_interval = (
-            retry_timings["retry_last_ts"],
-            retry_timings["retry_interval"],
-        )
+        failure_ts = retry_timings.failure_ts
+        retry_last_ts = retry_timings.retry_last_ts
+        retry_interval = retry_timings.retry_interval
 
         now = int(clock.time_msec())
 
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 0c89487eaf..f58afbc244 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -89,14 +89,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.event_source = hs.get_event_sources().sources["typing"]
 
         self.datastore = hs.get_datastore()
-        retry_timings_res = {
-            "destination": "",
-            "retry_last_ts": 0,
-            "retry_interval": 0,
-            "failure_ts": None,
-        }
         self.datastore.get_destination_retry_timings = Mock(
-            return_value=defer.succeed(retry_timings_res)
+            return_value=defer.succeed(None)
         )
 
         self.datastore.get_device_updates_by_remote = Mock(
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index b7f7eae8d0..bea9091d30 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.storage.databases.main.transactions import DestinationRetryTimings
 from synapse.util.retryutils import MAX_RETRY_INTERVAL
 
 from tests.unittest import HomeserverTestCase
@@ -36,8 +37,11 @@ class TransactionStoreTestCase(HomeserverTestCase):
         d = self.store.get_destination_retry_timings("example.com")
         r = self.get_success(d)
 
-        self.assert_dict(
-            {"retry_last_ts": 50, "retry_interval": 100, "failure_ts": 1000}, r
+        self.assertEqual(
+            DestinationRetryTimings(
+                retry_last_ts=50, retry_interval=100, failure_ts=1000
+            ),
+            r,
         )
 
     def test_initial_set_transactions(self):
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9b2be83a43..9e1bebdc83 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -51,10 +51,12 @@ class RetryLimiterTestCase(HomeserverTestCase):
         except AssertionError:
             pass
 
+        self.pump()
+
         new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
-        self.assertEqual(new_timings["failure_ts"], failure_ts)
-        self.assertEqual(new_timings["retry_last_ts"], failure_ts)
-        self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
+        self.assertEqual(new_timings.failure_ts, failure_ts)
+        self.assertEqual(new_timings.retry_last_ts, failure_ts)
+        self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
 
         # now if we try again we should get a failure
         self.get_failure(
@@ -77,14 +79,16 @@ class RetryLimiterTestCase(HomeserverTestCase):
         except AssertionError:
             pass
 
+        self.pump()
+
         new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
-        self.assertEqual(new_timings["failure_ts"], failure_ts)
-        self.assertEqual(new_timings["retry_last_ts"], retry_ts)
+        self.assertEqual(new_timings.failure_ts, failure_ts)
+        self.assertEqual(new_timings.retry_last_ts, retry_ts)
         self.assertGreaterEqual(
-            new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
+            new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
         )
         self.assertLessEqual(
-            new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
+            new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
         )
 
         #