summary refs log tree commit diff
path: root/synapse/storage/transactions.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/transactions.py')
-rw-r--r--synapse/storage/transactions.py237
1 files changed, 130 insertions, 107 deletions
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 0b8a3b7a07..624da4a9dc 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -13,10 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, Table, cached
+from ._base import SQLBaseStore, cached
 
 from collections import namedtuple
 
+from syutil.jsonutil import encode_canonical_json
 import logging
 
 logger = logging.getLogger(__name__)
@@ -46,15 +47,19 @@ class TransactionStore(SQLBaseStore):
         )
 
     def _get_received_txn_response(self, txn, transaction_id, origin):
-        where_clause = "transaction_id = ? AND origin = ?"
-        query = ReceivedTransactionsTable.select_statement(where_clause)
-
-        txn.execute(query, (transaction_id, origin))
-
-        results = ReceivedTransactionsTable.decode_results(txn.fetchall())
+        result = self._simple_select_one_txn(
+            txn,
+            table=ReceivedTransactionsTable.table_name,
+            keyvalues={
+                "transaction_id": transaction_id,
+                "origin": origin,
+            },
+            retcols=ReceivedTransactionsTable.fields,
+            allow_none=True,
+        )
 
-        if results and results[0].response_code:
-            return (results[0].response_code, results[0].response_json)
+        if result and result.response_code:
+            return result["response_code"], result["response_json"]
         else:
             return None
 
@@ -72,22 +77,18 @@ class TransactionStore(SQLBaseStore):
             response_json (str)
         """
 
-        return self.runInteraction(
-            "set_received_txn_response",
-            self._set_received_txn_response,
-            transaction_id, origin, code, response_dict
+        return self._simple_insert(
+            table=ReceivedTransactionsTable.table_name,
+            values={
+                "transaction_id": transaction_id,
+                "origin": origin,
+                "response_code": code,
+                "response_json": buffer(encode_canonical_json(response_dict)),
+            },
+            or_ignore=True,
+            desc="set_received_txn_response",
         )
 
-    def _set_received_txn_response(self, txn, transaction_id, origin, code,
-                                   response_json):
-        query = (
-            "UPDATE %s "
-            "SET response_code = ?, response_json = ? "
-            "WHERE transaction_id = ? AND origin = ?"
-        ) % ReceivedTransactionsTable.table_name
-
-        txn.execute(query, (code, response_json, transaction_id, origin))
-
     def prep_send_transaction(self, transaction_id, destination,
                               origin_server_ts):
         """Persists an outgoing transaction and calculates the values for the
@@ -114,41 +115,38 @@ class TransactionStore(SQLBaseStore):
     def _prep_send_transaction(self, txn, transaction_id, destination,
                                origin_server_ts):
 
+        next_id = self._transaction_id_gen.get_next_txn(txn)
+
         # First we find out what the prev_txns should be.
         # Since we know that we are only sending one transaction at a time,
         # we can simply take the last one.
-        query = "%s ORDER BY id DESC LIMIT 1" % (
-                SentTransactions.select_statement("destination = ?"),
-            )
+        query = (
+            "SELECT * FROM sent_transactions"
+            " WHERE destination = ?"
+            " ORDER BY id DESC LIMIT 1"
+        )
 
-        results = txn.execute(query, (destination,))
-        results = SentTransactions.decode_results(results)
+        txn.execute(query, (destination,))
+        results = self.cursor_to_dict(txn)
 
-        prev_txns = [r.transaction_id for r in results]
+        prev_txns = [r["transaction_id"] for r in results]
 
         # Actually add the new transaction to the sent_transactions table.
 
