diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 925eb5376e..692c2d8a7b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -285,7 +285,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def on_event_auth(self, event_id):
- auth = yield self.store.get_auth_chain(event_id)
+ auth = yield self.store.get_auth_chain([event_id])
for event in auth:
event.signatures.update(
@@ -494,7 +494,10 @@ class FederationHandler(BaseHandler):
yield self.replication_layer.send_pdu(new_pdu)
- auth_chain = yield self.store.get_auth_chain(event.event_id)
+ state_ids = [e.event_id for e in event.state_events.values()]
+ auth_chain = yield self.store.get_auth_chain(set(
+ [event.event_id] + state_ids
+ ))
defer.returnValue({
"state": event.state_events.values(),
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 6c559f8f63..0ff9a23ee0 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively.
"""
- def get_auth_chain(self, event_id):
+ def get_auth_chain(self, event_ids):
return self.runInteraction(
"get_auth_chain",
self._get_auth_chain_txn,
- event_id
+ event_ids
)
- def _get_auth_chain_txn(self, txn, event_id):
- results = self._get_auth_chain_ids_txn(txn, event_id)
+ def _get_auth_chain_txn(self, txn, event_ids):
+ results = self._get_auth_chain_ids_txn(txn, event_ids)
sql = "SELECT * FROM events WHERE event_id = ?"
rows = []
@@ -50,21 +50,21 @@ class EventFederationStore(SQLBaseStore):
return self._parse_events_txn(txn, rows)
- def get_auth_chain_ids(self, event_id):
+ def get_auth_chain_ids(self, event_ids):
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
- event_id
+ event_ids
)
- def _get_auth_chain_ids_txn(self, txn, event_id):
+ def _get_auth_chain_ids_txn(self, txn, event_ids):
results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE %s"
)
- front = set([event_id])
+ front = set(event_ids)
while front:
sql = base_sql % (
" OR ".join(["event_id=?"] * len(front)),
|