diff options
Diffstat (limited to 'synapse/storage/transactions.py')
-rw-r--r-- | synapse/storage/transactions.py | 106 |
1 files changed, 105 insertions, 1 deletions
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 00d0f48082..423cc3f02a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -17,6 +17,8 @@ from ._base import SQLBaseStore, Table from collections import namedtuple +from twisted.internet import defer + import logging logger = logging.getLogger(__name__) @@ -26,6 +28,10 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ + # a write-through cache of DestinationsTable.EntryType indexed by + # destination string + destination_retry_cache = {} + def get_received_txn_response(self, transaction_id, origin): """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response @@ -114,7 +120,7 @@ class TransactionStore(SQLBaseStore): def _prep_send_transaction(self, txn, transaction_id, destination, origin_server_ts): - # First we find out what the prev_txs should be. + # First we find out what the prev_txns should be. # Since we know that we are only sending one transaction at a time, # we can simply take the last one. query = "%s ORDER BY id DESC LIMIT 1" % ( @@ -205,6 +211,92 @@ class TransactionStore(SQLBaseStore): return ReceivedTransactionsTable.decode_results(txn.fetchall()) + def get_destination_retry_timings(self, destination): + """Gets the current retry timings (if any) for a given destination. + + Args: + destination (str) + + Returns: + None if not retrying + Otherwise a DestinationsTable.EntryType for the retry scheme + """ + if destination in self.destination_retry_cache: + return defer.succeed(self.destination_retry_cache[destination]) + + return self.runInteraction( + "get_destination_retry_timings", + self._get_destination_retry_timings, destination) + + def _get_destination_retry_timings(cls, txn, destination): + query = DestinationsTable.select_statement("destination = ?") + txn.execute(query, (destination,)) + result = txn.fetchall() + if result: + result = DestinationsTable.decode_single_result(result) + if result.retry_last_ts > 0: + return result + else: + return None + + def set_destination_retry_timings(self, destination, + retry_last_ts, retry_interval): + """Sets the current retry timings for a given destination. + Both timings should be zero if retrying is no longer occuring. + + Args: + destination (str) + retry_last_ts (int) - time of last retry attempt in unix epoch ms + retry_interval (int) - how long until next retry in ms + """ + + self.destination_retry_cache[destination] = ( + DestinationsTable.EntryType( + destination, + retry_last_ts, + retry_interval + ) + ) + + # XXX: we could chose to not bother persisting this if our cache thinks + # this is a NOOP + return self.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings, + destination, + retry_last_ts, + retry_interval, + ) + + def _set_destination_retry_timings(cls, txn, destination, + retry_last_ts, retry_interval): + + query = ( + "INSERT OR REPLACE INTO %s " + "(destination, retry_last_ts, retry_interval) " + "VALUES (?, ?, ?) " + ) % DestinationsTable.table_name + + txn.execute(query, (destination, retry_last_ts, retry_interval)) + + def get_destinations_needing_retry(self): + """Get all destinations which are due a retry for sending a transaction. + + Returns: + list: A list of `DestinationsTable.EntryType` + """ + + return self.runInteraction( + "get_destinations_needing_retry", + self._get_destinations_needing_retry + ) + + def _get_destinations_needing_retry(cls, txn): + where = "retry_last_ts > 0 and retry_next_ts < now()" + query = DestinationsTable.select_statement(where) + txn.execute(query) + return DestinationsTable.decode_results(txn.fetchall()) + class ReceivedTransactionsTable(Table): table_name = "received_transactions" @@ -247,3 +339,15 @@ class TransactionsToPduTable(Table): ] EntryType = namedtuple("TransactionsToPduEntry", fields) + + +class DestinationsTable(Table): + table_name = "destinations" + + fields = [ + "destination", + "retry_last_ts", + "retry_interval", + ] + + EntryType = namedtuple("DestinationsEntry", fields) |