diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6c4e63b77c..df4740f9d9 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -11,27 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional
+from typing import List, Optional, cast
from canonicaljson import json
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase, _EventInternalMetadata
+from synapse.events.builder import EventBuilder
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomID, UserID
+from synapse.util import Clock
from tests import unittest
from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config["redaction_retention_period"] = "30d"
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- self._storage = hs.get_storage_controllers()
+ storage = hs.get_storage_controllers()
+ assert storage.persistence is not None
+ self._persistence = storage.persistence
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.depth = 1
- def inject_room_member(
+ def inject_room_member( # type: ignore[override]
self,
- room,
- user,
- membership,
- replaces_state=None,
- extra_content: Optional[dict] = None,
- ):
+ room: RoomID,
+ user: UserID,
+ membership: str,
+ extra_content: Optional[JsonDict] = None,
+ ) -> EventBase:
content = {"membership": membership}
content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
@@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def inject_message(self, room, user, body):
+ def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase:
self.depth += 1
builder = self.event_builder_factory.for_room_version(
@@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def inject_redaction(self, room, event_id, user, reason):
+ def inject_redaction(
+ self, room: RoomID, event_id: str, user: UserID, reason: str
+ ) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def test_redact(self):
+ def test_redact(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_message(self.room1, self.u_alice, "t")
@@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"],
)
- def test_redact_join(self):
+ def test_redact_join(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_room_member(
@@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"],
)
- def test_circular_redaction(self):
+ def test_circular_redaction(self) -> None:
redaction_event_id1 = "$redaction1_id:test"
redaction_event_id2 = "$redaction2_id:test"
class EventIdManglingBuilder:
- def __init__(self, base_builder, event_id):
+ def __init__(self, base_builder: EventBuilder, event_id: str):
self._base_builder = base_builder
self._event_id = event_id
@@ -227,67 +236,73 @@ class RedactionTestCase(unittest.HomeserverTestCase):
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
- ):
+ ) -> EventBase:
built_event = await self._base_builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)
- built_event._event_id = self._event_id
+ built_event._event_id = self._event_id # type: ignore[attr-defined]
built_event._dict["event_id"] = self._event_id
assert built_event.event_id == self._event_id
return built_event
@property
- def room_id(self):
+ def room_id(self) -> str:
return self._base_builder.room_id
@property
- def type(self):
+ def type(self) -> str:
return self._base_builder.type
@property
- def internal_metadata(self):
+ def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
- EventIdManglingBuilder(
- self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Redaction,
- "sender": self.u_alice.to_string(),
- "room_id": self.room1.to_string(),
- "content": {"reason": "test"},
- "redacts": redaction_event_id2,
- },
+ cast(
+ EventBuilder,
+ EventIdManglingBuilder(
+ self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "test"},
+ "redacts": redaction_event_id2,
+ },
+ ),
+ redaction_event_id1,
),
- redaction_event_id1,
)
)
)
- self.get_success(self._storage.persistence.persist_event(event_1, context_1))
+ self.get_success(self._persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
- EventIdManglingBuilder(
- self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Redaction,
- "sender": self.u_alice.to_string(),
- "room_id": self.room1.to_string(),
- "content": {"reason": "test"},
- "redacts": redaction_event_id1,
- },
+ cast(
+ EventBuilder,
+ EventIdManglingBuilder(
+ self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "test"},
+ "redacts": redaction_event_id1,
+ },
+ ),
+ redaction_event_id2,
),
- redaction_event_id2,
)
)
)
- self.get_success(self._storage.persistence.persist_event(event_2, context_2))
+ self.get_success(self._persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
)
- def test_redact_censor(self):
+ def test_redact_censor(self) -> None:
"""Test that a redacted event gets censored in the DB after a month"""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.assert_dict({"content": {}}, json.loads(event_json))
- def test_redact_redaction(self):
+ def test_redact_redaction(self) -> None:
"""Tests that we can redact a redaction and can fetch it again."""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.store.get_event(first_redact_event.event_id, allow_none=True)
)
- def test_store_redacted_redaction(self):
+ def test_store_redacted_redaction(self) -> None:
"""Tests that we can store a redacted redaction."""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self._storage.persistence.persist_event(redaction_event, context)
- )
+ self.get_success(self._persistence.persist_event(redaction_event, context))
# Now lets jump to the future where we have censored the redaction event
# in the DB.
|