summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/state/__init__.py15
-rw-r--r--synapse/state/v2.py2
-rw-r--r--synapse/storage/data_stores/main/event_federation.py28
3 files changed, 34 insertions, 11 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdd6bef6b4..df7a4f6a89 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,7 +16,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Dict, Iterable, List, Optional
+from typing import Dict, Iterable, List, Optional, Set
 
 from six import iteritems, itervalues
 
@@ -662,7 +662,7 @@ class StateResolutionStore(object):
             allow_rejected=allow_rejected,
         )
 
-    def get_auth_chain(self, event_ids):
+    def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]):
         """Gets the full auth chain for a set of events (including rejected
         events).
 
@@ -674,11 +674,16 @@ class StateResolutionStore(object):
                presence of rejected events
 
         Args:
-            event_ids (list): The event IDs of the events to fetch the auth
-                chain for. Must be state events.
+            event_ids: The event IDs of the events to fetch the auth chain for.
+                Must be state events.
+            ignore_events: Set of events to exclude from the returned auth
+                chain.
+
 
         Returns:
             Deferred[list[str]]: List of event IDs of the auth chain.
         """
 
-        return self.store.get_auth_chain_ids(event_ids, include_given=True)
+        return self.store.get_auth_chain_ids(
+            event_ids, include_given=True, ignore_events=ignore_events,
+        )
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 531018c6a5..75fe58305a 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -248,7 +248,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
             and eid not in common
         )
 
-        auth_chain = yield state_res_store.get_auth_chain(auth_ids)
+        auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
         auth_ids.update(auth_chain)
 
         auth_sets.append(auth_ids)
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 60c67457b4..e16da2577d 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 import itertools
 import logging
+from typing import List, Optional, Set
 
 from six.moves import range
 from six.moves.queue import Empty, PriorityQueue
@@ -46,21 +47,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             event_ids, include_given=include_given
         ).addCallback(self.get_events_as_list)
 
-    def get_auth_chain_ids(self, event_ids, include_given=False):
+    def get_auth_chain_ids(
+        self,
+        event_ids: List[str],
+        include_given: bool = False,
+        ignore_events: Optional[Set[str]] = None,
+    ):
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
-            event_ids (list): state events
-            include_given (bool): include the given events in result
+            event_ids: state events
+            include_given: include the given events in result
+            ignore_events: Set of events to exclude from the returned auth
+                chain. This is useful if the caller will just discard the
+                given events anyway, and saves us from figuring out their auth
+                chains if not required.
 
         Returns:
             list of event_ids
         """
         return self.db.runInteraction(
-            "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
+            "get_auth_chain_ids",
+            self._get_auth_chain_ids_txn,
+            event_ids,
+            include_given,
+            ignore_events,
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
+        if ignore_events is None:
+            ignore_events = set()
+
         if include_given:
             results = set(event_ids)
         else:
@@ -80,6 +97,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 txn.execute(base_sql + clause, list(args))
                 new_front.update([r[0] for r in txn])
 
+            new_front -= ignore_events
             new_front -= results
 
             front = new_front