summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation.py8
-rw-r--r--synapse/storage/event_federation.py16
2 files changed, 14 insertions, 10 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fadb48fde6..cd9e655f95 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -173,6 +173,7 @@ class FederationHandler(BaseHandler):
                     context=event.room_id,
                     event_id=event.event_id,
                 )
+                # FIXME: Get auth chain for these state events
 
             current_state = state
 
@@ -288,7 +289,7 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def on_event_auth(self, event_id):
-        auth = yield self.store.get_auth_chain(event_id)
+        auth = yield self.store.get_auth_chain([event_id])
 
         for event in auth:
             event.signatures.update(
@@ -528,7 +529,10 @@ class FederationHandler(BaseHandler):
 
         yield self.replication_layer.send_pdu(new_pdu, destinations)
 
-        auth_chain = yield self.store.get_auth_chain(event.event_id)
+        state_ids = [e.event_id for e in event.state_events.values()]
+        auth_chain = yield self.store.get_auth_chain(set(
+            [event.event_id] + state_ids
+        ))
 
         defer.returnValue({
             "state": context.current_state.values(),
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index ced066f407..7a6009c9ee 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore):
     and backfilling from another server respectively.
     """
 
-    def get_auth_chain(self, event_id):
+    def get_auth_chain(self, event_ids):
         return self.runInteraction(
             "get_auth_chain",
             self._get_auth_chain_txn,
-            event_id
+            event_ids
         )
 
-    def _get_auth_chain_txn(self, txn, event_id):
-        results = self._get_auth_chain_ids_txn(txn, event_id)
+    def _get_auth_chain_txn(self, txn, event_ids):
+        results = self._get_auth_chain_ids_txn(txn, event_ids)
 
         sql = "SELECT * FROM events WHERE event_id = ?"
         rows = []
@@ -50,21 +50,21 @@ class EventFederationStore(SQLBaseStore):
 
         return self._parse_events_txn(txn, rows)
 
-    def get_auth_chain_ids(self, event_id):
+    def get_auth_chain_ids(self, event_ids):
         return self.runInteraction(
             "get_auth_chain_ids",
             self._get_auth_chain_ids_txn,
-            event_id
+            event_ids
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_id):
+    def _get_auth_chain_ids_txn(self, txn, event_ids):
         results = set()
 
         base_sql = (
             "SELECT auth_id FROM event_auth WHERE %s"
         )
 
-        front = set([event_id])
+        front = set(event_ids)
         while front:
             sql = base_sql % (
                 " OR ".join(["event_id=?"] * len(front)),