summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/event_federation.py40
1 files changed, 17 insertions, 23 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4826be630c..e6a97b018c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,14 +15,16 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Set, Tuple
 
 from synapse.api.errors import StoreError
+from synapse.events import EventBase
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.types import Collection
 from synapse.util.caches.descriptors import cached
 from synapse.util.iterutils import batch_iter
 
@@ -30,12 +32,14 @@ logger = logging.getLogger(__name__)
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    async def get_auth_chain(self, event_ids, include_given=False):
+    async def get_auth_chain(
+        self, event_ids: Collection[str], include_given: bool = False
+    ) -> List[EventBase]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
-            event_ids (list): state events
-            include_given (bool): include the given events in result
+            event_ids: state events
+            include_given: include the given events in result
 
         Returns:
             list of events
@@ -45,43 +49,34 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         )
         return await self.get_events_as_list(event_ids)
 
-    def get_auth_chain_ids(
-        self,
-        event_ids: List[str],
-        include_given: bool = False,
-        ignore_events: Optional[Set[str]] = None,
-    ):
+    async def get_auth_chain_ids(
+        self, event_ids: Collection[str], include_given: bool = False,
+    ) -> List[str]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
             event_ids: state events
             include_given: include the given events in result
-            ignore_events: Set of events to exclude from the returned auth
-                chain. This is useful if the caller will just discard the
-                given events anyway, and saves us from figuring out their auth
-                chains if not required.
 
         Returns:
             list of event_ids
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_auth_chain_ids",
             self._get_auth_chain_ids_txn,
             event_ids,
             include_given,
-            ignore_events,
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
-        if ignore_events is None:
-            ignore_events = set()
-
+    def _get_auth_chain_ids_txn(
+        self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
+    ) -> List[str]:
         if include_given:
             results = set(event_ids)
         else:
             results = set()
 
-        base_sql = "SELECT auth_id FROM event_auth WHERE "
+        base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
 
         front = set(event_ids)
         while front:
@@ -93,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 txn.execute(base_sql + clause, args)
                 new_front.update(r[0] for r in txn)
 
-            new_front -= ignore_events
             new_front -= results
 
             front = new_front