summary refs log tree commit diff
path: root/tests/rest/client/test_auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_auth.py')
-rw-r--r--tests/rest/client/test_auth.py70
1 files changed, 38 insertions, 32 deletions
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 4a68d66573..9653f45837 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -13,17 +13,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from http import HTTPStatus
-from typing import Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 from twisted.internet.defer import succeed
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client import account, auth, devices, login, logout, register
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.server import HomeServer
 from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict, UserID
+from synapse.util import Clock
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
@@ -33,11 +37,11 @@ from tests.unittest import override_config, skip_unless
 
 
 class DummyRecaptchaChecker(UserInteractiveAuthChecker):
-    def __init__(self, hs):
+    def __init__(self, hs: HomeServer) -> None:
         super().__init__(hs)
-        self.recaptcha_attempts = []
+        self.recaptcha_attempts: List[Tuple[dict, str]] = []
 
-    def check_auth(self, authdict, clientip):
+    def check_auth(self, authdict: dict, clientip: str) -> Any:
         self.recaptcha_attempts.append((authdict, clientip))
         return succeed(True)
 
@@ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
     ]
     hijack_auth = False
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         config = self.default_config()
 
@@ -61,7 +65,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         hs = self.setup_test_homeserver(config=config)
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.recaptcha_checker = DummyRecaptchaChecker(hs)
         auth_handler = hs.get_auth_handler()
         auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
@@ -101,7 +105,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(len(attempts), 1)
         self.assertEqual(attempts[0][0]["response"], "a")
 
-    def test_fallback_captcha(self):
+    def test_fallback_captcha(self) -> None:
         """Ensure that fallback auth via a captcha works."""
         # Returns a 401 as per the spec
         channel = self.register(
@@ -132,7 +136,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         # We're given a registered user.
         self.assertEqual(channel.json_body["user_id"], "@user:test")
 
-    def test_complete_operation_unknown_session(self):
+    def test_complete_operation_unknown_session(self) -> None:
         """
         Attempting to mark an invalid session as complete should error.
         """
@@ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         register.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
 
         # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
@@ -182,12 +186,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         return config
 
-    def create_resource_dict(self):
+    def create_resource_dict(self) -> Dict[str, Resource]:
         resource_dict = super().create_resource_dict()
         resource_dict.update(build_synapse_client_resource_tree(self.hs))
         return resource_dict
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_pass = "pass"
         self.user = self.register_user("test", self.user_pass)
         self.device_id = "dev1"
@@ -229,7 +233,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         return channel
 
-    def test_ui_auth(self):
+    def test_ui_auth(self) -> None:
         """
         Test user interactive authentication outside of registration.
         """
@@ -259,7 +263,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_grandfathered_identifier(self):
+    def test_grandfathered_identifier(self) -> None:
         """Check behaviour without "identifier" dict
 
         Synapse used to require clients to submit a "user" field for m.login.password
@@ -286,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_can_change_body(self):
+    def test_can_change_body(self) -> None:
         """
         The client dict can be modified during the user interactive authentication session.
 
@@ -325,7 +329,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_cannot_change_uri(self):
+    def test_cannot_change_uri(self) -> None:
         """
         The initial requested URI cannot be modified during the user interactive authentication session.
         """
@@ -362,7 +366,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         )
 
     @unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
-    def test_can_reuse_session(self):
+    def test_can_reuse_session(self) -> None:
         """
         The session can be reused if configured.
 
@@ -409,7 +413,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
-    def test_ui_auth_via_sso(self):
+    def test_ui_auth_via_sso(self) -> None:
         """Test a successful UI Auth flow via SSO
 
         This includes:
@@ -452,7 +456,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
-    def test_does_not_offer_password_for_sso_user(self):
+    def test_does_not_offer_password_for_sso_user(self) -> None:
         login_resp = self.helper.login_via_oidc("username")
         user_tok = login_resp["access_token"]
         device_id = login_resp["device_id"]
@@ -464,7 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         flows = channel.json_body["flows"]
         self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
 
-    def test_does_not_offer_sso_for_password_user(self):
+    def test_does_not_offer_sso_for_password_user(self) -> None:
         channel = self.delete_device(
             self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
         )
@@ -474,7 +478,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
-    def test_offers_both_flows_for_upgraded_user(self):
+    def test_offers_both_flows_for_upgraded_user(self) -> None:
         """A user that had a password and then logged in with SSO should get both flows"""
         login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
         self.assertEqual(login_resp["user_id"], self.user)
@@ -491,7 +495,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
-    def test_ui_auth_fails_for_incorrect_sso_user(self):
+    def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
         """If the user tries to authenticate with the wrong SSO user, they get an error"""
         # log the user in
         login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
@@ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
     ]
     hijack_auth = False
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_pass = "pass"
         self.user = self.register_user("test", self.user_pass)
 
@@ -548,7 +552,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": refresh_token},
         )
 
-    def is_access_token_valid(self, access_token) -> bool:
+    def is_access_token_valid(self, access_token: str) -> bool:
         """
         Checks whether an access token is valid, returning whether it is or not.
         """
@@ -561,7 +565,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
 
         return code == HTTPStatus.OK
 
-    def test_login_issue_refresh_token(self):
+    def test_login_issue_refresh_token(self) -> None:
         """
         A login response should include a refresh_token only if asked.
         """
@@ -591,7 +595,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         self.assertIn("refresh_token", login_with_refresh.json_body)
         self.assertIn("expires_in_ms", login_with_refresh.json_body)
 
-    def test_register_issue_refresh_token(self):
+    def test_register_issue_refresh_token(self) -> None:
         """
         A register response should include a refresh_token only if asked.
         """
@@ -627,7 +631,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         self.assertIn("refresh_token", register_with_refresh.json_body)
         self.assertIn("expires_in_ms", register_with_refresh.json_body)
 
-    def test_token_refresh(self):
+    def test_token_refresh(self) -> None:
         """
         A refresh token can be used to issue a new access token.
         """
@@ -665,7 +669,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         )
 
     @override_config({"refreshable_access_token_lifetime": "1m"})
-    def test_refreshable_access_token_expiration(self):
+    def test_refreshable_access_token_expiration(self) -> None:
         """
         The access token should have some time as specified in the config.
         """
@@ -722,7 +726,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "nonrefreshable_access_token_lifetime": "10m",
         }
     )
-    def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self):
+    def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(
+        self,
+    ) -> None:
         """
         Tests that the expiry times for refreshable and non-refreshable access
         tokens can be different.
@@ -782,7 +788,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
     @override_config(
         {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
     )
-    def test_refresh_token_expiry(self):
+    def test_refresh_token_expiry(self) -> None:
         """
         The refresh token can be configured to have a limited lifetime.
         When that lifetime has ended, the refresh token can no longer be used to
@@ -834,7 +840,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "session_lifetime": "3m",
         }
     )
-    def test_ultimate_session_expiry(self):
+    def test_ultimate_session_expiry(self) -> None:
         """
         The session can be configured to have an ultimate, limited lifetime.
         """
@@ -882,7 +888,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
         )
 
-    def test_refresh_token_invalidation(self):
+    def test_refresh_token_invalidation(self) -> None:
         """Refresh tokens are invalidated after first use of the next token.
 
         A refresh token is considered invalid if:
@@ -987,7 +993,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
         )
 
-    def test_many_token_refresh(self):
+    def test_many_token_refresh(self) -> None:
         """
         If a refresh is performed many times during a session, there shouldn't be
         extra 'cruft' built up over time.