diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 16d0c64372..bf3c8f8dc1 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -346,7 +346,7 @@ class EventClientSerializer(object):
defer.returnValue(event)
event_id = event.event_id
- event = serialize_event(event, time_now, **kwargs)
+ serialized_event = serialize_event(event, time_now, **kwargs)
# If MSC1849 is enabled then we need to look if thre are any relations
# we need to bundle in with the event
@@ -359,14 +359,34 @@ class EventClientSerializer(object):
)
if annotations.chunk:
- r = event["unsigned"].setdefault("m.relations", {})
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
r[RelationTypes.ANNOTATION] = annotations.to_dict()
if references.chunk:
- r = event["unsigned"].setdefault("m.relations", {})
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
r[RelationTypes.REFERENCES] = references.to_dict()
- defer.returnValue(event)
+ edit = None
+ if event.type == EventTypes.Message:
+ edit = yield self.store.get_applicable_edit(event_id)
+
+ if edit:
+ # If there is an edit replace the content, preserving existing
+ # relations.
+
+ relations = event.content.get("m.relates_to")
+ serialized_event["content"] = edit.content.get("m.new_content", {})
+ if relations:
+ serialized_event["content"]["m.relates_to"] = relations
+ else:
+ serialized_event["content"].pop("m.relates_to", None)
+
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
+ r[RelationTypes.REPLACES] = {
+ "event_id": edit.event_id,
+ }
+
+ defer.returnValue(serialized_event)
def serialize_events(self, events, time_now, **kwargs):
"""Serializes multiple events.
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 797450bc66..a3952506c1 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -143,3 +143,4 @@ class SlavedEventStore(EventFederationWorkerStore,
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))
self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+ self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index de67e305a1..6e216066ab 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -17,11 +17,13 @@ import logging
import attr
+from twisted.internet import defer
+
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.storage.stream import generate_pagination_where_clause
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -312,6 +314,59 @@ class RelationsWorkerStore(SQLBaseStore):
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
+ @cachedInlineCallbacks()
+ def get_applicable_edit(self, event_id):
+ """Get the most recent edit (if any) that has happened for the given
+ event.
+
+ Correctly handles checking whether edits were allowed to happen.
+
+ Args:
+ event_id (str): The original event ID
+
+ Returns:
+ Deferred[EventBase|None]: Returns the most recent edit, if any.
+ """
+
+ # We only allow edits for `m.room.message` events that have the same sender
+ # and event type. We can't assert these things during regular event auth so
+ # we have to do the checks post hoc.
+
+ # Fetches latest edit that has the same type and sender as the
+ # original, and is an `m.room.message`.
+ sql = """
+ SELECT edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by edit.origin_server_ts DESC, edit.event_id DESC
+ LIMIT 1
+ """
+
+ def _get_applicable_edit_txn(txn):
+ txn.execute(
+ sql, (event_id, RelationTypes.REPLACES,)
+ )
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ edit_id = yield self.runInteraction(
+ "get_applicable_edit", _get_applicable_edit_txn
+ )
+
+ if not edit_id:
+ return
+
+ edit_event = yield self.get_event(edit_id, allow_none=True)
+ defer.returnValue(edit_event)
+
class RelationsStore(RelationsWorkerStore):
def _handle_event_relations(self, txn, event):
@@ -357,3 +412,6 @@ class RelationsStore(RelationsWorkerStore):
txn.call_after(
self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
)
+
+ if rel_type == RelationTypes.REPLACES:
+ txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 775622bd2b..b0e4c47ae3 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -14,6 +14,7 @@
# limitations under the License.
import itertools
+import json
import six
@@ -102,11 +103,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# relation event we sent above.
self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
self.assert_dict(
- {
- "event_id": annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
+ {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
channel.json_body["chunk"][0],
)
@@ -330,8 +327,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
- self.maxDiff = None
-
self.assertEquals(
channel.json_body["unsigned"].get("m.relations"),
{
@@ -347,7 +342,84 @@ class RelationsTestCase(unittest.HomeserverTestCase):
},
)
- def _send_relation(self, relation_type, event_type, key=None):
+ def test_edit(self):
+ """Test that a simple edit works.
+ """
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACES,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(channel.json_body["content"], new_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {RelationTypes.REPLACES: {"event_id": edit_event_id}},
+ )
+
+ def test_multi_edit(self):
+ """Test that multiple edits, including attempts by people who
+ shouldn't be allowed, are correctly handled.
+ """
+
+ channel = self._send_relation(
+ RelationTypes.REPLACES,
+ "m.room.message",
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "First edit"},
+ },
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACES,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.REPLACES,
+ "m.room.message.WRONG_TYPE",
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
+ },
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ request, channel = self.make_request(
+ "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(channel.json_body["content"], new_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {RelationTypes.REPLACES: {"event_id": edit_event_id}},
+ )
+
+ def _send_relation(self, relation_type, event_type, key=None, content={}):
"""Helper function to send a relation pointing at `self.parent_id`
Args:
@@ -355,6 +427,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
event_type (str): The type of the event to create
key (str|None): The aggregation key used for m.annotation relation
type.
+ content(dict|None): The content of the created event.
Returns:
FakeChannel
@@ -367,7 +440,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, self.parent_id, relation_type, event_type, query),
- b"{}",
+ json.dumps(content).encode("utf-8"),
)
self.render(request)
return channel
|