summary refs log tree commit diff
path: root/synapse/storage/event_federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/event_federation.py')
-rw-r--r--synapse/storage/event_federation.py35
1 files changed, 15 insertions, 20 deletions
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 6c559f8f63..fb2eb21713 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -32,39 +32,33 @@ class EventFederationStore(SQLBaseStore):
     and backfilling from another server respectively.
     """
 
-    def get_auth_chain(self, event_id):
+    def get_auth_chain(self, event_ids):
         return self.runInteraction(
             "get_auth_chain",
             self._get_auth_chain_txn,
-            event_id
+            event_ids
         )
 
-    def _get_auth_chain_txn(self, txn, event_id):
-        results = self._get_auth_chain_ids_txn(txn, event_id)
+    def _get_auth_chain_txn(self, txn, event_ids):
+        results = self._get_auth_chain_ids_txn(txn, event_ids)
 
-        sql = "SELECT * FROM events WHERE event_id = ?"
-        rows = []
-        for ev_id in results:
-            c = txn.execute(sql, (ev_id,))
-            rows.extend(self.cursor_to_dict(c))
+        return self._get_events_txn(txn, results)
 
-        return self._parse_events_txn(txn, rows)
-
-    def get_auth_chain_ids(self, event_id):
+    def get_auth_chain_ids(self, event_ids):
         return self.runInteraction(
             "get_auth_chain_ids",
             self._get_auth_chain_ids_txn,
-            event_id
+            event_ids
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_id):
+    def _get_auth_chain_ids_txn(self, txn, event_ids):
         results = set()
 
         base_sql = (
             "SELECT auth_id FROM event_auth WHERE %s"
         )
 
-        front = set([event_id])
+        front = set(event_ids)
         while front:
             sql = base_sql % (
                 " OR ".join(["event_id=?"] * len(front)),
@@ -177,14 +171,15 @@ class EventFederationStore(SQLBaseStore):
             retcols=["prev_event_id", "is_state"],
         )
 
+        hashes = self._get_prev_event_hashes_txn(txn, event_id)
+
         results = []
         for d in res:
-            hashes = self._get_event_reference_hashes_txn(
-                txn,
-                d["prev_event_id"]
-            )
+            edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"])
+            edge_hash.update(hashes.get(d["prev_event_id"], {}))
             prev_hashes = {
-                k: encode_base64(v) for k, v in hashes.items()
+                k: encode_base64(v)
+                for k, v in edge_hash.items()
                 if k == "sha256"
             }
             results.append((d["prev_event_id"], prev_hashes, d["is_state"]))