diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index aa91e52733..6e216066ab 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -19,7 +19,7 @@ import attr
from twisted.internet import defer
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.storage.stream import generate_pagination_where_clause
@@ -314,8 +314,8 @@ class RelationsWorkerStore(SQLBaseStore):
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
- @cachedInlineCallbacks(tree=True)
- def get_applicable_edit(self, event_id, event_type, sender):
+ @cachedInlineCallbacks()
+ def get_applicable_edit(self, event_id):
"""Get the most recent edit (if any) that has happened for the given
event.
@@ -323,8 +323,6 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id (str): The original event ID
- event_type (str): The original event type
- sender (str): The original event sender
Returns:
Deferred[EventBase|None]: Returns the most recent edit, if any.
@@ -332,26 +330,28 @@ class RelationsWorkerStore(SQLBaseStore):
# We only allow edits for `m.room.message` events that have the same sender
# and event type. We can't assert these things during regular event auth so
- # we have to do the post hoc.
-
- if event_type != EventTypes.Message:
- return
+ # we have to do the checks post hoc.
+ # Fetches latest edit that has the same type and sender as the
+ # original, and is an `m.room.message`.
sql = """
- SELECT event_id, origin_server_ts FROM events
+ SELECT edit.event_id FROM events AS edit
INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
WHERE
relates_to_id = ?
AND relation_type = ?
- AND type = ?
- AND sender = ?
- ORDER by origin_server_ts DESC, event_id DESC
+ AND edit.type = 'm.room.message'
+ ORDER by edit.origin_server_ts DESC, edit.event_id DESC
LIMIT 1
"""
def _get_applicable_edit_txn(txn):
txn.execute(
- sql, (event_id, RelationTypes.REPLACES, event_type, sender)
+ sql, (event_id, RelationTypes.REPLACES,)
)
row = txn.fetchone()
if row:
@@ -412,4 +412,6 @@ class RelationsStore(RelationsWorkerStore):
txn.call_after(
self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
)
- txn.call_after(self.get_applicable_edit.invalidate_many, (parent_id,))
+
+ if rel_type == RelationTypes.REPLACES:
+ txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
|