summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/events/__init__.py2
-rw-r--r--synapse/federation/pdu_codec.py13
-rw-r--r--synapse/federation/replication.py2
-rw-r--r--synapse/federation/units.py10
-rw-r--r--synapse/state.py4
-rw-r--r--synapse/storage/__init__.py20
-rw-r--r--synapse/storage/pdu.py22
-rw-r--r--synapse/storage/schema/signatures.sql16
-rw-r--r--synapse/storage/signatures.py31
-rw-r--r--tests/federation/test_federation.py2
-rw-r--r--tests/federation/test_pdu_codec.py4
11 files changed, 95 insertions, 31 deletions
diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py
index f66fea2904..a5a55742e0 100644
--- a/synapse/api/events/__init__.py
+++ b/synapse/api/events/__init__.py
@@ -65,13 +65,13 @@ class SynapseEvent(JsonEncodedObject):
 
     internal_keys = [
         "is_state",
-        "prev_events",
         "depth",
         "destinations",
         "origin",
         "outlier",
         "power_level",
         "redacted",
+        "prev_pdus",
     ]
 
     required_keys = [
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
index bcac5f9ae8..11fd7264b3 100644
--- a/synapse/federation/pdu_codec.py
+++ b/synapse/federation/pdu_codec.py
@@ -45,9 +45,7 @@ class PduCodec(object):
         kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
         kwargs["room_id"] = pdu.context
         kwargs["etype"] = pdu.pdu_type
-        kwargs["prev_events"] = [
-            encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
-        ]
+        kwargs["prev_pdus"] = pdu.prev_pdus
 
         if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
             kwargs["prev_state"] = encode_event_id(
@@ -78,11 +76,8 @@ class PduCodec(object):
         d["context"] = event.room_id
         d["pdu_type"] = event.type
 
-        if hasattr(event, "prev_events"):
-            d["prev_pdus"] = [
-                decode_event_id(e, self.server_name)
-                for e in event.prev_events
-            ]
+        if hasattr(event, "prev_pdus"):
+            d["prev_pdus"] = event.prev_pdus
 
         if hasattr(event, "prev_state"):
             d["prev_state_id"], d["prev_state_origin"] = (
@@ -95,7 +90,7 @@ class PduCodec(object):
         kwargs = copy.deepcopy(event.unrecognized_keys)
         kwargs.update({
             k: v for k, v in d.items()
-            if k not in ["event_id", "room_id", "type", "prev_events"]
+            if k not in ["event_id", "room_id", "type"]
         })
 
         if "ts" not in kwargs:
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 9363ac7300..788a49b8e8 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -443,7 +443,7 @@ class ReplicationLayer(object):
             min_depth = yield self.store.get_min_depth_for_context(pdu.context)
 
             if min_depth and pdu.depth > min_depth:
-                for pdu_id, origin in pdu.prev_pdus:
+                for pdu_id, origin, hashes in pdu.prev_pdus:
                     exists = yield self._get_persisted_pdu(pdu_id, origin)
 
                     if not exists:
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 3518efb215..6a43007837 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -141,8 +141,16 @@ class Pdu(JsonEncodedObject):
                 for kid, sig in pdu_tuple.signatures.items()
             }
 
+            prev_pdus = []
+            for prev_pdu in pdu_tuple.prev_pdu_list:
+                prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {})
+                prev_hashes = {
+                    alg: encode_base64(hsh) for alg, hsh in prev_hashes.items()
+                }
+                prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes))
+
             return Pdu(
-                prev_pdus=pdu_tuple.prev_pdu_list,
+                prev_pdus=prev_pdus,
                 **args
             )
         else:
diff --git a/synapse/state.py b/synapse/state.py
index 9db84c9b5c..bc6b928ec7 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -72,10 +72,6 @@ class StateHandler(object):
 
         snapshot.fill_out_prev_events(event)
 
-        event.prev_events = [
-            e for e in event.prev_events if e != event.event_id
-        ]
-
         current_state = snapshot.prev_state_pdu
 
         if current_state:
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index b2a3f0b56c..af05b47932 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -177,6 +177,14 @@ class DataStore(RoomMemberStore, RoomStore,
                 txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
             )
 
