summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py91
1 files changed, 82 insertions, 9 deletions
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 775622bd2b..b0e4c47ae3 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import itertools
+import json
 
 import six
 
@@ -102,11 +103,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         # relation event we sent above.
         self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
         self.assert_dict(
-            {
-                "event_id": annotation_id,
-                "sender": self.user_id,
-                "type": "m.reaction",
-            },
+            {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
             channel.json_body["chunk"][0],
         )
 
@@ -330,8 +327,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.render(request)
         self.assertEquals(200, channel.code, channel.json_body)
 
-        self.maxDiff = None
-
         self.assertEquals(
             channel.json_body["unsigned"].get("m.relations"),
             {
@@ -347,7 +342,84 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def _send_relation(self, relation_type, event_type, key=None):
+    def test_edit(self):
+        """Test that a simple edit works.
+        """
+
+        new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+        channel = self._send_relation(
+            RelationTypes.REPLACES,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        edit_event_id = channel.json_body["event_id"]
+
+        request, channel = self.make_request(
+            "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id)
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        self.assertEquals(channel.json_body["content"], new_body)
+
+        self.assertEquals(
+            channel.json_body["unsigned"].get("m.relations"),
+            {RelationTypes.REPLACES: {"event_id": edit_event_id}},
+        )
+
+    def test_multi_edit(self):
+        """Test that multiple edits, including attempts by people who
+        shouldn't be allowed, are correctly handled.
+        """
+
+        channel = self._send_relation(
+            RelationTypes.REPLACES,
+            "m.room.message",
+            content={
+                "msgtype": "m.text",
+                "body": "Wibble",
+                "m.new_content": {"msgtype": "m.text", "body": "First edit"},
+            },
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+        channel = self._send_relation(
+            RelationTypes.REPLACES,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        edit_event_id = channel.json_body["event_id"]
+
+        channel = self._send_relation(
+            RelationTypes.REPLACES,
+            "m.room.message.WRONG_TYPE",
+            content={
+                "msgtype": "m.text",
+                "body": "Wibble",
+                "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
+            },
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        request, channel = self.make_request(
+            "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id)
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        self.assertEquals(channel.json_body["content"], new_body)
+
+        self.assertEquals(
+            channel.json_body["unsigned"].get("m.relations"),
+            {RelationTypes.REPLACES: {"event_id": edit_event_id}},
+        )
+
+    def _send_relation(self, relation_type, event_type, key=None, content={}):
         """Helper function to send a relation pointing at `self.parent_id`
 
         Args:
@@ -355,6 +427,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             event_type (str): The type of the event to create
             key (str|None): The aggregation key used for m.annotation relation
                 type.
+            content(dict|None): The content of the created event.
 
         Returns:
             FakeChannel
@@ -367,7 +440,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "POST",
             "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
             % (self.room, self.parent_id, relation_type, event_type, query),
-            b"{}",
+            json.dumps(content).encode("utf-8"),
         )
         self.render(request)
         return channel