diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index b2a3f0b56c..af05b47932 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -177,6 +177,14 @@ class DataStore(RoomMemberStore, RoomStore,
txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
)
+ for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus:
+ for alg, hash_base64 in prev_hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_prev_pdu_hash_txn(
+ txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg,
+ hash_bytes
+ )
+
if pdu.is_state:
self._persist_state_txn(txn, pdu.prev_pdus, cols)
else:
@@ -352,6 +360,7 @@ class DataStore(RoomMemberStore, RoomStore,
prev_pdus = self._get_latest_pdus_in_context(
txn, room_id
)
+
if state_type is not None and state_key is not None:
prev_state_pdu = self._get_current_state_pdu(
txn, room_id, state_type, state_key
@@ -401,17 +410,16 @@ class Snapshot(object):
self.prev_state_pdu = prev_state_pdu
def fill_out_prev_events(self, event):
- if hasattr(event, "prev_events"):
+ if hasattr(event, "prev_pdus"):
return
- es = [
- "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
+ event.prev_pdus = [
+ (p_id, origin, hashes)
+ for p_id, origin, hashes, _ in self.prev_pdus
]
- event.prev_events = [e for e in es if e != event.event_id]
-
if self.prev_pdus:
- event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
+ event.depth = max([int(v) for _, _, _, v in self.prev_pdus]) + 1
else:
event.depth = 0
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index 9d624429b7..a423b42dbd 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper
from synapse.federation.units import Pdu
from synapse.util.logutils import log_function
+from syutil.base64util import encode_base64
+
from collections import namedtuple
import logging
+
logger = logging.getLogger(__name__)
@@ -64,6 +67,8 @@ class PduStore(SQLBaseStore):
for r in PduEdgesTable.decode_results(txn.fetchall())
]
+ edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
+
hashes = self._get_pdu_hashes_txn(txn, pdu_id, origin)
signatures = self._get_pdu_origin_signatures_txn(
txn, pdu_id, origin
@@ -86,7 +91,7 @@ class PduStore(SQLBaseStore):
row = txn.fetchone()
if row:
results.append(PduTuple(
- PduEntry(*row), edges, hashes, signatures
+ PduEntry(*row), edges, hashes, signatures, edge_hashes
))
return results
@@ -310,9 +315,14 @@ class PduStore(SQLBaseStore):
(context, )
)
- results = txn.fetchall()
+ results = []
+ for pdu_id, origin, depth in txn.fetchall():
+ hashes = self._get_pdu_hashes_txn(txn, pdu_id, origin)
+ sha256_bytes = hashes["sha256"]
+ prev_hashes = {"sha256": encode_base64(sha256_bytes)}
+ results.append((pdu_id, origin, prev_hashes, depth))
- return [(row[0], row[1], row[2]) for row in results]
+ return results
@defer.inlineCallbacks
def get_oldest_pdus_in_context(self, context):
@@ -431,7 +441,7 @@ class PduStore(SQLBaseStore):
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
% PduForwardExtremitiesTable.table_name
)
- txn.executemany(query, prev_pdus)
+ txn.executemany(query, list(p[:2] for p in prev_pdus))
# We only insert as a forward extremety the new pdu if there are no
# other pdus that reference it as a prev pdu
@@ -454,7 +464,7 @@ class PduStore(SQLBaseStore):
# deleted in a second if they're incorrect anyway.
txn.executemany(
PduBackwardExtremitiesTable.insert_statement(),
- [(i, o, context) for i, o in prev_pdus]
+ [(i, o, context) for i, o, _ in prev_pdus]
)
# Also delete from the backwards extremities table all ones that
@@ -915,7 +925,7 @@ This does not include a prev_pdus key.
PduTuple = namedtuple(
"PduTuple",
- ("pdu_entry", "prev_pdu_list", "hashes", "signatures")
+ ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
)
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
the `prev_pdus` key of a PDU.
diff --git a/synapse/storage/schema/signatures.sql b/synapse/storage/schema/signatures.sql
index 86ee0f2377..a72c4dc35f 100644
--- a/synapse/storage/schema/signatures.sql
+++ b/synapse/storage/schema/signatures.sql
@@ -34,3 +34,19 @@ CREATE TABLE IF NOT EXISTS pdu_origin_signatures (
CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures (
pdu_id, origin
);
+
+CREATE TABLE IF NOT EXISTS pdu_edge_hashes(
+ pdu_id TEXT,
+ origin TEXT,
+ prev_pdu_id TEXT,
+ prev_origin TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (
+ pdu_id, origin, prev_pdu_id, prev_origin, algorithm
+ )
+);
+
+CREATE INDEX IF NOT EXISTS pdu_edge_hashes_id ON pdu_edge_hashes(
+ pdu_id, origin
+);
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 1f0a680500..1147102489 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -88,3 +88,34 @@ class SignatureStore(SQLBaseStore):
"signature": buffer(signature_bytes),
})
+ def _get_prev_pdu_hashes_txn(self, txn, pdu_id, origin):
+ """Get all the hashes for previous PDUs of a PDU
+ Args:
+ txn (cursor):
+ pdu_id (str): Id of the PDU.
+ origin (str): Origin of the PDU.
+ Returns:
+ dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
+ """
+ query = (
+ "SELECT prev_pdu_id, prev_origin, algorithm, hash"
+ " FROM pdu_edge_hashes"
+ " WHERE pdu_id = ? and origin = ?"
+ )
+ txn.execute(query, (pdu_id, origin))
+ results = {}
+ for prev_pdu_id, prev_origin, algorithm, hash_bytes in txn.fetchall():
+ hashes = results.setdefault((prev_pdu_id, prev_origin), {})
+ hashes[algorithm] = hash_bytes
+ return results
+
+ def _store_prev_pdu_hash_txn(self, txn, pdu_id, origin, prev_pdu_id,
+ prev_origin, algorithm, hash_bytes):
+ self._simple_insert_txn(txn, "pdu_edge_hashes", {
+ "pdu_id": pdu_id,
+ "origin": origin,
+ "prev_pdu_id": prev_pdu_id,
+ "prev_origin": prev_origin,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ })
|