-        query = SentTransactions.insert_statement()
-        txn.execute(query, SentTransactions.EntryType(
-            None,
-            transaction_id=transaction_id,
-            destination=destination,
-            ts=origin_server_ts,
-            response_code=0,
-            response_json=None
-        ))
-
-        # Update the tx id -> pdu id mapping
-
-        # values = [
-        #     (transaction_id, destination, pdu[0], pdu[1])
-        #     for pdu in pdu_list
-        # ]
-        #
-        # logger.debug("Inserting: %s", repr(values))
-        #
-        # query = TransactionsToPduTable.insert_statement()
-        # txn.executemany(query, values)
+        self._simple_insert_txn(
+            txn,
+            table=SentTransactions.table_name,
+            values={
+                "id": next_id,
+                "transaction_id": transaction_id,
+                "destination": destination,
+                "ts": origin_server_ts,
+                "response_code": 0,
+                "response_json": None,
+            }
+        )
+
+        # TODO Update the tx id -> pdu id mapping
 
         return prev_txns
 
@@ -164,18 +162,24 @@ class TransactionStore(SQLBaseStore):
         return self.runInteraction(
             "delivered_txn",
             self._delivered_txn,
-            transaction_id, destination, code, response_dict
+            transaction_id, destination, code,
+            buffer(encode_canonical_json(response_dict)),
         )
 
