summary refs log tree commit diff
path: root/tests/replication/slave
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2016-04-07 16:26:52 +0100
committerMark Haines <mark.haines@matrix.org>2016-04-07 16:52:07 +0100
commitceb599e789ef34a3733a99d17701a75fc5409d17 (patch)
tree8641d93b820f619f28b8cd588e7b532839b955c4 /tests/replication/slave
parentMerge pull request #704 from matrix-org/markh/slaveIII (diff)
downloadsynapse-ceb599e789ef34a3733a99d17701a75fc5409d17.tar.xz
Add tests for redactions
Diffstat (limited to 'tests/replication/slave')
-rw-r--r--tests/replication/slave/storage/_base.py2
-rw-r--r--tests/replication/slave/storage/test_events.py51
2 files changed, 51 insertions, 2 deletions
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 0f525a8943..983caafe8a 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
     def check(self, method, args, expected_result=None):
         master_result = yield getattr(self.master_store, method)(*args)
         slaved_result = yield getattr(self.slaved_store, method)(*args)
-        self.assertEqual(master_result, slaved_result)
         if expected_result is not None:
             self.assertEqual(master_result, expected_result)
             self.assertEqual(slaved_result, expected_result)
+        self.assertEqual(master_result, slaved_result)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 9af62702b3..baa4a26eb5 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -205,13 +205,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             [join3]
         )
 
+    @defer.inlineCallbacks
+    def test_redactions(self):
+        yield self.persist(type="m.room.create", key="", creator=USER_ID)
+        yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+
+        msg = yield self.persist(
+            type="m.room.message", msgtype="m.text", body="Hello"
+        )
+        yield self.replicate()
+        yield self.check("get_event", [msg.event_id], msg)
+
+        redaction = yield self.persist(
+            type="m.room.redaction", redacts=msg.event_id
+        )
+        yield self.replicate()
+
+        msg_dict = msg.get_dict()
+        msg_dict["content"] = {}
+        msg_dict["unsigned"]["redacted_by"] = redaction.event_id
+        msg_dict["unsigned"]["redacted_because"] = redaction
+        redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
+        yield self.check("get_event", [msg.event_id], redacted)
+
+    @defer.inlineCallbacks
+    def test_backfilled_redactions(self):
+        yield self.persist(type="m.room.create", key="", creator=USER_ID)
+        yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+
+        msg = yield self.persist(
+            type="m.room.message", msgtype="m.text", body="Hello"
+        )
+        yield self.replicate()
+        yield self.check("get_event", [msg.event_id], msg)
+
+        redaction = yield self.persist(
+            type="m.room.redaction", redacts=msg.event_id, backfill=True
+        )
+        yield self.replicate()
+
+        msg_dict = msg.get_dict()
+        msg_dict["content"] = {}
+        msg_dict["unsigned"]["redacted_by"] = redaction.event_id
+        msg_dict["unsigned"]["redacted_because"] = redaction
+        redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
+        yield self.check("get_event", [msg.event_id], redacted)
+
     event_id = 0
 
     @defer.inlineCallbacks
     def persist(
         self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
         state=None, reset_state=False, backfill=False,
-        depth=None, prev_events=[], auth_events=[], prev_state=[],
+        depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
         **content
     ):
         """
@@ -236,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             event_dict["state_key"] = key
             event_dict["prev_state"] = prev_state
 
+        if redacts is not None:
+            event_dict["redacts"] = redacts
+
         event = FrozenEvent(event_dict, internal_metadata_dict=internal)
 
         self.event_id += 1