summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/federation/transaction_queue.py47
-rw-r--r--synapse/storage/transactions.py48
2 files changed, 40 insertions, 55 deletions
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 32fa5e8c15..2a7dd343f3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -202,19 +202,6 @@ class TransactionQueue(object):
     @defer.inlineCallbacks
     @log_function
     def _attempt_new_transaction(self, destination):
-        if destination in self.pending_transactions:
-            # XXX: pending_transactions can get stuck on by a never-ending
-            # request at which point pending_pdus_by_dest just keeps growing.
-            # we need application-layer timeouts of some flavour of these
-            # requests
-            logger.debug(
-                "TX [%s] Transaction already in progress",
-                destination
-            )
-            return
-
-        logger.debug("TX [%s] _attempt_new_transaction", destination)
-
         # list of (pending_pdu, deferred, order)
         pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
         pending_edus = self.pending_edus_by_dest.pop(destination, [])
@@ -228,20 +215,34 @@ class TransactionQueue(object):
             logger.debug("TX [%s] Nothing to send", destination)
             return
 
-        # Sort based on the order field
-        pending_pdus.sort(key=lambda t: t[2])
-
-        pdus = [x[0] for x in pending_pdus]
-        edus = [x[0] for x in pending_edus]
-        failures = [x[0].get_dict() for x in pending_failures]
-        deferreds = [
-            x[1]
-            for x in pending_pdus + pending_edus + pending_failures
-        ]
+        if destination in self.pending_transactions:
+            # XXX: pending_transactions can get stuck on by a never-ending
+            # request at which point pending_pdus_by_dest just keeps growing.
+            # we need application-layer timeouts of some flavour of these
+            # requests
+            logger.debug(
+                "TX [%s] Transaction already in progress",
+                destination
+            )
+            return
 
+        # NOTE: Nothing should be between the above check and the insertion below
         try:
             self.pending_transactions[destination] = 1
 
+            logger.debug("TX [%s] _attempt_new_transaction", destination)
+
+            # Sort based on the order field
+            pending_pdus.sort(key=lambda t: t[2])
+
+            pdus = [x[0] for x in pending_pdus]
+            edus = [x[0] for x in pending_edus]
+            failures = [x[0].get_dict() for x in pending_failures]
+            deferreds = [
+                x[1]
+                for x in pending_pdus + pending_edus + pending_failures
+            ]
+
             txn_id = str(self._next_txn_id)
 
             limiter = yield get_retry_limiter(
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 15695e9831..4e0d7c9774 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -253,16 +253,6 @@ class TransactionStore(SQLBaseStore):
             retry_interval (int) - how long until next retry in ms
         """
 
-        # As this is the new value, we might as well prefill the cache
-        self.get_destination_retry_timings.prefill(
-            destination,
-            {
-                "destination": destination,
-                "retry_last_ts": retry_last_ts,
-                "retry_interval": retry_interval
-            },
-        )
-
         # XXX: we could chose to not bother persisting this if our cache thinks
         # this is a NOOP
         return self.runInteraction(
@@ -275,31 +265,25 @@ class TransactionStore(SQLBaseStore):
 
     def _set_destination_retry_timings(self, txn, destination,
                                        retry_last_ts, retry_interval):
-        query = (
-            "UPDATE destinations"
-            " SET retry_last_ts = ?, retry_interval = ?"
-            " WHERE destination = ?"
-        )
+        txn.call_after(self.get_destination_retry_timings.invalidate, (destination,))
 
-        txn.execute(
-            query,
-            (
-                retry_last_ts, retry_interval, destination,
-            )
+        self._simple_upsert_txn(
+            txn,
+            "destinations",
+            keyvalues={
+                "destination": destination,
+            },
+            values={
+                "retry_last_ts": retry_last_ts,
+                "retry_interval": retry_interval,
+            },
+            insertion_values={
+                "destination": destination,
+                "retry_last_ts": retry_last_ts,
+                "retry_interval": retry_interval,
+            }
         )
 
-        if txn.rowcount == 0:
-            # destination wasn't already in table. Insert it.
-            self._simple_insert_txn(
-                txn,
-                table="destinations",
-                values={
-                    "destination": destination,
-                    "retry_last_ts": retry_last_ts,
-                    "retry_interval": retry_interval,
-                }
-            )
-
     def get_destinations_needing_retry(self):
         """Get all destinations which are due a retry for sending a transaction.