summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/event_federation.py50
1 files changed, 45 insertions, 5 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ff81d5cd17..c0ea445550 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -16,6 +16,7 @@ import logging
 from queue import Empty, PriorityQueue
 from typing import Collection, Dict, Iterable, List, Set, Tuple
 
+from synapse.api.constants import MAX_DEPTH
 from synapse.api.errors import StoreError
 from synapse.events import EventBase
 from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -670,8 +671,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return dict(txn)
 
-    async def get_max_depth_of(self, event_ids: List[str]) -> int:
-        """Returns the max depth of a set of event IDs
+    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+        """Returns the event ID and depth for the event that has the max depth from a set of event IDs
 
         Args:
             event_ids: The event IDs to calculate the max depth of.
@@ -680,14 +681,53 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             table="events",
             column="event_id",
             iterable=event_ids,
-            retcols=("depth",),
+            retcols=(
+                "event_id",
+                "depth",
+            ),
             desc="get_max_depth_of",
         )
 
         if not rows:
-            return 0
+            return None, 0
         else:
-            return max(row["depth"] for row in rows)
+            max_depth_event_id = ""
+            current_max_depth = 0
+            for row in rows:
+                if row["depth"] > current_max_depth:
+                    max_depth_event_id = row["event_id"]
+                    current_max_depth = row["depth"]
+
+            return max_depth_event_id, current_max_depth
+
+    async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+        """Returns the event ID and depth for the event that has the min depth from a set of event IDs
+
+        Args:
+            event_ids: The event IDs to calculate the max depth of.
+        """
+        rows = await self.db_pool.simple_select_many_batch(
+            table="events",
+            column="event_id",
+            iterable=event_ids,
+            retcols=(
+                "event_id",
+                "depth",
+            ),
+            desc="get_min_depth_of",
+        )
+
+        if not rows:
+            return None, 0
+        else:
+            min_depth_event_id = ""
+            current_min_depth = MAX_DEPTH
+            for row in rows:
+                if row["depth"] < current_min_depth:
+                    min_depth_event_id = row["event_id"]
+                    current_min_depth = row["depth"]
+
+            return min_depth_event_id, current_min_depth
 
     async def get_prev_events_for_room(self, room_id: str) -> List[str]:
         """