summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_redaction.py70
-rw-r--r--tests/storage/test_registration.py1
2 files changed, 71 insertions, 0 deletions
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 8488b6edc8..d961b81d48 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -17,6 +17,8 @@
 
 from mock import Mock
 
+from twisted.internet import defer
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.types import RoomID, UserID
@@ -216,3 +218,71 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
             event.unsigned["redacted_because"],
         )
+
+    def test_circular_redaction(self):
+        redaction_event_id1 = "$redaction1_id:test"
+        redaction_event_id2 = "$redaction2_id:test"
+
+        class EventIdManglingBuilder:
+            def __init__(self, base_builder, event_id):
+                self._base_builder = base_builder
+                self._event_id = event_id
+
+            @defer.inlineCallbacks
+            def build(self, prev_event_ids):
+                built_event = yield self._base_builder.build(prev_event_ids)
+                built_event.event_id = self._event_id
+                built_event._event_dict["event_id"] = self._event_id
+                return built_event
+
+            @property
+            def room_id(self):
+                return self._base_builder.room_id
+
+        event_1, context_1 = self.get_success(
+            self.event_creation_handler.create_new_client_event(
+                EventIdManglingBuilder(
+                    self.event_builder_factory.for_room_version(
+                        RoomVersions.V1,
+                        {
+                            "type": EventTypes.Redaction,
+                            "sender": self.u_alice.to_string(),
+                            "room_id": self.room1.to_string(),
+                            "content": {"reason": "test"},
+                            "redacts": redaction_event_id2,
+                        },
+                    ),
+                    redaction_event_id1,
+                )
+            )
+        )
+
+        self.get_success(self.store.persist_event(event_1, context_1))
+
+        event_2, context_2 = self.get_success(
+            self.event_creation_handler.create_new_client_event(
+                EventIdManglingBuilder(
+                    self.event_builder_factory.for_room_version(
+                        RoomVersions.V1,
+                        {
+                            "type": EventTypes.Redaction,
+                            "sender": self.u_alice.to_string(),
+                            "room_id": self.room1.to_string(),
+                            "content": {"reason": "test"},
+                            "redacts": redaction_event_id1,
+                        },
+                    ),
+                    redaction_event_id2,
+                )
+            )
+        )
+        self.get_success(self.store.persist_event(event_2, context_2))
+
+        # fetch one of the redactions
+        fetched = self.get_success(self.store.get_event(redaction_event_id1))
+
+        # it should have been redacted
+        self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2)
+        self.assertEqual(
+            fetched.unsigned["redacted_because"].event_id, redaction_event_id2
+        )
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 0253c4ac05..4578cc3b60 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -49,6 +49,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
                 "consent_server_notice_sent": None,
                 "appservice_id": None,
                 "creation_ts": 1000,
+                "user_type": None,
             },
             (yield self.store.get_user_by_id(self.user_id)),
         )