summary refs log tree commit diff
path: root/synapse/storage/transactions.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/transactions.py')
-rw-r--r--synapse/storage/transactions.py191
1 files changed, 104 insertions, 87 deletions
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 7d22392444..9dec58c21d 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, Table, cached
+from ._base import SQLBaseStore, cached
 
 from collections import namedtuple
 
@@ -84,13 +84,18 @@ class TransactionStore(SQLBaseStore):
 
     def _set_received_txn_response(self, txn, transaction_id, origin, code,
                                    response_json):
-        query = (
-            "UPDATE %s "
-            "SET response_code = ?, response_json = ? "
-            "WHERE transaction_id = ? AND origin = ?"
-        ) % ReceivedTransactionsTable.table_name
-
-        txn.execute(query, (code, response_json, transaction_id, origin))
+        self._simple_update_one_txn(
+            txn,
+            table=ReceivedTransactionsTable.table_name,
+            keyvalues={
+                "transaction_id": transaction_id,
+                "origin": origin,
+            },
+            updatevalues={
+                "response_code": code,
+                "response_json": response_json,
+            }
+        )
 
     def prep_send_transaction(self, transaction_id, destination,
                               origin_server_ts):
@@ -121,38 +126,32 @@ class TransactionStore(SQLBaseStore):
         # First we find out what the prev_txns should be.
         # Since we know that we are only sending one transaction at a time,
         # we can simply take the last one.
-        query = "%s ORDER BY id DESC LIMIT 1" % (
-                SentTransactions.select_statement("destination = ?"),
-            )
+        query = (
+            "SELECT * FROM sent_transactions"
+            " WHERE destination = ?"
+            " ORDER BY id DESC LIMIT 1"
+        )
 
         txn.execute(query, (destination,))
-        results = SentTransactions.decode_results(txn.fetchall())
+        results = self.cursor_to_dict(txn)
 
-        prev_txns = [r.transaction_id for r in results]
+        prev_txns = [r["transaction_id"] for r in results]
 
         # Actually add the new transaction to the sent_transactions table.
 
-        query = SentTransactions.insert_statement()
-        txn.execute(query, SentTransactions.EntryType(
-            self.get_next_stream_id(),
-            transaction_id=transaction_id,
-            destination=destination,
-            ts=origin_server_ts,
-            response_code=0,
-            response_json=None
-        ))
-
-        # Update the tx id -> pdu id mapping
-
-        # values = [
-        #     (transaction_id, destination, pdu[0], pdu[1])
-        #     for pdu in pdu_list
-        # ]
-        #
-        # logger.debug("Inserting: %s", repr(values))
-        #
-        # query = TransactionsToPduTable.insert_statement()
-        # txn.executemany(query, values)
+        self._simple_insert_txn(
+            txn,
+            table=SentTransactions.table_name,
+            values={
+                "transaction_id": self.get_next_stream_id(),
+                "destination": destination,
+                "ts": origin_server_ts,
+                "response_code": 0,
+                "response_json": None,
+            }
+        )
+
+        # TODO Update the tx id -> pdu id mapping
 
         return prev_txns
 
@@ -171,15 +170,20 @@ class TransactionStore(SQLBaseStore):
             transaction_id, destination, code, response_dict
         )
 
