summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/events/utils.py30
-rw-r--r--synapse/replication/slave/storage/events.py1
-rw-r--r--synapse/storage/relations.py60
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py91
4 files changed, 167 insertions, 15 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 16d0c64372..2019ce9b1c 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -346,7 +346,7 @@ class EventClientSerializer(object):
             defer.returnValue(event)
 
         event_id = event.event_id
-        event = serialize_event(event, time_now, **kwargs)
+        serialized_event = serialize_event(event, time_now, **kwargs)
 
         # If MSC1849 is enabled then we need to look if thre are any relations
         # we need to bundle in with the event
@@ -359,14 +359,36 @@ class EventClientSerializer(object):
             )
 
             if annotations.chunk:
-                r = event["unsigned"].setdefault("m.relations", {})
+                r = serialized_event["unsigned"].setdefault("m.relations", {})
                 r[RelationTypes.ANNOTATION] = annotations.to_dict()
 
             if references.chunk:
-                r = event["unsigned"].setdefault("m.relations", {})
+                r = serialized_event["unsigned"].setdefault("m.relations", {})
                 r[RelationTypes.REFERENCES] = references.to_dict()
 
-        defer.returnValue(event)
+            edit = None
+            if event.type == EventTypes.Message:
+                edit = yield self.store.get_applicable_edit(
+                    event.event_id, event.type, event.sender,
+                )
+
+            if edit:
+                # If there is an edit replace the content, preserving existing
+                # relations.
+
+                relations = event.content.get("m.relates_to")
+                serialized_event["content"] = edit.content.get("m.new_content", {})
+                if relations:
+                    serialized_event["content"]["m.relates_to"] = relations
+                else:
+                    serialized_event["content"].pop("m.relates_to", None)
+
+                r = serialized_event["unsigned"].setdefault("m.relations", {})
+                r[RelationTypes.REPLACES] = {
+                    "event_id": edit.event_id,
+                }
+
+        defer.returnValue(serialized_event)
 
     def serialize_events(self, events, time_now, **kwargs):
         """Serializes multiple events.
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 797450bc66..857128b311 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -143,3 +143,4 @@ class SlavedEventStore(EventFederationWorkerStore,
         if relates_to:
             self.get_relations_for_event.invalidate_many((relates_to,))
             self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+            self.get_applicable_edit.invalidate_many((relates_to,))
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index de67e305a1..aa91e52733 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -17,11 +17,13 @@ import logging
 
 import attr
 
-from synapse.api.constants import RelationTypes
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.stream import generate_pagination_where_clause
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
 
@@ -312,6 +314,59 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
+    @cachedInlineCallbacks(tree=True)
+    def get_applicable_edit(self, event_id, event_type, sender):
+        """Get the most recent edit (if any) that has happened for the given
+        event.
+
+        Correctly handles checking whether edits were allowed to happen.
+
+        Args:
+            event_id (str): The original event ID
+            event_type (str): The original event type
+            sender (str): The original event sender
+
+        Returns:
+            Deferred[EventBase|None]: Returns the most recent edit, if any.
+        """
+
+        # We only allow edits for `m.room.message` events that have the same sender
+        # and event type. We can't assert these things during regular event auth so
+        # we have to do the post hoc.
+
+        if event_type != EventTypes.Message:
+            return
+
+        sql = """
+            SELECT event_id, origin_server_ts FROM events
+            INNER JOIN event_relations USING (event_id)
+            WHERE
+                relates_to_id = ?
+                AND relation_type = ?
+                AND type = ?
+                AND sender = ?
+            ORDER by origin_server_ts DESC, event_id DESC
+            LIMIT 1
+        """
+
+        def _get_applicable_edit_txn(txn):
+            txn.execute(
+                sql, (event_id, RelationTypes.REPLACES, event_type, sender)
+            )
+            row = txn.fetchone()
+            if row:
+                return row[0]
+
+        edit_id = yield self.runInteraction(
+            "get_applicable_edit", _get_applicable_edit_txn
+        )
+
+        if not edit_id:
+            return
+
+        edit_event = yield self.get_event(edit_id, allow_none=True)
+        defer.returnValue(edit_event)
+
 
 class RelationsStore(RelationsWorkerStore):
     def _handle_event_relations(self, txn, event):
@@ -357,3 +412,4 @@ class RelationsStore(RelationsWorkerStore):
         txn.call_after(
             self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
         )
+        txn.call_after(self.get_applicable_edit.invalidate_many, (parent_id,))
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 775622bd2b..b0e4c47ae3 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import itertools
+import json
 
 import six
 
@@ -102,11 +103,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         # relation event we sent above.
         self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
         self.assert_dict(
-            {
-                "event_id": annotation_id,
-                "sender": self.user_id,
-                "type": "m.reaction",
-            },
+            {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
             channel.json_body["chunk"][0],
         )
 
@@ -330,8 +327,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.render(request)
         self.assertEquals(200, channel.code, channel.json_body)
 
-        self.maxDiff = None
-
         self.assertEquals(
             channel.json_body["unsigned"].get("m.relations"),
             {
@@ -347,7 +342,84 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def _send_relation(self, relation_type, event_type, key=None):
+    def test_edit(self):
+        """Test that a simple edit works.
+        """
+
+        new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+        channel = self._send_relation(
+            RelationTypes.REPLACES,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        edit_event_id = channel.json_body["event_id"]
+
+        request, channel = self.make_request(
+            "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id)
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        self.assertEquals(channel.json_body["content"], new_body)
+
+        self.assertEquals(
+            channel.json_body["unsigned"].get("m.relations"),
+            {RelationTypes.REPLACES: {"event_id": edit_event_id}},
+        )
+
+    def test_multi_edit(self):
+        """Test that multiple edits, including attempts by people who
+        shouldn't be allowed, are correctly handled.
+        """
+
+        channel = self._send_relation(
+            RelationTypes.REPLACES,
+            "m.room.message",
+            content={
+                "msgtype": "m.text",
+                "body": "Wibble",
+                "m.new_content": {"msgtype": "m.text", "body": "First edit"},
+            },
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+        channel = self._send_relation(
+            RelationTypes.REPLACES,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        edit_event_id = channel.json_body["event_id"]
+
+        channel = self._send_relation(
+            RelationTypes.REPLACES,
+            "m.room.message.WRONG_TYPE",
+            content={
+                "msgtype": "m.text",
+                "body": "Wibble",
+                "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
+            },
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        request, channel = self.make_request(
+            "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id)
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        self.assertEquals(channel.json_body["content"], new_body)
+
+        self.assertEquals(
+            channel.json_body["unsigned"].get("m.relations"),
+            {RelationTypes.REPLACES: {"event_id": edit_event_id}},
+        )
+
+    def _send_relation(self, relation_type, event_type, key=None, content={}):
         """Helper function to send a relation pointing at `self.parent_id`
 
         Args:
@@ -355,6 +427,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             event_type (str): The type of the event to create
             key (str|None): The aggregation key used for m.annotation relation
                 type.
+            content(dict|None): The content of the created event.
 
         Returns:
             FakeChannel
@@ -367,7 +440,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "POST",
             "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
             % (self.room, self.parent_id, relation_type, event_type, query),
-            b"{}",
+            json.dumps(content).encode("utf-8"),
         )
         self.render(request)
         return channel