diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 00d0f48082..423cc3f02a 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -17,6 +17,8 @@ from ._base import SQLBaseStore, Table
from collections import namedtuple
+from twisted.internet import defer
+
import logging
logger = logging.getLogger(__name__)
@@ -26,6 +28,10 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
+ # a write-through cache of DestinationsTable.EntryType indexed by
+ # destination string
+ destination_retry_cache = {}
+
def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response
@@ -114,7 +120,7 @@ class TransactionStore(SQLBaseStore):
def _prep_send_transaction(self, txn, transaction_id, destination,
origin_server_ts):
- # First we find out what the prev_txs should be.
+ # First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time,
# we can simply take the last one.
query = "%s ORDER BY id DESC LIMIT 1" % (
@@ -205,6 +211,92 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall())
+ def get_destination_retry_timings(self, destination):
+ """Gets the current retry timings (if any) for a given destination.
+
+ Args:
+ destination (str)
+
+ Returns:
+ None if not retrying
+ Otherwise a DestinationsTable.EntryType for the retry scheme
+ """
+ if destination in self.destination_retry_cache:
+ return defer.succeed(self.destination_retry_cache[destination])
+
+ return self.runInteraction(
+ "get_destination_retry_timings",
+ self._get_destination_retry_timings, destination)
+
+ def _get_destination_retry_timings(cls, txn, destination):
+ query = DestinationsTable.select_statement("destination = ?")
+ txn.execute(query, (destination,))
+ result = txn.fetchall()
+ if result:
+ result = DestinationsTable.decode_single_result(result)
+ if result.retry_last_ts > 0:
+ return result
+ else:
+ return None
+
+ def set_destination_retry_timings(self, destination,
+ retry_last_ts, retry_interval):
+ """Sets the current retry timings for a given destination.
+ Both timings should be zero if retrying is no longer occuring.
+
+ Args:
+ destination (str)
+ retry_last_ts (int) - time of last retry attempt in unix epoch ms
+ retry_interval (int) - how long until next retry in ms
+ """
+
+ self.destination_retry_cache[destination] = (
+ DestinationsTable.EntryType(
+ destination,
+ retry_last_ts,
+ retry_interval
+ )
+ )
+
+ # XXX: we could chose to not bother persisting this if our cache thinks
+ # this is a NOOP
+ return self.runInteraction(
+ "set_destination_retry_timings",
+ self._set_destination_retry_timings,
+ destination,
+ retry_last_ts,
+ retry_interval,
+ )
+
+ def _set_destination_retry_timings(cls, txn, destination,
+ retry_last_ts, retry_interval):
+
+ query = (
+ "INSERT OR REPLACE INTO %s "
+ "(destination, retry_last_ts, retry_interval) "
+ "VALUES (?, ?, ?) "
+ ) % DestinationsTable.table_name
+
+ txn.execute(query, (destination, retry_last_ts, retry_interval))
+
+ def get_destinations_needing_retry(self):
+ """Get all destinations which are due a retry for sending a transaction.
+
+ Returns:
+ list: A list of `DestinationsTable.EntryType`
+ """
+
+ return self.runInteraction(
+ "get_destinations_needing_retry",
+ self._get_destinations_needing_retry
+ )
+
+ def _get_destinations_needing_retry(cls, txn):
+ where = "retry_last_ts > 0 and retry_next_ts < now()"
+ query = DestinationsTable.select_statement(where)
+ txn.execute(query)
+ return DestinationsTable.decode_results(txn.fetchall())
+
class ReceivedTransactionsTable(Table):
table_name = "received_transactions"
@@ -247,3 +339,15 @@ class TransactionsToPduTable(Table):
]
EntryType = namedtuple("TransactionsToPduEntry", fields)
+
+
+class DestinationsTable(Table):
+ table_name = "destinations"
+
+ fields = [
+ "destination",
+ "retry_last_ts",
+ "retry_interval",
+ ]
+
+ EntryType = namedtuple("DestinationsEntry", fields)
|