-    def _delivered_txn(cls, txn, transaction_id, destination,
+    def _delivered_txn(self, txn, transaction_id, destination,
                        code, response_json):
-        query = (
-            "UPDATE %s "
-            "SET response_code = ?, response_json = ? "
-            "WHERE transaction_id = ? AND destination = ?"
-        ) % SentTransactions.table_name
-
-        txn.execute(query, (code, response_json, transaction_id, destination))
+        self._simple_update_one_txn(
+            txn,
+            table=SentTransactions.table_name,
+            keyvalues={
+                "transaction_id": transaction_id,
+                "destination": destination,
+            },
+            updatevalues={
+                "response_code": code,
+                "response_json": None,  # For now, don't persist response_json
+            }
+        )
 
     def get_transactions_after(self, transaction_id, destination):
         """Get all transactions after a given local transaction_id.
@@ -185,25 +189,26 @@ class TransactionStore(SQLBaseStore):
             destination (str)
 
         Returns:
-            list: A list of `ReceivedTransactionsTable.EntryType`
+            list: A list of dicts
         """
         return self.runInteraction(
             "get_transactions_after",
             self._get_transactions_after, transaction_id, destination
         )
 
-    def _get_transactions_after(cls, txn, transaction_id, destination):
-        where = (
-            "destination = ? AND id > (select id FROM %s WHERE "
-            "transaction_id = ? AND destination = ?)"
-        ) % (
-            SentTransactions.table_name
+    def _get_transactions_after(self, txn, transaction_id, destination):
+        query = (
+            "SELECT * FROM sent_transactions"
+            " WHERE destination = ? AND id >"
+            " ("
+            " SELECT id FROM sent_transactions"
+            " WHERE transaction_id = ? AND destination = ?"
+            " )"
         )
-        query = SentTransactions.select_statement(where)
 
         txn.execute(query, (destination, transaction_id, destination))
 
-        return ReceivedTransactionsTable.decode_results(txn.fetchall())
+        return self.cursor_to_dict(txn)
 
     @cached()
     def get_destination_retry_timings(self, destination):
@@ -214,22 +219,27 @@ class TransactionStore(SQLBaseStore):
 
         Returns:
             None if not retrying
-            Otherwise a DestinationsTable.EntryType for the retry scheme
+            Otherwise a dict for the retry scheme
         """
         return self.runInteraction(
             "get_destination_retry_timings",
             self._get_destination_retry_timings, destination)
 
-    def _get_destination_retry_timings(cls, txn, destination):
-        query = DestinationsTable.select_statement("destination = ?")
-        txn.execute(query, (destination,))
-        result = txn.fetchall()
-        if result:
-            result = DestinationsTable.decode_single_result(result)
-            if result.retry_last_ts > 0:
-                return result
-            else:
-                return None
+    def _get_destination_retry_timings(self, txn, destination):
+        result = self._simple_select_one_txn(
+            txn,
+            table=DestinationsTable.table_name,
+            keyvalues={
+                "destination": destination,
+            },
+            retcols=DestinationsTable.fields,
+            allow_none=True,
+        )
+
+        if result and result["retry_last_ts"] > 0:
+            return result
+        else:
+            return None
 
     def set_destination_retry_timings(self, destination,
                                       retry_last_ts, retry_interval):
@@ -245,11 +255,11 @@ class TransactionStore(SQLBaseStore):
         # As this is the new value, we might as well prefill the cache
         self.get_destination_retry_timings.prefill(
             destination,
-            DestinationsTable.EntryType(
-                destination,
-                retry_last_ts,
-                retry_interval
-            )
+            {
+                "destination": destination,
+                "retry_last_ts": retry_last_ts,
+                "retry_interval": retry_interval
+            },
         )
 
         # XXX: we could chose to not bother persisting this if our cache thinks
@@ -262,22 +272,38 @@ class TransactionStore(SQLBaseStore):
             retry_interval,
         )
 
-    def _set_destination_retry_timings(cls, txn, destination,
+    def _set_destination_retry_timings(self, txn, destination,
                                        retry_last_ts, retry_interval):
-
         query = (
-            "INSERT OR REPLACE INTO %s "
-            "(destination, retry_last_ts, retry_interval) "
-            "VALUES (?, ?, ?) "
-        ) % DestinationsTable.table_name
+            "UPDATE destinations"
+            " SET retry_last_ts = ?, retry_interval = ?"
+            " WHERE destination = ?"
+        )
+
+        txn.execute(
+            query,
+            (
+                retry_last_ts, retry_interval, destination,
+            )
+        )
 
-        txn.execute(query, (destination, retry_last_ts, retry_interval))
+        if txn.rowcount == 0:
+            # destination wasn't already in table. Insert it.
+            self._simple_insert_txn(
+                txn,
+                table="destinations",
+                values={
+                    "destination": destination,
+                    "retry_last_ts": retry_last_ts,
+                    "retry_interval": retry_interval,
+                }
+            )
 
     def get_destinations_needing_retry(self):
         """Get all destinations which are due a retry for sending a transaction.
 
         Returns:
-            list: A list of `DestinationsTable.EntryType`
+            list: A list of dicts
         """
 
         return self.runInteraction(
@@ -285,14 +311,17 @@ class TransactionStore(SQLBaseStore):
             self._get_destinations_needing_retry
         )
 
-    def _get_destinations_needing_retry(cls, txn):
-        where = "retry_last_ts > 0 and retry_next_ts < now()"
-        query = DestinationsTable.select_statement(where)
-        txn.execute(query)
-        return DestinationsTable.decode_results(txn.fetchall())
+    def _get_destinations_needing_retry(self, txn):
+        query = (
+            "SELECT * FROM destinations"
+            " WHERE retry_last_ts > 0 and retry_next_ts < ?"
+        )
 
+        txn.execute(query, (self._clock.time_msec(),))
+        return self.cursor_to_dict(txn)
 
-class ReceivedTransactionsTable(Table):
+
+class ReceivedTransactionsTable(object):
     table_name = "received_transactions"
 
     fields = [
@@ -304,10 +333,8 @@ class ReceivedTransactionsTable(Table):
         "has_been_referenced",
     ]
 
-    EntryType = namedtuple("ReceivedTransactionsEntry", fields)
-
 
-class SentTransactions(Table):
+class SentTransactions(object):
     table_name = "sent_transactions"
 
     fields = [
@@ -322,7 +349,7 @@ class SentTransactions(Table):
     EntryType = namedtuple("SentTransactionsEntry", fields)
 
 
-class TransactionsToPduTable(Table):
+class TransactionsToPduTable(object):
     table_name = "transaction_id_to_pdu"
 
     fields = [
@@ -332,10 +359,8 @@ class TransactionsToPduTable(Table):
         "pdu_origin",
     ]
 
-    EntryType = namedtuple("TransactionsToPduEntry", fields)
 
-
-class DestinationsTable(Table):
+class DestinationsTable(object):
     table_name = "destinations"
 
     fields = [
@@ -343,5 +368,3 @@ class DestinationsTable(Table):
         "retry_last_ts",
         "retry_interval",
     ]
-
-    EntryType = namedtuple("DestinationsEntry", fields)