summary refs log tree commit diff
path: root/tests/storage/test_redaction.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_redaction.py')
-rw-r--r--tests/storage/test_redaction.py125
1 files changed, 69 insertions, 56 deletions
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6c4e63b77c..df4740f9d9 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -11,27 +11,35 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional
+from typing import List, Optional, cast
 
 from canonicaljson import json
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase, _EventInternalMetadata
+from synapse.events.builder import EventBuilder
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomID, UserID
+from synapse.util import Clock
 
 from tests import unittest
 from tests.utils import create_room
 
 
 class RedactionTestCase(unittest.HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = super().default_config()
         config["redaction_retention_period"] = "30d"
         return config
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
-        self._storage = hs.get_storage_controllers()
+        storage = hs.get_storage_controllers()
+        assert storage.persistence is not None
+        self._persistence = storage.persistence
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
         self.depth = 1
 
-    def inject_room_member(
+    def inject_room_member(  # type: ignore[override]
         self,
-        room,
-        user,
-        membership,
-        replaces_state=None,
-        extra_content: Optional[dict] = None,
-    ):
+        room: RoomID,
+        user: UserID,
+        membership: str,
+        extra_content: Optional[JsonDict] = None,
+    ) -> EventBase:
         content = {"membership": membership}
         content.update(extra_content or {})
         builder = self.event_builder_factory.for_room_version(
@@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self._storage.persistence.persist_event(event, context))
+        self.get_success(self._persistence.persist_event(event, context))
 
         return event
 
-    def inject_message(self, room, user, body):
+    def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase:
         self.depth += 1
 
         builder = self.event_builder_factory.for_room_version(
@@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self._storage.persistence.persist_event(event, context))
+        self.get_success(self._persistence.persist_event(event, context))
 
         return event
 
-    def inject_redaction(self, room, event_id, user, reason):
+    def inject_redaction(
+        self, room: RoomID, event_id: str, user: UserID, reason: str
+    ) -> EventBase:
         builder = self.event_builder_factory.for_room_version(
             RoomVersions.V1,
             {
@@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self._storage.persistence.persist_event(event, context))
+        self.get_success(self._persistence.persist_event(event, context))
 
         return event
 
-    def test_redact(self):
+    def test_redact(self) -> None:
         self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
         msg_event = self.inject_message(self.room1, self.u_alice, "t")
@@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             event.unsigned["redacted_because"],
         )
 
-    def test_redact_join(self):
+    def test_redact_join(self) -> None:
         self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
         msg_event = self.inject_room_member(
@@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             event.unsigned["redacted_because"],
         )
 
-    def test_circular_redaction(self):
+    def test_circular_redaction(self) -> None:
         redaction_event_id1 = "$redaction1_id:test"
         redaction_event_id2 = "$redaction2_id:test"
 
         class EventIdManglingBuilder:
-            def __init__(self, base_builder, event_id):
+            def __init__(self, base_builder: EventBuilder, event_id: str):
                 self._base_builder = base_builder
                 self._event_id = event_id
 
@@ -227,67 +236,73 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 prev_event_ids: List[str],
                 auth_event_ids: Optional[List[str]],
                 depth: Optional[int] = None,
-            ):
+            ) -> EventBase:
                 built_event = await self._base_builder.build(
                     prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
                 )
 
-                built_event._event_id = self._event_id
+                built_event._event_id = self._event_id  # type: ignore[attr-defined]
                 built_event._dict["event_id"] = self._event_id
                 assert built_event.event_id == self._event_id
 
                 return built_event
 
             @property
-            def room_id(self):
+            def room_id(self) -> str:
                 return self._base_builder.room_id
 
             @property
-            def type(self):
+            def type(self) -> str:
                 return self._base_builder.type
 
             @property
-            def internal_metadata(self):
+            def internal_metadata(self) -> _EventInternalMetadata:
                 return self._base_builder.internal_metadata
 
         event_1, context_1 = self.get_success(
             self.event_creation_handler.create_new_client_event(
-                EventIdManglingBuilder(
-                    self.event_builder_factory.for_room_version(
-                        RoomVersions.V1,
-                        {
-                            "type": EventTypes.Redaction,
-                            "sender": self.u_alice.to_string(),
-                            "room_id": self.room1.to_string(),
-                            "content": {"reason": "test"},
-                            "redacts": redaction_event_id2,
-                        },
+                cast(
+                    EventBuilder,
+                    EventIdManglingBuilder(
+                        self.event_builder_factory.for_room_version(
+                            RoomVersions.V1,
+                            {
+                                "type": EventTypes.Redaction,
+                                "sender": self.u_alice.to_string(),
+                                "room_id": self.room1.to_string(),
+                                "content": {"reason": "test"},
+                                "redacts": redaction_event_id2,
+                            },
+                        ),
+                        redaction_event_id1,
                     ),
-                    redaction_event_id1,
                 )
             )
         )
 
-        self.get_success(self._storage.persistence.persist_event(event_1, context_1))
+        self.get_success(self._persistence.persist_event(event_1, context_1))
 
         event_2, context_2 = self.get_success(
             self.event_creation_handler.create_new_client_event(
-                EventIdManglingBuilder(
-                    self.event_builder_factory.for_room_version(
-                        RoomVersions.V1,
-                        {
-                            "type": EventTypes.Redaction,
-                            "sender": self.u_alice.to_string(),
-                            "room_id": self.room1.to_string(),
-                            "content": {"reason": "test"},
-                            "redacts": redaction_event_id1,
-                        },
+                cast(
+                    EventBuilder,
+                    EventIdManglingBuilder(
+                        self.event_builder_factory.for_room_version(
+                            RoomVersions.V1,
+                            {
+                                "type": EventTypes.Redaction,
+                                "sender": self.u_alice.to_string(),
+                                "room_id": self.room1.to_string(),
+                                "content": {"reason": "test"},
+                                "redacts": redaction_event_id1,
+                            },
+                        ),
+                        redaction_event_id2,
                     ),
-                    redaction_event_id2,
                 )
             )
         )
-        self.get_success(self._storage.persistence.persist_event(event_2, context_2))
+        self.get_success(self._persistence.persist_event(event_2, context_2))
 
         # fetch one of the redactions
         fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             fetched.unsigned["redacted_because"].event_id, redaction_event_id2
         )
 
-    def test_redact_censor(self):
+    def test_redact_censor(self) -> None:
         """Test that a redacted event gets censored in the DB after a month"""
 
         self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
         self.assert_dict({"content": {}}, json.loads(event_json))
 
-    def test_redact_redaction(self):
+    def test_redact_redaction(self) -> None:
         """Tests that we can redact a redaction and can fetch it again."""
 
         self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.store.get_event(first_redact_event.event_id, allow_none=True)
         )
 
-    def test_store_redacted_redaction(self):
+    def test_store_redacted_redaction(self) -> None:
         """Tests that we can store a redacted redaction."""
 
         self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(
-            self._storage.persistence.persist_event(redaction_event, context)
-        )
+        self.get_success(self._persistence.persist_event(redaction_event, context))
 
         # Now lets jump to the future where we have censored the redaction event
         # in the DB.