summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test__base.py14
-rw-r--r--tests/storage/test_monthly_active_users.py47
-rw-r--r--tests/storage/test_roommember.py25
3 files changed, 50 insertions, 36 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 4899cd5c36..366398e39d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,12 +14,18 @@
 # limitations under the License.
 
 import secrets
+from typing import Any, Dict, Generator, List, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
 
 class UpsertManyTests(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.storage = hs.get_datastores().main
 
         self.table_name = "table_" + secrets.token_hex(6)
@@ -40,11 +46,13 @@ class UpsertManyTests(unittest.HomeserverTestCase):
             )
         )
 
-    def _dump_to_tuple(self, res):
+    def _dump_to_tuple(
+        self, res: List[Dict[str, Any]]
+    ) -> Generator[Tuple[int, str, str], None, None]:
         for i in res:
             yield (i["id"], i["username"], i["value"])
 
-    def test_upsert_many(self):
+    def test_upsert_many(self) -> None:
         """
         Upsert_many will perform the upsert operation across a batch of data.
         """
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 79648d45db..60c8d37594 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -11,11 +11,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, List
 from unittest.mock import Mock
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import UserTypes
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -24,7 +28,7 @@ from tests.unittest import default_config, override_config
 FORTY_DAYS = 40 * 24 * 60 * 60
 
 
-def gen_3pids(count):
+def gen_3pids(count: int) -> List[Dict[str, Any]]:
     """Generate `count` threepids as a list."""
     return [
         {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count)
@@ -32,7 +36,7 @@ def gen_3pids(count):
 
 
 class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = default_config("test")
 
         config.update({"limit_usage_by_mau": True, "max_mau_value": 50})
@@ -44,10 +48,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastores().main
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
         # Advance the clock a bit
-        reactor.advance(FORTY_DAYS)
+        self.reactor.advance(FORTY_DAYS)
 
     @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
     def test_initialise_reserved_users(self):
@@ -245,7 +249,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
 
         d = self.store.populate_monthly_active_users(user_id)
         self.get_success(d)
@@ -253,9 +257,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_populate_monthly_users_should_update(self):
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
 
-        self.store.is_trial_user = Mock(return_value=defer.succeed(False))
+        self.store.is_trial_user = Mock(return_value=defer.succeed(False))  # type: ignore[assignment]
 
         self.store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(None)
@@ -266,9 +270,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_called_once()
 
     def test_populate_monthly_users_should_not_update(self):
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
 
-        self.store.is_trial_user = Mock(return_value=defer.succeed(False))
+        self.store.is_trial_user = Mock(return_value=defer.succeed(False))  # type: ignore[assignment]
         self.store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(self.hs.get_clock().time_msec())
         )
@@ -324,16 +328,15 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         count = self.get_success(self.store.get_monthly_active_count())
         self.assertEqual(count, 0)
 
-        d = self.store.register_user(
-            user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
+        self.get_success(
+            self.store.register_user(
+                user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
+            )
         )
-        self.get_success(d)
 
-        d = self.store.upsert_monthly_active_user(support_user_id)
-        self.get_success(d)
+        self.get_success(self.store.upsert_monthly_active_user(support_user_id))
 
-        d = self.store.get_monthly_active_count()
-        count = self.get_success(d)
+        count = self.get_success(self.store.get_monthly_active_count())
         self.assertEqual(count, 0)
 
     # Note that the max_mau_value setting should not matter.
@@ -352,7 +355,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
     def test_no_users_when_not_tracking(self):
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
 
         self.get_success(self.store.populate_monthly_active_users("@user:sever"))
 
@@ -388,16 +391,16 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
             self.store.register_user(user_id=native_user1, password_hash=None)
         )
 
-        count = self.get_success(self.store.get_monthly_active_count_by_service())
-        self.assertEqual(count, {})
+        count1 = self.get_success(self.store.get_monthly_active_count_by_service())
+        self.assertEqual(count1, {})
 
         self.get_success(self.store.upsert_monthly_active_user(native_user1))
         self.get_success(self.store.upsert_monthly_active_user(appservice1_user1))
         self.get_success(self.store.upsert_monthly_active_user(appservice1_user2))
         self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
 
-        count = self.get_success(self.store.get_monthly_active_count())
-        self.assertEqual(count, 1)
+        count2 = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count2, 1)
 
         d = self.store.get_monthly_active_count_by_service()
         result = self.get_success(d)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index b8f09a8ee0..a2a9c05f24 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -12,11 +12,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import Membership
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import TestHomeServer
@@ -31,7 +34,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs: TestHomeServer):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None:
 
         # We can't test the RoomMemberStore on its own without the other event
         # storage logic
@@ -44,7 +47,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # User elsewhere on another host
         self.u_charlie = UserID.from_string("@charlie:elsewhere")
 
-    def test_one_member(self):
+    def test_one_member(self) -> None:
 
         # Alice creates the room, and is automatically joined
         self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
@@ -57,7 +60,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual([self.room], [m.room_id for m in rooms_for_user])
 
-    def test_count_known_servers(self):
+    def test_count_known_servers(self) -> None:
         """
         _count_known_servers will calculate how many servers are in a room.
         """
@@ -68,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         servers = self.get_success(self.store._count_known_servers())
         self.assertEqual(servers, 2)
 
-    def test_count_known_servers_stat_counter_disabled(self):
+    def test_count_known_servers_stat_counter_disabled(self) -> None:
         """
         If enabled, the metrics for how many servers are known will be counted.
         """
@@ -85,7 +88,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
     @unittest.override_config(
         {"enable_metrics": True, "metrics_flags": {"known_servers": True}}
     )
-    def test_count_known_servers_stat_counter_enabled(self):
+    def test_count_known_servers_stat_counter_enabled(self) -> None:
         """
         If enabled, the metrics for how many servers are known will be counted.
         """
@@ -107,7 +110,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # It now knows about Charlie's server.
         self.assertEqual(self.store._known_servers_count, 2)
 
-    def test_get_joined_users_from_context(self):
+    def test_get_joined_users_from_context(self) -> None:
         room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
         bob_event = self.get_success(
             event_injection.inject_member_event(
@@ -161,7 +164,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
 
-    def test__null_byte_in_display_name_properly_handled(self):
+    def test__null_byte_in_display_name_properly_handled(self) -> None:
         room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
 
         res = self.get_success(
@@ -211,11 +214,11 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
 
 
 class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastores().main
-        self.room_creator = homeserver.get_room_creation_handler()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.room_creator = hs.get_room_creation_handler()
 
-    def test_can_rerun_update(self):
+    def test_can_rerun_update(self) -> None:
         # First make sure we have completed all updates.
         self.wait_for_background_updates()