diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 4a68d66573..9653f45837 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -13,17 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
-from typing import Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
from twisted.internet.defer import succeed
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
@@ -33,11 +37,11 @@ from tests.unittest import override_config, skip_unless
class DummyRecaptchaChecker(UserInteractiveAuthChecker):
- def __init__(self, hs):
+ def __init__(self, hs: HomeServer) -> None:
super().__init__(hs)
- self.recaptcha_attempts = []
+ self.recaptcha_attempts: List[Tuple[dict, str]] = []
- def check_auth(self, authdict, clientip):
+ def check_auth(self, authdict: dict, clientip: str) -> Any:
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)
@@ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
]
hijack_auth = False
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
@@ -61,7 +65,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
@@ -101,7 +105,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a")
- def test_fallback_captcha(self):
+ def test_fallback_captcha(self) -> None:
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
channel = self.register(
@@ -132,7 +136,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")
- def test_complete_operation_unknown_session(self):
+ def test_complete_operation_unknown_session(self) -> None:
"""
Attempting to mark an invalid session as complete should error.
"""
@@ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
@@ -182,12 +186,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
return config
- def create_resource_dict(self):
+ def create_resource_dict(self) -> Dict[str, Resource]:
resource_dict = super().create_resource_dict()
resource_dict.update(build_synapse_client_resource_tree(self.hs))
return resource_dict
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.device_id = "dev1"
@@ -229,7 +233,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
return channel
- def test_ui_auth(self):
+ def test_ui_auth(self) -> None:
"""
Test user interactive authentication outside of registration.
"""
@@ -259,7 +263,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
)
- def test_grandfathered_identifier(self):
+ def test_grandfathered_identifier(self) -> None:
"""Check behaviour without "identifier" dict
Synapse used to require clients to submit a "user" field for m.login.password
@@ -286,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
)
- def test_can_change_body(self):
+ def test_can_change_body(self) -> None:
"""
The client dict can be modified during the user interactive authentication session.
@@ -325,7 +329,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
)
- def test_cannot_change_uri(self):
+ def test_cannot_change_uri(self) -> None:
"""
The initial requested URI cannot be modified during the user interactive authentication session.
"""
@@ -362,7 +366,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
)
@unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
- def test_can_reuse_session(self):
+ def test_can_reuse_session(self) -> None:
"""
The session can be reused if configured.
@@ -409,7 +413,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
- def test_ui_auth_via_sso(self):
+ def test_ui_auth_via_sso(self) -> None:
"""Test a successful UI Auth flow via SSO
This includes:
@@ -452,7 +456,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
- def test_does_not_offer_password_for_sso_user(self):
+ def test_does_not_offer_password_for_sso_user(self) -> None:
login_resp = self.helper.login_via_oidc("username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]
@@ -464,7 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
- def test_does_not_offer_sso_for_password_user(self):
+ def test_does_not_offer_sso_for_password_user(self) -> None:
channel = self.delete_device(
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
)
@@ -474,7 +478,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
- def test_offers_both_flows_for_upgraded_user(self):
+ def test_offers_both_flows_for_upgraded_user(self) -> None:
"""A user that had a password and then logged in with SSO should get both flows"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
@@ -491,7 +495,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
- def test_ui_auth_fails_for_incorrect_sso_user(self):
+ def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
"""If the user tries to authenticate with the wrong SSO user, they get an error"""
# log the user in
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
@@ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
]
hijack_auth = False
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
@@ -548,7 +552,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
{"refresh_token": refresh_token},
)
- def is_access_token_valid(self, access_token) -> bool:
+ def is_access_token_valid(self, access_token: str) -> bool:
"""
Checks whether an access token is valid, returning whether it is or not.
"""
@@ -561,7 +565,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
return code == HTTPStatus.OK
- def test_login_issue_refresh_token(self):
+ def test_login_issue_refresh_token(self) -> None:
"""
A login response should include a refresh_token only if asked.
"""
@@ -591,7 +595,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.assertIn("refresh_token", login_with_refresh.json_body)
self.assertIn("expires_in_ms", login_with_refresh.json_body)
- def test_register_issue_refresh_token(self):
+ def test_register_issue_refresh_token(self) -> None:
"""
A register response should include a refresh_token only if asked.
"""
@@ -627,7 +631,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.assertIn("refresh_token", register_with_refresh.json_body)
self.assertIn("expires_in_ms", register_with_refresh.json_body)
- def test_token_refresh(self):
+ def test_token_refresh(self) -> None:
"""
A refresh token can be used to issue a new access token.
"""
@@ -665,7 +669,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
)
@override_config({"refreshable_access_token_lifetime": "1m"})
- def test_refreshable_access_token_expiration(self):
+ def test_refreshable_access_token_expiration(self) -> None:
"""
The access token should have some time as specified in the config.
"""
@@ -722,7 +726,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
"nonrefreshable_access_token_lifetime": "10m",
}
)
- def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self):
+ def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(
+ self,
+ ) -> None:
"""
Tests that the expiry times for refreshable and non-refreshable access
tokens can be different.
@@ -782,7 +788,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
@override_config(
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
)
- def test_refresh_token_expiry(self):
+ def test_refresh_token_expiry(self) -> None:
"""
The refresh token can be configured to have a limited lifetime.
When that lifetime has ended, the refresh token can no longer be used to
@@ -834,7 +840,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
"session_lifetime": "3m",
}
)
- def test_ultimate_session_expiry(self):
+ def test_ultimate_session_expiry(self) -> None:
"""
The session can be configured to have an ultimate, limited lifetime.
"""
@@ -882,7 +888,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
)
- def test_refresh_token_invalidation(self):
+ def test_refresh_token_invalidation(self) -> None:
"""Refresh tokens are invalidated after first use of the next token.
A refresh token is considered invalid if:
@@ -987,7 +993,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
)
- def test_many_token_refresh(self):
+ def test_many_token_refresh(self) -> None:
"""
If a refresh is performed many times during a session, there shouldn't be
extra 'cruft' built up over time.
|