summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/events/utils.py4
-rw-r--r--synapse/replication/slave/storage/events.py2
-rw-r--r--synapse/storage/relations.py32
3 files changed, 19 insertions, 19 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 2019ce9b1c..bf3c8f8dc1 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -368,9 +368,7 @@ class EventClientSerializer(object):
 
             edit = None
             if event.type == EventTypes.Message:
-                edit = yield self.store.get_applicable_edit(
-                    event.event_id, event.type, event.sender,
-                )
+                edit = yield self.store.get_applicable_edit(event_id)
 
             if edit:
                 # If there is an edit replace the content, preserving existing
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 857128b311..a3952506c1 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -143,4 +143,4 @@ class SlavedEventStore(EventFederationWorkerStore,
         if relates_to:
             self.get_relations_for_event.invalidate_many((relates_to,))
             self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
-            self.get_applicable_edit.invalidate_many((relates_to,))
+            self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index aa91e52733..6e216066ab 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -19,7 +19,7 @@ import attr
 
 from twisted.internet import defer
 
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.stream import generate_pagination_where_clause
@@ -314,8 +314,8 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
-    @cachedInlineCallbacks(tree=True)
-    def get_applicable_edit(self, event_id, event_type, sender):
+    @cachedInlineCallbacks()
+    def get_applicable_edit(self, event_id):
         """Get the most recent edit (if any) that has happened for the given
         event.
 
@@ -323,8 +323,6 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id (str): The original event ID
-            event_type (str): The original event type
-            sender (str): The original event sender
 
         Returns:
             Deferred[EventBase|None]: Returns the most recent edit, if any.
@@ -332,26 +330,28 @@ class RelationsWorkerStore(SQLBaseStore):
 
         # We only allow edits for `m.room.message` events that have the same sender
         # and event type. We can't assert these things during regular event auth so
-        # we have to do the post hoc.
-
-        if event_type != EventTypes.Message:
-            return
+        # we have to do the checks post hoc.
 
+        # Fetches latest edit that has the same type and sender as the
+        # original, and is an `m.room.message`.
         sql = """
-            SELECT event_id, origin_server_ts FROM events
+            SELECT edit.event_id FROM events AS edit
             INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS original ON
+                original.event_id = relates_to_id
+                AND edit.type = original.type
+                AND edit.sender = original.sender
             WHERE
                 relates_to_id = ?
                 AND relation_type = ?
-                AND type = ?
-                AND sender = ?
-            ORDER by origin_server_ts DESC, event_id DESC
+                AND edit.type = 'm.room.message'
+            ORDER by edit.origin_server_ts DESC, edit.event_id DESC
             LIMIT 1
         """
 
         def _get_applicable_edit_txn(txn):
             txn.execute(
-                sql, (event_id, RelationTypes.REPLACES, event_type, sender)
+                sql, (event_id, RelationTypes.REPLACES,)
             )
             row = txn.fetchone()
             if row:
@@ -412,4 +412,6 @@ class RelationsStore(RelationsWorkerStore):
         txn.call_after(
             self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
         )
-        txn.call_after(self.get_applicable_edit.invalidate_many, (parent_id,))
+
+        if rel_type == RelationTypes.REPLACES:
+            txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))