+        for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus:
+            for alg, hash_base64 in prev_hashes.items():
+                hash_bytes = decode_base64(hash_base64)
+                self._store_prev_pdu_hash_txn(
+                    txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg,
+                    hash_bytes
+                )
+
         if pdu.is_state:
             self._persist_state_txn(txn, pdu.prev_pdus, cols)
         else:
@@ -352,6 +360,7 @@ class DataStore(RoomMemberStore, RoomStore,
             prev_pdus = self._get_latest_pdus_in_context(
                 txn, room_id
             )
+
             if state_type is not None and state_key is not None:
                 prev_state_pdu = self._get_current_state_pdu(
                     txn, room_id, state_type, state_key
@@ -401,17 +410,16 @@ class Snapshot(object):
         self.prev_state_pdu = prev_state_pdu
 
     def fill_out_prev_events(self, event):
-        if hasattr(event, "prev_events"):
+        if hasattr(event, "prev_pdus"):
             return
 
-        es = [
-            "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
+        event.prev_pdus = [
+            (p_id, origin, hashes)
+            for p_id, origin, hashes, _ in self.prev_pdus
         ]
 
-        event.prev_events = [e for e in es if e != event.event_id]
-
         if self.prev_pdus:
-            event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
+            event.depth = max([int(v) for _, _, _, v in self.prev_pdus]) + 1
         else:
             event.depth = 0
 
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index 9d624429b7..a423b42dbd 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper
 from synapse.federation.units import Pdu
 from synapse.util.logutils import log_function
 
+from syutil.base64util import encode_base64
+
 from collections import namedtuple
 
 import logging
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -64,6 +67,8 @@ class PduStore(SQLBaseStore):
                 for r in PduEdgesTable.decode_results(txn.fetchall())
             ]
 
+            edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
+
             hashes = self._get_pdu_hashes_txn(txn, pdu_id, origin)
             signatures = self._get_pdu_origin_signatures_txn(
                 txn, pdu_id, origin
@@ -86,7 +91,7 @@ class PduStore(SQLBaseStore):
             row = txn.fetchone()
             if row:
                 results.append(PduTuple(
-                    PduEntry(*row), edges, hashes, signatures
+                    PduEntry(*row), edges, hashes, signatures, edge_hashes
                 ))
 
         return results
@@ -310,9 +315,14 @@ class PduStore(SQLBaseStore):
             (context, )
         )
 
-        results = txn.fetchall()
+        results = []
+        for pdu_id, origin, depth in txn.fetchall():
+            hashes = self._get_pdu_hashes_txn(txn, pdu_id, origin)
+            sha256_bytes = hashes["sha256"]
+            prev_hashes = {"sha256": encode_base64(sha256_bytes)}
+            results.append((pdu_id, origin, prev_hashes, depth))
 
-        return [(row[0], row[1], row[2]) for row in results]
+        return results
 
     @defer.inlineCallbacks
     def get_oldest_pdus_in_context(self, context):
@@ -431,7 +441,7 @@ class PduStore(SQLBaseStore):
                 "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
                 % PduForwardExtremitiesTable.table_name
             )
-            txn.executemany(query, prev_pdus)
+            txn.executemany(query, list(p[:2] for p in prev_pdus))
 
             # We only insert as a forward extremety the new pdu if there are no
             # other pdus that reference it as a prev pdu
@@ -454,7 +464,7 @@ class PduStore(SQLBaseStore):
             # deleted in a second if they're incorrect anyway.
             txn.executemany(
                 PduBackwardExtremitiesTable.insert_statement(),
-                [(i, o, context) for i, o in prev_pdus]
+                [(i, o, context) for i, o, _ in prev_pdus]
             )
 
             # Also delete from the backwards extremities table all ones that
@@ -915,7 +925,7 @@ This does not include a prev_pdus key.
 
 PduTuple = namedtuple(
     "PduTuple",
-    ("pdu_entry", "prev_pdu_list", "hashes", "signatures")
+    ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
 )
 """ This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
 the `prev_pdus` key of a PDU.
diff --git a/synapse/storage/schema/signatures.sql b/synapse/storage/schema/signatures.sql
index 86ee0f2377..a72c4dc35f 100644
--- a/synapse/storage/schema/signatures.sql
+++ b/synapse/storage/schema/signatures.sql
@@ -34,3 +34,19 @@ CREATE TABLE IF NOT EXISTS pdu_origin_signatures (
 CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures (
     pdu_id, origin
 );
+
+CREATE TABLE IF NOT EXISTS pdu_edge_hashes(
+    pdu_id TEXT,
+    origin TEXT,
+    prev_pdu_id TEXT,
+    prev_origin TEXT,
+    algorithm TEXT,
+    hash BLOB,
+    CONSTRAINT uniqueness UNIQUE (
+        pdu_id, origin, prev_pdu_id, prev_origin, algorithm
+    )
+);
+
+CREATE INDEX IF NOT EXISTS pdu_edge_hashes_id ON pdu_edge_hashes(
+    pdu_id, origin
+);
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 1f0a680500..1147102489 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -88,3 +88,34 @@ class SignatureStore(SQLBaseStore):
             "signature": buffer(signature_bytes),
         })
 
+    def _get_prev_pdu_hashes_txn(self, txn, pdu_id, origin):
+        """Get all the hashes for previous PDUs of a PDU
+        Args:
+            txn (cursor):
+            pdu_id (str): Id of the PDU.
+            origin (str): Origin of the PDU.
+        Returns:
+            dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
+        """
+        query = (
+            "SELECT prev_pdu_id, prev_origin, algorithm, hash"
+            " FROM pdu_edge_hashes"
+            " WHERE pdu_id = ? and origin = ?"
+        )
+        txn.execute(query, (pdu_id, origin))
+        results = {}
+        for prev_pdu_id, prev_origin, algorithm, hash_bytes in txn.fetchall():
+            hashes = results.setdefault((prev_pdu_id, prev_origin), {})
+            hashes[algorithm] = hash_bytes
+        return results
+
+    def _store_prev_pdu_hash_txn(self, txn, pdu_id, origin, prev_pdu_id,
+                             prev_origin, algorithm, hash_bytes):
+        self._simple_insert_txn(txn, "pdu_edge_hashes", {
+            "pdu_id": pdu_id,
+            "origin": origin,
+            "prev_pdu_id": prev_pdu_id,
+            "prev_origin": prev_origin,
+            "algorithm": algorithm,
+            "hash": buffer(hash_bytes),
+        })
diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py
index 03b2167cf7..eed50e6335 100644
--- a/tests/federation/test_federation.py
+++ b/tests/federation/test_federation.py
@@ -41,7 +41,7 @@ def make_pdu(prev_pdus=[], **kwargs):
     }
     pdu_fields.update(kwargs)
 
-    return PduTuple(PduEntry(**pdu_fields), prev_pdus, {}, {})
+    return PduTuple(PduEntry(**pdu_fields), prev_pdus, {}, {}, {})
 
 
 class FederationTestCase(unittest.TestCase):
diff --git a/tests/federation/test_pdu_codec.py b/tests/federation/test_pdu_codec.py
index 80851a4258..0ad8cf6641 100644
--- a/tests/federation/test_pdu_codec.py
+++ b/tests/federation/test_pdu_codec.py
@@ -88,7 +88,7 @@ class PduCodecTestCase(unittest.TestCase):
         self.assertEquals(pdu.context, event.room_id)
         self.assertEquals(pdu.is_state, event.is_state)
         self.assertEquals(pdu.depth, event.depth)
-        self.assertEquals(["alice@bob.com"], event.prev_events)
+        self.assertEquals(pdu.prev_pdus, event.prev_pdus)
         self.assertEquals(pdu.content, event.content)
 
     def test_pdu_from_event(self):
@@ -144,7 +144,7 @@ class PduCodecTestCase(unittest.TestCase):
         self.assertEquals(pdu.context, event.room_id)
         self.assertEquals(pdu.is_state, event.is_state)
         self.assertEquals(pdu.depth, event.depth)
-        self.assertEquals(["alice@bob.com"], event.prev_events)
+        self.assertEquals(pdu.prev_pdus, event.prev_pdus)
         self.assertEquals(pdu.content, event.content)
         self.assertEquals(pdu.state_key, event.state_key)