diff options
Diffstat (limited to 'synapse/storage/pdu.py')
-rw-r--r-- | synapse/storage/pdu.py | 48 |
1 files changed, 41 insertions, 7 deletions
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index d70467dcd6..4a4341907b 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper from synapse.federation.units import Pdu from synapse.util.logutils import log_function +from syutil.base64util import encode_base64 + from collections import namedtuple import logging + logger = logging.getLogger(__name__) @@ -44,7 +47,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( - self._get_pdu_tuple, pdu_id, origin + "get_pdu", self._get_pdu_tuple, pdu_id, origin ) def _get_pdu_tuple(self, txn, pdu_id, origin): @@ -64,6 +67,13 @@ class PduStore(SQLBaseStore): for r in PduEdgesTable.decode_results(txn.fetchall()) ] + edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin) + + hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin) + signatures = self._get_pdu_origin_signatures_txn( + txn, pdu_id, origin + ) + query = ( "SELECT %(fields)s FROM %(pdus)s as p " "LEFT JOIN %(state)s as s " @@ -80,7 +90,9 @@ class PduStore(SQLBaseStore): row = txn.fetchone() if row: - results.append(PduTuple(PduEntry(*row), edges)) + results.append(PduTuple( + PduEntry(*row), edges, hashes, signatures, edge_hashes + )) return results @@ -96,6 +108,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( + "get_current_state_for_context", self._get_current_state_for_context, context ) @@ -144,6 +157,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( + "mark_pdu_as_processed", self._mark_as_processed, pdu_id, pdu_origin ) @@ -153,6 +167,7 @@ class PduStore(SQLBaseStore): def get_all_pdus_from_context(self, context): """Get a list of all PDUs for a given context.""" return self.runInteraction( + "get_all_pdus_from_context", self._get_all_pdus_from_context, context, ) @@ -180,6 +195,7 @@ class PduStore(SQLBaseStore): list: A list of PduTuples """ return self.runInteraction( + "get_backfill", self._get_backfill, context, pdu_list, limit ) @@ -241,6 +257,7 @@ class PduStore(SQLBaseStore): context (str) """ return self.runInteraction( + "get_min_depth_for_context", self._get_min_depth_for_context, context ) @@ -277,6 +294,13 @@ class PduStore(SQLBaseStore): (context, depth) ) + def get_latest_pdus_in_context(self, context): + return self.runInteraction( + "get_latest_pdus_in_context", + self._get_latest_pdus_in_context, + context + ) + def _get_latest_pdus_in_context(self, txn, context): """Get's a list of the most current pdus for a given context. This is used when we are sending a Pdu and need to fill out the `prev_pdus` @@ -303,9 +327,14 @@ class PduStore(SQLBaseStore): (context, ) ) - results = txn.fetchall() + results = [] + for pdu_id, origin, depth in txn.fetchall(): + hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin) + sha256_bytes = hashes["sha256"] + prev_hashes = {"sha256": encode_base64(sha256_bytes)} + results.append((pdu_id, origin, prev_hashes, depth)) - return [(row[0], row[1], row[2]) for row in results] + return results @defer.inlineCallbacks def get_oldest_pdus_in_context(self, context): @@ -347,6 +376,7 @@ class PduStore(SQLBaseStore): """ return self.runInteraction( + "is_pdu_new", self._is_pdu_new, pdu_id=pdu_id, origin=origin, @@ -424,7 +454,7 @@ class PduStore(SQLBaseStore): "DELETE FROM %s WHERE pdu_id = ? AND origin = ?" % PduForwardExtremitiesTable.table_name ) - txn.executemany(query, prev_pdus) + txn.executemany(query, list(p[:2] for p in prev_pdus)) # We only insert as a forward extremety the new pdu if there are no # other pdus that reference it as a prev pdu @@ -447,7 +477,7 @@ class PduStore(SQLBaseStore): # deleted in a second if they're incorrect anyway. txn.executemany( PduBackwardExtremitiesTable.insert_statement(), - [(i, o, context) for i, o in prev_pdus] + [(i, o, context) for i, o, _ in prev_pdus] ) # Also delete from the backwards extremities table all ones that @@ -500,6 +530,7 @@ class StatePduStore(SQLBaseStore): def get_unresolved_state_tree(self, new_state_pdu): return self.runInteraction( + "get_unresolved_state_tree", self._get_unresolved_state_tree, new_state_pdu ) @@ -539,6 +570,7 @@ class StatePduStore(SQLBaseStore): def update_current_state(self, pdu_id, origin, context, pdu_type, state_key): return self.runInteraction( + "update_current_state", self._update_current_state, pdu_id, origin, context, pdu_type, state_key ) @@ -578,6 +610,7 @@ class StatePduStore(SQLBaseStore): """ return self.runInteraction( + "get_current_state_pdu", self._get_current_state_pdu, context, pdu_type, state_key ) @@ -637,6 +670,7 @@ class StatePduStore(SQLBaseStore): bool: True if the new_pdu clobbered the current state, False if not """ return self.runInteraction( + "handle_new_state", self._handle_new_state, new_pdu ) @@ -908,7 +942,7 @@ This does not include a prev_pdus key. PduTuple = namedtuple( "PduTuple", - ("pdu_entry", "prev_pdu_list") + ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes") ) """ This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent the `prev_pdus` key of a PDU. |