diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index d70467dcd6..4a4341907b 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper
from synapse.federation.units import Pdu
from synapse.util.logutils import log_function
+from syutil.base64util import encode_base64
+
from collections import namedtuple
import logging
+
logger = logging.getLogger(__name__)
@@ -44,7 +47,7 @@ class PduStore(SQLBaseStore):
"""
return self.runInteraction(
- self._get_pdu_tuple, pdu_id, origin
+ "get_pdu", self._get_pdu_tuple, pdu_id, origin
)
def _get_pdu_tuple(self, txn, pdu_id, origin):
@@ -64,6 +67,13 @@ class PduStore(SQLBaseStore):
for r in PduEdgesTable.decode_results(txn.fetchall())
]
+ edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
+
+ hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin)
+ signatures = self._get_pdu_origin_signatures_txn(
+ txn, pdu_id, origin
+ )
+
query = (
"SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s "
@@ -80,7 +90,9 @@ class PduStore(SQLBaseStore):
row = txn.fetchone()
if row:
- results.append(PduTuple(PduEntry(*row), edges))
+ results.append(PduTuple(
+ PduEntry(*row), edges, hashes, signatures, edge_hashes
+ ))
return results
@@ -96,6 +108,7 @@ class PduStore(SQLBaseStore):
"""
return self.runInteraction(
+ "get_current_state_for_context",
self._get_current_state_for_context,
context
)
@@ -144,6 +157,7 @@ class PduStore(SQLBaseStore):
"""
return self.runInteraction(
+ "mark_pdu_as_processed",
self._mark_as_processed, pdu_id, pdu_origin
)
@@ -153,6 +167,7 @@ class PduStore(SQLBaseStore):
def get_all_pdus_from_context(self, context):
"""Get a list of all PDUs for a given context."""
return self.runInteraction(
+ "get_all_pdus_from_context",
self._get_all_pdus_from_context, context,
)
@@ -180,6 +195,7 @@ class PduStore(SQLBaseStore):
list: A list of PduTuples
"""
return self.runInteraction(
+ "get_backfill",
self._get_backfill, context, pdu_list, limit
)
@@ -241,6 +257,7 @@ class PduStore(SQLBaseStore):
context (str)
"""
return self.runInteraction(
+ "get_min_depth_for_context",
self._get_min_depth_for_context, context
)
@@ -277,6 +294,13 @@ class PduStore(SQLBaseStore):
(context, depth)
)
+ def get_latest_pdus_in_context(self, context):
+ return self.runInteraction(
+ "get_latest_pdus_in_context",
+ self._get_latest_pdus_in_context,
+ context
+ )
+
def _get_latest_pdus_in_context(self, txn, context):
"""Get's a list of the most current pdus for a given context. This is
used when we are sending a Pdu and need to fill out the `prev_pdus`
@@ -303,9 +327,14 @@ class PduStore(SQLBaseStore):
(context, )
)
- results = txn.fetchall()
+ results = []
+ for pdu_id, origin, depth in txn.fetchall():
+ hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin)
+ sha256_bytes = hashes["sha256"]
+ prev_hashes = {"sha256": encode_base64(sha256_bytes)}
+ results.append((pdu_id, origin, prev_hashes, depth))
- return [(row[0], row[1], row[2]) for row in results]
+ return results
@defer.inlineCallbacks
def get_oldest_pdus_in_context(self, context):
@@ -347,6 +376,7 @@ class PduStore(SQLBaseStore):
"""
return self.runInteraction(
+ "is_pdu_new",
self._is_pdu_new,
pdu_id=pdu_id,
origin=origin,
@@ -424,7 +454,7 @@ class PduStore(SQLBaseStore):
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
% PduForwardExtremitiesTable.table_name
)
- txn.executemany(query, prev_pdus)
+ txn.executemany(query, list(p[:2] for p in prev_pdus))
# We only insert as a forward extremety the new pdu if there are no
# other pdus that reference it as a prev pdu
@@ -447,7 +477,7 @@ class PduStore(SQLBaseStore):
# deleted in a second if they're incorrect anyway.
txn.executemany(
PduBackwardExtremitiesTable.insert_statement(),
- [(i, o, context) for i, o in prev_pdus]
+ [(i, o, context) for i, o, _ in prev_pdus]
)
# Also delete from the backwards extremities table all ones that
@@ -500,6 +530,7 @@ class StatePduStore(SQLBaseStore):
def get_unresolved_state_tree(self, new_state_pdu):
return self.runInteraction(
+ "get_unresolved_state_tree",
self._get_unresolved_state_tree, new_state_pdu
)
@@ -539,6 +570,7 @@ class StatePduStore(SQLBaseStore):
def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key):
return self.runInteraction(
+ "update_current_state",
self._update_current_state,
pdu_id, origin, context, pdu_type, state_key
)
@@ -578,6 +610,7 @@ class StatePduStore(SQLBaseStore):
"""
return self.runInteraction(
+ "get_current_state_pdu",
self._get_current_state_pdu, context, pdu_type, state_key
)
@@ -637,6 +670,7 @@ class StatePduStore(SQLBaseStore):
bool: True if the new_pdu clobbered the current state, False if not
"""
return self.runInteraction(
+ "handle_new_state",
self._handle_new_state, new_pdu
)
@@ -908,7 +942,7 @@ This does not include a prev_pdus key.
PduTuple = namedtuple(
"PduTuple",
- ("pdu_entry", "prev_pdu_list")
+ ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
)
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
the `prev_pdus` key of a PDU.
|