summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/events_forward_extremities.py49
1 files changed, 48 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index 250a424cc0..cc684a94fe 100644
--- a/synapse/storage/databases/main/events_forward_extremities.py
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -4,7 +4,54 @@ from synapse.storage._base import SQLBaseStore
 
 
 class EventForwardExtremitiesStore(SQLBaseStore):
+
+    async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+        """Delete any extra forward extremities for a room.
+
+        Returns count deleted.
+        """
+        def delete_forward_extremities_for_room_txn(txn):
+            # First we need to get the event_id to not delete
+            sql = (
+                "SELECT "
+                "   last_value(event_id) OVER w AS event_id"
+                "   FROM event_forward_extremities"
+                "   NATURAL JOIN events"
+                " where room_id = ?"
+                "   WINDOW w AS ("
+                "   PARTITION BY room_id"
+                "       ORDER BY stream_ordering"
+                "       range between unbounded preceding and unbounded following"
+                "   )"
+                "   ORDER BY stream_ordering"
+            )
+            txn.execute(sql, (room_id,))
+            rows = txn.fetchall()
+
+            # TODO: should this raise a SynapseError instead of better to blow?
+            event_id = rows[0][0]
+
+            # Now delete the extra forward extremities
+            sql = (
+                "DELETE FROM event_forward_extremities "
+                "WHERE"
+                "   event_id != ?"
+                "   AND room_id = ?"
+            )
+
+            # TODO we should not commit yet
+            txn.execute(sql, (event_id, room_id))
+
+            # TODO flush the cache then commit
+
+            return txn.rowcount
+
+        return await self.db_pool.runInteraction(
+            "delete_forward_extremities_for_room", delete_forward_extremities_for_room_txn,
+        )
+
     async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+        """Get list of forward extremities for a room."""
         def get_forward_extremities_for_room_txn(txn):
             sql = (
                 "SELECT event_id, state_group FROM event_forward_extremities NATURAL JOIN event_to_state_groups "
@@ -16,5 +63,5 @@ class EventForwardExtremitiesStore(SQLBaseStore):
             return [{"event_id": row[0], "state_group": row[1]} for row in rows]
 
         return await self.db_pool.runInteraction(
-            "get_forward_extremities_for_room", get_forward_extremities_for_room_txn
+            "get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
         )