summary refs log tree commit diff
path: root/tests/rest/client/test_retention.py
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2022-02-28 18:47:37 +0100
committerGitHub <noreply@github.com>2022-02-28 17:47:37 +0000
commit1901cb1d4a8b7d9af64493fbd336e9aa2561c20c (patch)
tree60ac88ede37911a2cb92de0c70b55710fd52bd80 /tests/rest/client/test_retention.py
parentFix `PushRuleEvaluator` and `Filter` to work on frozendicts (#12100) (diff)
downloadsynapse-1901cb1d4a8b7d9af64493fbd336e9aa2561c20c.tar.xz
Add type hints to `tests/rest/client` (#12084)
Diffstat (limited to 'tests/rest/client/test_retention.py')
-rw-r--r--tests/rest/client/test_retention.py41
1 files changed, 27 insertions, 14 deletions
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index c41a1c14a1..f3bf8d0934 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -13,9 +13,14 @@
 # limitations under the License.
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 from synapse.visibility import filter_events_for_client
 
 from tests import unittest
@@ -31,7 +36,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["retention"] = {
             "enabled": True,
@@ -47,7 +52,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = self.register_user("user", "password")
         self.token = self.login("user", "password")
 
@@ -55,7 +60,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         self.serializer = self.hs.get_event_client_serializer()
         self.clock = self.hs.get_clock()
 
-    def test_retention_event_purged_with_state_event(self):
+    def test_retention_event_purged_with_state_event(self) -> None:
         """Tests that expired events are correctly purged when the room's retention policy
         is defined by a state event.
         """
@@ -72,7 +77,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         self._test_retention_event_purged(room_id, one_day_ms * 1.5)
 
-    def test_retention_event_purged_with_state_event_outside_allowed(self):
+    def test_retention_event_purged_with_state_event_outside_allowed(self) -> None:
         """Tests that the server configuration can override the policy for a room when
         running the purge jobs.
         """
@@ -102,7 +107,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         # instead of the one specified in the room's policy.
         self._test_retention_event_purged(room_id, one_day_ms * 0.5)
 
-    def test_retention_event_purged_without_state_event(self):
+    def test_retention_event_purged_without_state_event(self) -> None:
         """Tests that expired events are correctly purged when the room's retention policy
         is defined by the server's configuration's default retention policy.
         """
@@ -110,7 +115,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         self._test_retention_event_purged(room_id, one_day_ms * 2)
 
-    def test_visibility(self):
+    def test_visibility(self) -> None:
         """Tests that synapse.visibility.filter_events_for_client correctly filters out
         outdated events
         """
@@ -152,7 +157,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         # That event should be the second, not outdated event.
         self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
 
-    def _test_retention_event_purged(self, room_id: str, increment: float):
+    def _test_retention_event_purged(self, room_id: str, increment: float) -> None:
         """Run the following test scenario to test the message retention policy support:
 
         1. Send event 1
@@ -186,6 +191,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
 
         expired_event_id = resp.get("event_id")
+        assert expired_event_id is not None
 
         # Check that we can retrieve the event.
         expired_event = self.get_event(expired_event_id)
@@ -201,6 +207,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
 
         valid_event_id = resp.get("event_id")
+        assert valid_event_id is not None
 
         # Advance the time again. Now our first event should have expired but our second
         # one should still be kept.
@@ -218,7 +225,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         # has been purged.
         self.get_event(room_id, create_event.event_id)
 
-    def get_event(self, event_id, expect_none=False):
+    def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
         event = self.get_success(self.store.get_event(event_id, allow_none=True))
 
         if expect_none:
@@ -240,7 +247,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["retention"] = {
             "enabled": True,
@@ -254,11 +261,11 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
         )
         return self.hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = self.register_user("user", "password")
         self.token = self.login("user", "password")
 
-    def test_no_default_policy(self):
+    def test_no_default_policy(self) -> None:
         """Tests that an event doesn't get expired if there is neither a default retention
         policy nor a policy specific to the room.
         """
@@ -266,7 +273,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
 
         self._test_retention(room_id)
 
-    def test_state_policy(self):
+    def test_state_policy(self) -> None:
         """Tests that an event gets correctly expired if there is no default retention
         policy but there's a policy specific to the room.
         """
@@ -283,12 +290,15 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
 
         self._test_retention(room_id, expected_code_for_first_event=404)
 
-    def _test_retention(self, room_id, expected_code_for_first_event=200):
+    def _test_retention(
+        self, room_id: str, expected_code_for_first_event: int = 200
+    ) -> None:
         # Send a first event to the room. This is the event we'll want to be purged at the
         # end of the test.
         resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
 
         first_event_id = resp.get("event_id")
+        assert first_event_id is not None
 
         # Check that we can retrieve the event.
         expired_event = self.get_event(room_id, first_event_id)
@@ -304,6 +314,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
 
         second_event_id = resp.get("event_id")
+        assert second_event_id is not None
 
         # Advance the time by another month.
         self.reactor.advance(one_day_ms * 30 / 1000)
@@ -322,7 +333,9 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
         second_event = self.get_event(room_id, second_event_id)
         self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event)
 
-    def get_event(self, room_id, event_id, expected_code=200):
+    def get_event(
+        self, room_id: str, event_id: str, expected_code: int = 200
+    ) -> JsonDict:
         url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
 
         channel = self.make_request("GET", url, access_token=self.token)