summary refs log tree commit diff
path: root/synapse/storage/pdu.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/pdu.py')
-rw-r--r--synapse/storage/pdu.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index b1cb0185a6..9bdc831fd8 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper
 from synapse.federation.units import Pdu
 from synapse.util.logutils import log_function
 
+from syutil.base64util import encode_base64
+
 from collections import namedtuple
 
 import logging
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -64,6 +67,13 @@ class PduStore(SQLBaseStore):
                 for r in PduEdgesTable.decode_results(txn.fetchall())
             ]
 
+            edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
+
+            hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin)
+            signatures = self._get_pdu_origin_signatures_txn(
+                txn, pdu_id, origin
+            )
+
             query = (
                 "SELECT %(fields)s FROM %(pdus)s as p "
                 "LEFT JOIN %(state)s as s "
@@ -80,7 +90,9 @@ class PduStore(SQLBaseStore):
 
             row = txn.fetchone()
             if row:
-                results.append(PduTuple(PduEntry(*row), edges))
+                results.append(PduTuple(
+                    PduEntry(*row), edges, hashes, signatures, edge_hashes
+                ))
 
         return results
 
@@ -309,9 +321,14 @@ class PduStore(SQLBaseStore):
             (context, )
         )
 
-        results = txn.fetchall()
+        results = []
+        for pdu_id, origin, depth in txn.fetchall():
+            hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin)
+            sha256_bytes = hashes["sha256"]
+            prev_hashes = {"sha256": encode_base64(sha256_bytes)}
+            results.append((pdu_id, origin, prev_hashes, depth))
 
-        return [(row[0], row[1], row[2]) for row in results]
+        return results
 
     @defer.inlineCallbacks
     def get_oldest_pdus_in_context(self, context):
@@ -430,7 +447,7 @@ class PduStore(SQLBaseStore):
                 "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
                 % PduForwardExtremitiesTable.table_name
             )
-            txn.executemany(query, prev_pdus)
+            txn.executemany(query, list(p[:2] for p in prev_pdus))
 
             # We only insert as a forward extremety the new pdu if there are no
             # other pdus that reference it as a prev pdu
@@ -453,7 +470,7 @@ class PduStore(SQLBaseStore):
             # deleted in a second if they're incorrect anyway.
             txn.executemany(
                 PduBackwardExtremitiesTable.insert_statement(),
-                [(i, o, context) for i, o in prev_pdus]
+                [(i, o, context) for i, o, _ in prev_pdus]
             )
 
             # Also delete from the backwards extremities table all ones that
@@ -914,7 +931,7 @@ This does not include a prev_pdus key.
 
 PduTuple = namedtuple(
     "PduTuple",
-    ("pdu_entry", "prev_pdu_list")
+    ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
 )
 """ This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
 the `prev_pdus` key of a PDU.