summary refs log tree commit diff
path: root/synapse/storage/transactions.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/transactions.py')
-rw-r--r--synapse/storage/transactions.py287
1 files changed, 287 insertions, 0 deletions
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
new file mode 100644
index 0000000000..aa41e2ad7f
--- /dev/null
+++ b/synapse/storage/transactions.py
@@ -0,0 +1,287 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table
+from .pdu import PdusTable
+
+from collections import namedtuple
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class TransactionStore(SQLBaseStore):
+    """A collection of queries for handling PDUs.
+    """
+
+    def get_received_txn_response(self, transaction_id, origin):
+        """For an incoming transaction from a given origin, check if we have
+        already responded to it. If so, return the response code and response
+        body (as a dict).
+
+        Args:
+            transaction_id (str)
+            origin(str)
+
+        Returns:
+            tuple: None if we have not previously responded to
+            this transaction or a 2-tuple of (int, dict)
+        """
+
+        return self._db_pool.runInteraction(
+            self._get_received_txn_response, transaction_id, origin
+        )
+
+    def _get_received_txn_response(self, txn, transaction_id, origin):
+        where_clause = "transaction_id = ? AND origin = ?"
+        query = ReceivedTransactionsTable.select_statement(where_clause)
+
+        txn.execute(query, (transaction_id, origin))
+
+        results = ReceivedTransactionsTable.decode_results(txn.fetchall())
+
+        if results and results[0].response_code:
+            return (results[0].response_code, results[0].response_json)
+        else:
+            return None
+
+    def set_received_txn_response(self, transaction_id, origin, code,
+                                  response_dict):
+        """Persist the response we returened for an incoming transaction, and
+        should return for subsequent transactions with the same transaction_id
+        and origin.
+
+        Args:
+            txn
+            transaction_id (str)
+            origin (str)
+            code (int)
+            response_json (str)
+        """
+
+        return self._db_pool.runInteraction(
+            self._set_received_txn_response,
+            transaction_id, origin, code, response_dict
+        )
+
+    def _set_received_txn_response(self, txn, transaction_id, origin, code,
+                                   response_json):
+        query = (
+            "UPDATE %s "
+            "SET response_code = ?, response_json = ? "
+            "WHERE transaction_id = ? AND origin = ?"
+        ) % ReceivedTransactionsTable.table_name
+
+        txn.execute(query, (code, response_json, transaction_id, origin))
+
+    def prep_send_transaction(self, transaction_id, destination, ts, pdu_list):
+        """Persists an outgoing transaction and calculates the values for the
+        previous transaction id list.
+
+        This should be called before sending the transaction so that it has the
+        correct value for the `prev_ids` key.
+
+        Args:
+            transaction_id (str)
+            destination (str)
+            ts (int)
+            pdu_list (list)
+
+        Returns:
+            list: A list of previous transaction ids.
+        """
+
+        return self._db_pool.runInteraction(
+            self._prep_send_transaction,
+            transaction_id, destination, ts, pdu_list
+        )
+
+    def _prep_send_transaction(self, txn, transaction_id, destination, ts,
+                               pdu_list):
+
+        # First we find out what the prev_txs should be.
+        # Since we know that we are only sending one transaction at a time,
+        # we can simply take the last one.
+        query = "%s ORDER BY id DESC LIMIT 1" % (
+                SentTransactions.select_statement("destination = ?"),
+            )
+
+        results = txn.execute(query, (destination,))
+        results = SentTransactions.decode_results(results)
+
+        prev_txns = [r.transaction_id for r in results]
+
+        # Actually add the new transaction to the sent_transactions table.
+
+        query = SentTransactions.insert_statement()
+        txn.execute(query, SentTransactions.EntryType(
+            None,
+            transaction_id=transaction_id,
+            destination=destination,
+            ts=ts,
+            response_code=0,
+            response_json=None
+        ))
+
+        # Update the tx id -> pdu id mapping
+
+        values = [
+            (transaction_id, destination, pdu[0], pdu[1])
+            for pdu in pdu_list
+        ]
+
+        logger.debug("Inserting: %s", repr(values))
+
+        query = TransactionsToPduTable.insert_statement()
+        txn.executemany(query, values)
+
+        return prev_txns
+
+    def delivered_txn(self, transaction_id, destination, code, response_dict):
+        """Persists the response for an outgoing transaction.
+
+        Args:
+            transaction_id (str)
+            destination (str)
+            code (int)
+            response_json (str)
+        """
+        return self._db_pool.runInteraction(
+            self._delivered_txn,
+            transaction_id, destination, code, response_dict
+        )
+
+    def _delivered_txn(cls, txn, transaction_id, destination,
+                       code, response_json):
+        query = (
+            "UPDATE %s "
+            "SET response_code = ?, response_json = ? "
+            "WHERE transaction_id = ? AND destination = ?"
+        ) % SentTransactions.table_name
+
+        txn.execute(query, (code, response_json, transaction_id, destination))
+
+    def get_transactions_after(self, transaction_id, destination):
+        """Get all transactions after a given local transaction_id.
+
+        Args:
+            transaction_id (str)
+            destination (str)
+
+        Returns:
+            list: A list of `ReceivedTransactionsTable.EntryType`
+        """
+        return self._db_pool.runInteraction(
+            self._get_transactions_after, transaction_id, destination
+        )
+
+    def _get_transactions_after(cls, txn, transaction_id, destination):
+        where = (
+            "destination = ? AND id > (select id FROM %s WHERE "
+            "transaction_id = ? AND destination = ?)"
+        ) % (
+            SentTransactions.table_name
+        )
+        query = SentTransactions.select_statement(where)
+
+        txn.execute(query, (destination, transaction_id, destination))
+
+        return ReceivedTransactionsTable.decode_results(txn.fetchall())
+
+    def get_pdus_after_transaction(self, transaction_id, destination):
+        """For a given local transaction_id that we sent to a given destination
+        home server, return a list of PDUs that were sent to that destination
+        after it.
+
+        Args:
+            txn
+            transaction_id (str)
+            destination (str)
+
+        Returns
+            list: A list of PduTuple
+        """
+        return self._db_pool.runInteraction(
+            self._get_pdus_after_transaction,
+            transaction_id, destination
+        )
+
+    def _get_pdus_after_transaction(self, txn, transaction_id, destination):
+
+        # Query that first get's all transaction_ids with an id greater than
+        # the one given from the `sent_transactions` table. Then JOIN on this
+        # from the `tx->pdu` table to get a list of (pdu_id, origin) that
+        # specify the pdus that were sent in those transactions.
+        query = (
+            "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
+            "INNER JOIN %(sent_tx)s as st "
+            "ON tp.transaction_id = st.transaction_id "
+            "AND tp.destination = st.destination "
+            "WHERE st.id > ("
+            "SELECT id FROM %(sent_tx)s "
+            "WHERE transaction_id = ? AND destination = ?"
+        ) % {
+            "tx_pdu": TransactionsToPduTable.table_name,
+            "sent_tx": SentTransactions.table_name,
+        }
+
+        txn.execute(query, (transaction_id, destination))
+
+        pdus = PdusTable.decode_results(txn.fetchall())
+
+        return self._get_pdu_tuples(txn, pdus)
+
+
+class ReceivedTransactionsTable(Table):
+    table_name = "received_transactions"
+
+    fields = [
+        "transaction_id",
+        "origin",
+        "ts",
+        "response_code",
+        "response_json",
+        "has_been_referenced",
+    ]
+
+    EntryType = namedtuple("ReceivedTransactionsEntry", fields)
+
+
+class SentTransactions(Table):
+    table_name = "sent_transactions"
+
+    fields = [
+        "id",
+        "transaction_id",
+        "destination",
+        "ts",
+        "response_code",
+        "response_json",
+    ]
+
+    EntryType = namedtuple("SentTransactionsEntry", fields)
+
+
+class TransactionsToPduTable(Table):
+    table_name = "transaction_id_to_pdu"
+
+    fields = [
+        "transaction_id",
+        "destination",
+        "pdu_id",
+        "pdu_origin",
+    ]
+
+    EntryType = namedtuple("TransactionsToPduEntry", fields)