summary refs log tree commit diff
path: root/synapse/storage/pdu.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/pdu.py')
-rw-r--r--synapse/storage/pdu.py48
1 files changed, 41 insertions, 7 deletions
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index d70467dcd6..4a4341907b 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper
 from synapse.federation.units import Pdu
 from synapse.util.logutils import log_function
 
+from syutil.base64util import encode_base64
+
 from collections import namedtuple
 
 import logging
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -44,7 +47,7 @@ class PduStore(SQLBaseStore):
         """
 
         return self.runInteraction(
-            self._get_pdu_tuple, pdu_id, origin
+            "get_pdu", self._get_pdu_tuple, pdu_id, origin
         )
 
     def _get_pdu_tuple(self, txn, pdu_id, origin):
@@ -64,6 +67,13 @@ class PduStore(SQLBaseStore):
                 for r in PduEdgesTable.decode_results(txn.fetchall())
             ]
 
+            edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
+
+            hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin)
+            signatures = self._get_pdu_origin_signatures_txn(
+                txn, pdu_id, origin
+            )
+
             query = (
                 "SELECT %(fields)s FROM %(pdus)s as p "
                 "LEFT JOIN %(state)s as s "
@@ -80,7 +90,9 @@ class PduStore(SQLBaseStore):
 
             row = txn.fetchone()
             if row:
-                results.append(PduTuple(PduEntry(*row), edges))
+                results.append(PduTuple(
+                    PduEntry(*row), edges, hashes, signatures, edge_hashes
+                ))
 
         return results
 
@@ -96,6 +108,7 @@ class PduStore(SQLBaseStore):
         """
 
         return self.runInteraction(
+            "get_current_state_for_context",
             self._get_current_state_for_context,
             context
         )
@@ -144,6 +157,7 @@ class PduStore(SQLBaseStore):
         """
 
         return self.runInteraction(
+            "mark_pdu_as_processed",
             self._mark_as_processed, pdu_id, pdu_origin
         )
 
@@ -153,6 +167,7 @@ class PduStore(SQLBaseStore):
     def get_all_pdus_from_context(self, context):
         """Get a list of all PDUs for a given context."""
         return self.runInteraction(
+            "get_all_pdus_from_context",
             self._get_all_pdus_from_context, context,
         )
 
@@ -180,6 +195,7 @@ class PduStore(SQLBaseStore):
             list: A list of PduTuples
         """
         return self.runInteraction(
+            "get_backfill",
             self._get_backfill, context, pdu_list, limit
         )
 
@@ -241,6 +257,7 @@ class PduStore(SQLBaseStore):
             context (str)
         """
         return self.runInteraction(
+            "get_min_depth_for_context",
             self._get_min_depth_for_context, context
         )
 
@@ -277,6 +294,13 @@ class PduStore(SQLBaseStore):
                 (context, depth)
             )
 
+    def get_latest_pdus_in_context(self, context):
+        return self.runInteraction(
+            "get_latest_pdus_in_context",
+            self._get_latest_pdus_in_context,
+            context
+        )
+
     def _get_latest_pdus_in_context(self, txn, context):
         """Get's a list of the most current pdus for a given context. This is
         used when we are sending a Pdu and need to fill out the `prev_pdus`
@@ -303,9 +327,14 @@ class PduStore(SQLBaseStore):
             (context, )
         )
 
-        results = txn.fetchall()
+        results = []
+        for pdu_id, origin, depth in txn.fetchall():
+            hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin)
+            sha256_bytes = hashes["sha256"]
+            prev_hashes = {"sha256": encode_base64(sha256_bytes)}
+            results.append((pdu_id, origin, prev_hashes, depth))
 
-        return [(row[0], row[1], row[2]) for row in results]
+        return results
 
     @defer.inlineCallbacks
     def get_oldest_pdus_in_context(self, context):
@@ -347,6 +376,7 @@ class PduStore(SQLBaseStore):
         """
 
         return self.runInteraction(
+            "is_pdu_new",
             self._is_pdu_new,
             pdu_id=pdu_id,
             origin=origin,
@@ -424,7 +454,7 @@ class PduStore(SQLBaseStore):
                 "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
                 % PduForwardExtremitiesTable.table_name
             )
-            txn.executemany(query, prev_pdus)
+            txn.executemany(query, list(p[:2] for p in prev_pdus))
 
             # We only insert as a forward extremety the new pdu if there are no
             # other pdus that reference it as a prev pdu
@@ -447,7 +477,7 @@ class PduStore(SQLBaseStore):
             # deleted in a second if they're incorrect anyway.
             txn.executemany(
                 PduBackwardExtremitiesTable.insert_statement(),
-                [(i, o, context) for i, o in prev_pdus]
+                [(i, o, context) for i, o, _ in prev_pdus]
             )
 
             # Also delete from the backwards extremities table all ones that
@@ -500,6 +530,7 @@ class StatePduStore(SQLBaseStore):
 
     def get_unresolved_state_tree(self, new_state_pdu):
         return self.runInteraction(
+            "get_unresolved_state_tree",
             self._get_unresolved_state_tree, new_state_pdu
         )
 
@@ -539,6 +570,7 @@ class StatePduStore(SQLBaseStore):
     def update_current_state(self, pdu_id, origin, context, pdu_type,
                              state_key):
         return self.runInteraction(
+            "update_current_state",
             self._update_current_state,
             pdu_id, origin, context, pdu_type, state_key
         )
@@ -578,6 +610,7 @@ class StatePduStore(SQLBaseStore):
         """
 
         return self.runInteraction(
+            "get_current_state_pdu",
             self._get_current_state_pdu, context, pdu_type, state_key
         )
 
@@ -637,6 +670,7 @@ class StatePduStore(SQLBaseStore):
             bool: True if the new_pdu clobbered the current state, False if not
         """
         return self.runInteraction(
+            "handle_new_state",
             self._handle_new_state, new_pdu
         )
 
@@ -908,7 +942,7 @@ This does not include a prev_pdus key.
 
 PduTuple = namedtuple(
     "PduTuple",
-    ("pdu_entry", "prev_pdu_list")
+    ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
 )
 """ This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
 the `prev_pdus` key of a PDU.