-    def _delivered_txn(cls, txn, transaction_id, destination,
+    def _delivered_txn(self, txn, transaction_id, destination,
                        code, response_json):
-        query = (
-            "UPDATE %s "
-            "SET response_code = ?, response_json = ? "
-            "WHERE transaction_id = ? AND destination = ?"
-        ) % SentTransactions.table_name
-
-        txn.execute(query, (code, response_json, transaction_id, destination))
+        self._simple_update_one_txn(
+            txn,
+            table=SentTransactions.table_name,
+            keyvalues={
+                "transaction_id": transaction_id,
+                "destination": destination,
+            },
+            updatevalues={
+                "response_code": code,
+                "response_json": response_json,
+            }
+        )
 
     def get_transactions_after(self, transaction_id, destination):
         """Get all transactions after a given local transaction_id.
@@ -189,25 +193,26 @@ class TransactionStore(SQLBaseStore):
             destination (str)
 
         Returns:
-            list: A list of `ReceivedTransactionsTable.EntryType`
+            list: A list of dicts
         """
         return self.runInteraction(
             "get_transactions_after",
             self._get_transactions_after, transaction_id, destination
         )
 
-    def _get_transactions_after(cls, txn, transaction_id, destination):
-        where = (
-            "destination = ? AND id > (select id FROM %s WHERE "
-            "transaction_id = ? AND destination = ?)"
-        ) % (
-            SentTransactions.table_name
+    def _get_transactions_after(self, txn, transaction_id, destination):
+        query = (
+            "SELECT * FROM sent_transactions"
+            " WHERE destination = ? AND id >"
+            " ("
+            " SELECT id FROM sent_transactions"
+            " WHERE transaction_id = ? AND destination = ?"
+            " )"
         )
-        query = SentTransactions.select_statement(where)
 
         txn.execute(query, (destination, transaction_id, destination))
 
-        return ReceivedTransactionsTable.decode_results(txn.fetchall())
+        return self.cursor_to_dict(txn)
 
     @cached()
     def get_destination_retry_timings(self, destination):
@@ -218,22 +223,27 @@ class TransactionStore(SQLBaseStore):
 
         Returns:
             None if not retrying
-            Otherwise a DestinationsTable.EntryType for the retry scheme
+            Otherwise a dict for the retry scheme
         """
         return self.runInteraction(
             "get_destination_retry_timings",
             self._get_destination_retry_timings, destination)
 
-    def _get_destination_retry_timings(cls, txn, destination):
-        query = DestinationsTable.select_statement("destination = ?")
-        txn.execute(query, (destination,))
-        result = txn.fetchall()
-        if result:
-            result = DestinationsTable.decode_single_result(result)
-            if result.retry_last_ts > 0:
-                return result
-            else:
-                return None
+    def _get_destination_retry_timings(self, txn, destination):
+        result = self._simple_select_one_txn(
+            txn,
+            table=DestinationsTable.table_name,
+            keyvalues={
+                "destination": destination,
+            },
+            retcols=DestinationsTable.fields,
+            allow_none=True,
+        )
+
+        if result["retry_last_ts"] > 0:
+            return result
+        else:
+            return None
 
     def set_destination_retry_timings(self, destination,
                                       retry_last_ts, retry_interval):
@@ -249,11 +259,11 @@ class TransactionStore(SQLBaseStore):
         # As this is the new value, we might as well prefill the cache
         self.get_destination_retry_timings.prefill(
             destination,
-            DestinationsTable.EntryType(
-                destination,
-                retry_last_ts,
-                retry_interval
-            )
+            {
+                "destination": destination,
+                "retry_last_ts": retry_last_ts,
+                "retry_interval": retry_interval
+            },
         )
 
         # XXX: we could chose to not bother persisting this if our cache thinks
@@ -270,18 +280,27 @@ class TransactionStore(SQLBaseStore):
                                        retry_last_ts, retry_interval):
 
         query = (
-            "REPLACE INTO %s "
-            "(destination, retry_last_ts, retry_interval) "
-            "VALUES (?, ?, ?) "
-        ) % DestinationsTable.table_name
+            "INSERT INTO destinations"
+            " (destination, retry_last_ts, retry_interval)"
+            " VALUES (?, ?, ?)"
+            " ON DUPLICATE KEY UPDATE"
+            " retry_last_ts=?, retry_interval=?"
+        )
 
-        txn.execute(query, (destination, retry_last_ts, retry_interval))
+        txn.execute(
+            query,
+            (
+                destination,
+                retry_last_ts, retry_interval,
+                retry_last_ts, retry_interval,
+            )
+        )
 
     def get_destinations_needing_retry(self):
         """Get all destinations which are due a retry for sending a transaction.
 
         Returns:
-            list: A list of `DestinationsTable.EntryType`
+            list: A list of dicts
         """
 
         return self.runInteraction(
@@ -289,14 +308,17 @@ class TransactionStore(SQLBaseStore):
             self._get_destinations_needing_retry
         )
 
-    def _get_destinations_needing_retry(cls, txn):
-        where = "retry_last_ts > 0 and retry_next_ts < now()"
-        query = DestinationsTable.select_statement(where)
-        txn.execute(query)
-        return DestinationsTable.decode_results(txn.fetchall())
+    def _get_destinations_needing_retry(self, txn):
+        query = (
+            "SELECT * FROM destinations"
+            " WHERE retry_last_ts > 0 and retry_next_ts < ?"
+        )
+
+        txn.execute(query, (self._clock.time_msec(),))
+        return self.cursor_to_dict(txn)
 
 
-class ReceivedTransactionsTable(Table):
+class ReceivedTransactionsTable(object):
     table_name = "received_transactions"
 
     fields = [
@@ -308,10 +330,8 @@ class ReceivedTransactionsTable(Table):
         "has_been_referenced",
     ]
 
-    EntryType = namedtuple("ReceivedTransactionsEntry", fields)
 
-
-class SentTransactions(Table):
+class SentTransactions(object):
     table_name = "sent_transactions"
 
     fields = [
@@ -326,7 +346,7 @@ class SentTransactions(Table):
     EntryType = namedtuple("SentTransactionsEntry", fields)
 
 
-class TransactionsToPduTable(Table):
+class TransactionsToPduTable(object):
     table_name = "transaction_id_to_pdu"
 
     fields = [
@@ -336,10 +356,8 @@ class TransactionsToPduTable(Table):
         "pdu_origin",
     ]
 
-    EntryType = namedtuple("TransactionsToPduEntry", fields)
-
 
-class DestinationsTable(Table):
+class DestinationsTable(object):
     table_name = "destinations"
 
     fields = [
@@ -348,4 +366,3 @@ class DestinationsTable(Table):
         "retry_interval",
     ]
 
-    EntryType = namedtuple("DestinationsEntry", fields)