summary refs log tree commit diff
path: root/tests/storage/test_client_ips.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_client_ips.py')
-rw-r--r--tests/storage/test_client_ips.py58
1 files changed, 30 insertions, 28 deletions
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index a9af1babed..81e4e596e4 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -13,15 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Any, Dict
 from unittest.mock import Mock
 
 from parameterized import parameterized
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.http.site import XForwardedForRequest
 from synapse.rest.client import login
+from synapse.server import HomeServer
 from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.types import UserID
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import make_request
@@ -30,14 +35,10 @@ from tests.unittest import override_config
 
 
 class ClientIpStoreTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver()
-        return hs
-
-    def prepare(self, hs, reactor, clock):
-        self.store = self.hs.get_datastores().main
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
 
-    def test_insert_new_client_ip(self):
+    def test_insert_new_client_ip(self) -> None:
         self.reactor.advance(12345678)
 
         user_id = "@user:id"
@@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             r,
         )
 
-    def test_insert_new_client_ip_none_device_id(self):
+    def test_insert_new_client_ip_none_device_id(self) -> None:
         """
         An insert with a device ID of NULL will not create a new entry, but
         update an existing entry in the user_ips table.
@@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
     @parameterized.expand([(False,), (True,)])
-    def test_get_last_client_ip_by_device(self, after_persisting: bool):
+    def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None:
         """Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
         self.reactor.advance(12345678)
 
@@ -213,7 +214,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def test_get_last_client_ip_by_device_combined_data(self):
+    def test_get_last_client_ip_by_device_combined_data(self) -> None:
         """Test that `get_last_client_ip_by_device` combines persisted and unpersisted
         data together correctly
         """
@@ -312,7 +313,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
     @parameterized.expand([(False,), (True,)])
-    def test_get_user_ip_and_agents(self, after_persisting: bool):
+    def test_get_user_ip_and_agents(self, after_persisting: bool) -> None:
         """Test `get_user_ip_and_agents` for persisted and unpersisted data"""
         self.reactor.advance(12345678)
 
@@ -352,7 +353,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             ],
         )
 
-    def test_get_user_ip_and_agents_combined_data(self):
+    def test_get_user_ip_and_agents_combined_data(self) -> None:
         """Test that `get_user_ip_and_agents` combines persisted and unpersisted data
         together correctly
         """
@@ -429,7 +430,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
     @override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
-    def test_disabled_monthly_active_user(self):
+    def test_disabled_monthly_active_user(self) -> None:
         user_id = "@user:server"
         self.get_success(
             self.store.insert_client_ip(
@@ -440,7 +441,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.assertFalse(active)
 
     @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
-    def test_adding_monthly_active_user_when_full(self):
+    def test_adding_monthly_active_user_when_full(self) -> None:
         lots_of_users = 100
         user_id = "@user:server"
 
@@ -456,7 +457,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.assertFalse(active)
 
     @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
-    def test_adding_monthly_active_user_when_space(self):
+    def test_adding_monthly_active_user_when_space(self) -> None:
         user_id = "@user:server"
         active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertFalse(active)
@@ -473,7 +474,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.assertTrue(active)
 
     @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
-    def test_updating_monthly_active_user_when_space(self):
+    def test_updating_monthly_active_user_when_space(self) -> None:
         user_id = "@user:server"
         self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
 
@@ -491,7 +492,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertTrue(active)
 
-    def test_devices_last_seen_bg_update(self):
+    def test_devices_last_seen_bg_update(self) -> None:
         # First make sure we have completed all updates.
         self.wait_for_background_updates()
 
@@ -576,7 +577,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             r,
         )
 
-    def test_old_user_ips_pruned(self):
+    def test_old_user_ips_pruned(self) -> None:
         # First make sure we have completed all updates.
         self.wait_for_background_updates()
 
@@ -639,11 +640,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.assertEqual(result, [])
 
         # But we should still get the correct values for the device
-        result = self.get_success(
+        result2 = self.get_success(
             self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, device_id)]
+        r = result2[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
@@ -663,15 +664,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver()
-        return hs
-
-    def prepare(self, hs, reactor, clock):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = self.hs.get_datastores().main
         self.user_id = self.register_user("bob", "abc123", True)
 
-    def test_request_with_xforwarded(self):
+    def test_request_with_xforwarded(self) -> None:
         """
         The IP in X-Forwarded-For is entered into the client IPs table.
         """
@@ -681,14 +678,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
             {"request": XForwardedForRequest},
         )
 
-    def test_request_from_getPeer(self):
+    def test_request_from_getPeer(self) -> None:
         """
         The IP returned by getPeer is entered into the client IPs table, if
         there's no X-Forwarded-For header.
         """
         self._runtest({}, "127.0.0.1", {})
 
-    def _runtest(self, headers, expected_ip, make_request_args):
+    def _runtest(
+        self,
+        headers: Dict[bytes, bytes],
+        expected_ip: str,
+        make_request_args: Dict[str, Any],
+    ) -> None:
         device_id = "bleb"
 
         access_token = self.login("bob", "abc123", device_id=device_id)