diff options
Diffstat (limited to 'synapse/storage/transactions.py')
-rw-r--r-- | synapse/storage/transactions.py | 237 |
1 files changed, 130 insertions, 107 deletions
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 0b8a3b7a07..624da4a9dc 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, Table, cached +from ._base import SQLBaseStore, cached from collections import namedtuple +from syutil.jsonutil import encode_canonical_json import logging logger = logging.getLogger(__name__) @@ -46,15 +47,19 @@ class TransactionStore(SQLBaseStore): ) def _get_received_txn_response(self, txn, transaction_id, origin): - where_clause = "transaction_id = ? AND origin = ?" - query = ReceivedTransactionsTable.select_statement(where_clause) - - txn.execute(query, (transaction_id, origin)) - - results = ReceivedTransactionsTable.decode_results(txn.fetchall()) + result = self._simple_select_one_txn( + txn, + table=ReceivedTransactionsTable.table_name, + keyvalues={ + "transaction_id": transaction_id, + "origin": origin, + }, + retcols=ReceivedTransactionsTable.fields, + allow_none=True, + ) - if results and results[0].response_code: - return (results[0].response_code, results[0].response_json) + if result and result.response_code: + return result["response_code"], result["response_json"] else: return None @@ -72,22 +77,18 @@ class TransactionStore(SQLBaseStore): response_json (str) """ - return self.runInteraction( - "set_received_txn_response", - self._set_received_txn_response, - transaction_id, origin, code, response_dict + return self._simple_insert( + table=ReceivedTransactionsTable.table_name, + values={ + "transaction_id": transaction_id, + "origin": origin, + "response_code": code, + "response_json": buffer(encode_canonical_json(response_dict)), + }, + or_ignore=True, + desc="set_received_txn_response", ) - def _set_received_txn_response(self, txn, transaction_id, origin, code, - response_json): - query = ( - "UPDATE %s " - "SET response_code = ?, response_json = ? " - "WHERE transaction_id = ? AND origin = ?" - ) % ReceivedTransactionsTable.table_name - - txn.execute(query, (code, response_json, transaction_id, origin)) - def prep_send_transaction(self, transaction_id, destination, origin_server_ts): """Persists an outgoing transaction and calculates the values for the @@ -114,41 +115,38 @@ class TransactionStore(SQLBaseStore): def _prep_send_transaction(self, txn, transaction_id, destination, origin_server_ts): + next_id = self._transaction_id_gen.get_next_txn(txn) + # First we find out what the prev_txns should be. # Since we know that we are only sending one transaction at a time, # we can simply take the last one. - query = "%s ORDER BY id DESC LIMIT 1" % ( - SentTransactions.select_statement("destination = ?"), - ) + query = ( + "SELECT * FROM sent_transactions" + " WHERE destination = ?" + " ORDER BY id DESC LIMIT 1" + ) - results = txn.execute(query, (destination,)) - results = SentTransactions.decode_results(results) + txn.execute(query, (destination,)) + results = self.cursor_to_dict(txn) - prev_txns = [r.transaction_id for r in results] + prev_txns = [r["transaction_id"] for r in results] # Actually add the new transaction to the sent_transactions table. - query = SentTransactions.insert_statement() - txn.execute(query, SentTransactions.EntryType( - None, - transaction_id=transaction_id, - destination=destination, - ts=origin_server_ts, - response_code=0, - response_json=None - )) - - # Update the tx id -> pdu id mapping - - # values = [ - # (transaction_id, destination, pdu[0], pdu[1]) - # for pdu in pdu_list - # ] - # - # logger.debug("Inserting: %s", repr(values)) - # - # query = TransactionsToPduTable.insert_statement() - # txn.executemany(query, values) + self._simple_insert_txn( + txn, + table=SentTransactions.table_name, + values={ + "id": next_id, + "transaction_id": transaction_id, + "destination": destination, + "ts": origin_server_ts, + "response_code": 0, + "response_json": None, + } + ) + + # TODO Update the tx id -> pdu id mapping return prev_txns @@ -164,18 +162,24 @@ class TransactionStore(SQLBaseStore): return self.runInteraction( "delivered_txn", self._delivered_txn, - transaction_id, destination, code, response_dict + transaction_id, destination, code, + buffer(encode_canonical_json(response_dict)), ) - def _delivered_txn(cls, txn, transaction_id, destination, + def _delivered_txn(self, txn, transaction_id, destination, code, response_json): - query = ( - "UPDATE %s " - "SET response_code = ?, response_json = ? " - "WHERE transaction_id = ? AND destination = ?" - ) % SentTransactions.table_name - - txn.execute(query, (code, response_json, transaction_id, destination)) + self._simple_update_one_txn( + txn, + table=SentTransactions.table_name, + keyvalues={ + "transaction_id": transaction_id, + "destination": destination, + }, + updatevalues={ + "response_code": code, + "response_json": None, # For now, don't persist response_json + } + ) def get_transactions_after(self, transaction_id, destination): """Get all transactions after a given local transaction_id. @@ -185,25 +189,26 @@ class TransactionStore(SQLBaseStore): destination (str) Returns: - list: A list of `ReceivedTransactionsTable.EntryType` + list: A list of dicts """ return self.runInteraction( "get_transactions_after", self._get_transactions_after, transaction_id, destination ) - def _get_transactions_after(cls, txn, transaction_id, destination): - where = ( - "destination = ? AND id > (select id FROM %s WHERE " - "transaction_id = ? AND destination = ?)" - ) % ( - SentTransactions.table_name + def _get_transactions_after(self, txn, transaction_id, destination): + query = ( + "SELECT * FROM sent_transactions" + " WHERE destination = ? AND id >" + " (" + " SELECT id FROM sent_transactions" + " WHERE transaction_id = ? AND destination = ?" + " )" ) - query = SentTransactions.select_statement(where) txn.execute(query, (destination, transaction_id, destination)) - return ReceivedTransactionsTable.decode_results(txn.fetchall()) + return self.cursor_to_dict(txn) @cached() def get_destination_retry_timings(self, destination): @@ -214,22 +219,27 @@ class TransactionStore(SQLBaseStore): Returns: None if not retrying - Otherwise a DestinationsTable.EntryType for the retry scheme + Otherwise a dict for the retry scheme """ return self.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination) - def _get_destination_retry_timings(cls, txn, destination): - query = DestinationsTable.select_statement("destination = ?") - txn.execute(query, (destination,)) - result = txn.fetchall() - if result: - result = DestinationsTable.decode_single_result(result) - if result.retry_last_ts > 0: - return result - else: - return None + def _get_destination_retry_timings(self, txn, destination): + result = self._simple_select_one_txn( + txn, + table=DestinationsTable.table_name, + keyvalues={ + "destination": destination, + }, + retcols=DestinationsTable.fields, + allow_none=True, + ) + + if result and result["retry_last_ts"] > 0: + return result + else: + return None def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval): @@ -245,11 +255,11 @@ class TransactionStore(SQLBaseStore): # As this is the new value, we might as well prefill the cache self.get_destination_retry_timings.prefill( destination, - DestinationsTable.EntryType( - destination, - retry_last_ts, - retry_interval - ) + { + "destination": destination, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval + }, ) # XXX: we could chose to not bother persisting this if our cache thinks @@ -262,22 +272,38 @@ class TransactionStore(SQLBaseStore): retry_interval, ) - def _set_destination_retry_timings(cls, txn, destination, + def _set_destination_retry_timings(self, txn, destination, retry_last_ts, retry_interval): - query = ( - "INSERT OR REPLACE INTO %s " - "(destination, retry_last_ts, retry_interval) " - "VALUES (?, ?, ?) " - ) % DestinationsTable.table_name + "UPDATE destinations" + " SET retry_last_ts = ?, retry_interval = ?" + " WHERE destination = ?" + ) + + txn.execute( + query, + ( + retry_last_ts, retry_interval, destination, + ) + ) - txn.execute(query, (destination, retry_last_ts, retry_interval)) + if txn.rowcount == 0: + # destination wasn't already in table. Insert it. + self._simple_insert_txn( + txn, + table="destinations", + values={ + "destination": destination, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + } + ) def get_destinations_needing_retry(self): """Get all destinations which are due a retry for sending a transaction. Returns: - list: A list of `DestinationsTable.EntryType` + list: A list of dicts """ return self.runInteraction( @@ -285,14 +311,17 @@ class TransactionStore(SQLBaseStore): self._get_destinations_needing_retry ) - def _get_destinations_needing_retry(cls, txn): - where = "retry_last_ts > 0 and retry_next_ts < now()" - query = DestinationsTable.select_statement(where) - txn.execute(query) - return DestinationsTable.decode_results(txn.fetchall()) + def _get_destinations_needing_retry(self, txn): + query = ( + "SELECT * FROM destinations" + " WHERE retry_last_ts > 0 and retry_next_ts < ?" + ) + txn.execute(query, (self._clock.time_msec(),)) + return self.cursor_to_dict(txn) -class ReceivedTransactionsTable(Table): + +class ReceivedTransactionsTable(object): table_name = "received_transactions" fields = [ @@ -304,10 +333,8 @@ class ReceivedTransactionsTable(Table): "has_been_referenced", ] - EntryType = namedtuple("ReceivedTransactionsEntry", fields) - -class SentTransactions(Table): +class SentTransactions(object): table_name = "sent_transactions" fields = [ @@ -322,7 +349,7 @@ class SentTransactions(Table): EntryType = namedtuple("SentTransactionsEntry", fields) -class TransactionsToPduTable(Table): +class TransactionsToPduTable(object): table_name = "transaction_id_to_pdu" fields = [ @@ -332,10 +359,8 @@ class TransactionsToPduTable(Table): "pdu_origin", ] - EntryType = namedtuple("TransactionsToPduEntry", fields) - -class DestinationsTable(Table): +class DestinationsTable(object): table_name = "destinations" fields = [ @@ -343,5 +368,3 @@ class DestinationsTable(Table): "retry_last_ts", "retry_interval", ] - - EntryType = namedtuple("DestinationsEntry", fields) |