summary refs log tree commit diff
path: root/synapse/storage/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/relations.py')
-rw-r--r--synapse/storage/relations.py60
1 files changed, 58 insertions, 2 deletions
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index de67e305a1..aa91e52733 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -17,11 +17,13 @@ import logging
 
 import attr
 
-from synapse.api.constants import RelationTypes
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.stream import generate_pagination_where_clause
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
 
@@ -312,6 +314,59 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
+    @cachedInlineCallbacks(tree=True)
+    def get_applicable_edit(self, event_id, event_type, sender):
+        """Get the most recent edit (if any) that has happened for the given
+        event.
+
+        Correctly handles checking whether edits were allowed to happen.
+
+        Args:
+            event_id (str): The original event ID
+            event_type (str): The original event type
+            sender (str): The original event sender
+
+        Returns:
+            Deferred[EventBase|None]: Returns the most recent edit, if any.
+        """
+
+        # We only allow edits for `m.room.message` events that have the same sender
+        # and event type. We can't assert these things during regular event auth so
+        # we have to do the post hoc.
+
+        if event_type != EventTypes.Message:
+            return
+
+        sql = """
+            SELECT event_id, origin_server_ts FROM events
+            INNER JOIN event_relations USING (event_id)
+            WHERE
+                relates_to_id = ?
+                AND relation_type = ?
+                AND type = ?
+                AND sender = ?
+            ORDER by origin_server_ts DESC, event_id DESC
+            LIMIT 1
+        """
+
+        def _get_applicable_edit_txn(txn):
+            txn.execute(
+                sql, (event_id, RelationTypes.REPLACES, event_type, sender)
+            )
+            row = txn.fetchone()
+            if row:
+                return row[0]
+
+        edit_id = yield self.runInteraction(
+            "get_applicable_edit", _get_applicable_edit_txn
+        )
+
+        if not edit_id:
+            return
+
+        edit_event = yield self.get_event(edit_id, allow_none=True)
+        defer.returnValue(edit_event)
+
 
 class RelationsStore(RelationsWorkerStore):
     def _handle_event_relations(self, txn, event):
@@ -357,3 +412,4 @@ class RelationsStore(RelationsWorkerStore):
         txn.call_after(
             self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
         )
+        txn.call_after(self.get_applicable_edit.invalidate_many, (parent_id,))