diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 60c67457b4..e16da2577d 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,6 +14,7 @@
# limitations under the License.
import itertools
import logging
+from typing import List, Optional, Set
from six.moves import range
from six.moves.queue import Empty, PriorityQueue
@@ -46,21 +47,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
- def get_auth_chain_ids(self, event_ids, include_given=False):
+ def get_auth_chain_ids(
+ self,
+ event_ids: List[str],
+ include_given: bool = False,
+ ignore_events: Optional[Set[str]] = None,
+ ):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
+ ignore_events: Set of events to exclude from the returned auth
+ chain. This is useful if the caller will just discard the
+ given events anyway, and saves us from figuring out their auth
+ chains if not required.
Returns:
list of event_ids
"""
return self.db.runInteraction(
- "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
+ "get_auth_chain_ids",
+ self._get_auth_chain_ids_txn,
+ event_ids,
+ include_given,
+ ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
+ if ignore_events is None:
+ ignore_events = set()
+
if include_given:
results = set(event_ids)
else:
@@ -80,6 +97,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(base_sql + clause, list(args))
new_front.update([r[0] for r in txn])
+ new_front -= ignore_events
new_front -= results
front = new_front
|