From 64c73c6ac88a740ee480a0ad1f9afc8596bccfa4 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 23 Feb 2022 14:33:19 +0100 Subject: Add type hints to `tests/rest/client` (#12066) --- tests/rest/client/test_login.py | 120 ++++++++++++++++++++++------------------ 1 file changed, 66 insertions(+), 54 deletions(-) (limited to 'tests/rest/client/test_login.py') diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 26d0d83e00..d48defda63 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -20,6 +20,7 @@ from urllib.parse import urlencode import pymacaroons +from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin @@ -27,12 +28,15 @@ from synapse.appservice import ApplicationService from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -95,7 +99,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() self.hs.config.registration.enable_registration = True self.hs.config.registration.registrations_require_3pid = [] @@ -117,7 +121,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_address(self): + def test_POST_ratelimiting_per_address(self) -> None: # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -165,7 +169,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_account(self): + def test_POST_ratelimiting_per_account(self) -> None: self.register_user("kermit", "monkey") for i in range(0, 6): @@ -210,7 +214,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_account_failed_attempts(self): + def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: self.register_user("kermit", "monkey") for i in range(0, 6): @@ -243,7 +247,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) @override_config({"session_lifetime": "24h"}) - def test_soft_logout(self): + def test_soft_logout(self) -> None: self.register_user("kermit", "monkey") # we shouldn't be able to make requests without an access token @@ -298,7 +302,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], False) - def _delete_device(self, access_token, user_id, password, device_id): + def _delete_device( + self, access_token: str, user_id: str, password: str, device_id: str + ) -> None: """Perform the UI-Auth to delete a device""" channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token @@ -329,7 +335,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_after_being_soft_logged_out(self): + def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: self.register_user("kermit", "monkey") # log in as normal @@ -353,7 +359,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( + self, + ) -> None: self.register_user("kermit", "monkey") # log in as normal @@ -432,7 +440,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_get_login_flows(self): + def test_get_login_flows(self) -> None: """GET /login should return password and SSO flows""" channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) @@ -459,12 +467,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ], ) - def test_multi_sso_redirect(self): + def test_multi_sso_redirect(self) -> None: """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker channel = self._make_sso_redirect_request(None) self.assertEqual(channel.code, 302, channel.result) - uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + uri = location_headers[0] # hitting that picker should give us some HTML channel = self.make_request("GET", uri) @@ -487,7 +497,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) - def test_multi_sso_redirect_to_cas(self): + def test_multi_sso_redirect_to_cas(self) -> None: """If CAS is chosen, should redirect to the CAS server""" channel = self.make_request( @@ -514,7 +524,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri_params = urllib.parse.parse_qs(service_uri_query) self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_saml(self): + def test_multi_sso_redirect_to_saml(self) -> None: """If SAML is chosen, should redirect to the SAML server""" channel = self.make_request( "GET", @@ -536,7 +546,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): relay_state_param = saml_uri_params["RelayState"][0] self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_login_via_oidc(self): + def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" # pick the default OIDC provider @@ -604,7 +614,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") - def test_multi_sso_redirect_to_unknown(self): + def test_multi_sso_redirect_to_unknown(self) -> None: """An unknown IdP should cause a 400""" channel = self.make_request( "GET", @@ -612,23 +622,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def test_client_idp_redirect_to_unknown(self): + def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" channel = self._make_sso_redirect_request("xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - def test_client_idp_redirect_to_oidc(self): + def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + oidc_uri = location_headers[0] oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - def _make_sso_redirect_request(self, idp_prov: Optional[str] = None): + def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect ... possibly specifying an IDP provider @@ -659,7 +671,7 @@ class CASTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.base_url = "https://matrix.goodserver.com/" self.redirect_path = "_synapse/client/login/sso/redirect/confirm" @@ -675,7 +687,7 @@ class CASTestCase(unittest.HomeserverTestCase): cas_user_id = "username" self.user_id = "@%s:test" % cas_user_id - async def get_raw(uri, args): + async def get_raw(uri: str, args: Any) -> bytes: """Return an example response payload from a call to the `/proxyValidate` endpoint of a CAS server, copied from https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 @@ -709,10 +721,10 @@ class CASTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.deactivate_account_handler = hs.get_deactivate_account_handler() - def test_cas_redirect_confirm(self): + def test_cas_redirect_confirm(self) -> None: """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. """ @@ -754,15 +766,15 @@ class CASTestCase(unittest.HomeserverTestCase): } } ) - def test_cas_redirect_whitelisted(self): + def test_cas_redirect_whitelisted(self) -> None: """Tests that the SSO login flow serves a redirect to a whitelisted url""" self._test_redirect("https://legit-site.com/") @override_config({"public_baseurl": "https://example.com"}) - def test_cas_redirect_login_fallback(self): + def test_cas_redirect_login_fallback(self) -> None: self._test_redirect("https://example.com/_matrix/static/client/login") - def _test_redirect(self, redirect_url): + def _test_redirect(self, redirect_url: str) -> None: """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" @@ -778,7 +790,7 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) - def test_deactivated_user(self): + def test_deactivated_user(self) -> None: """Logging in as a deactivated account should error.""" redirect_url = "https://legit-site.com/" @@ -821,7 +833,7 @@ class JWTTestCase(unittest.HomeserverTestCase): "algorithm": jwt_algorithm, } - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # If jwt_config has been defined (eg via @override_config), don't replace it. @@ -837,23 +849,23 @@ class JWTTestCase(unittest.HomeserverTestCase): return result.decode("ascii") return result - def jwt_login(self, *args): + def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid_registered(self): + def test_login_jwt_valid_registered(self) -> None: self.register_user("kermit", "monkey") channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_valid_unregistered(self): + def test_login_jwt_valid_unregistered(self) -> None: channel = self.jwt_login({"sub": "frog"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_jwt_invalid_signature(self): + def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, "notsecret") self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -862,7 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: Signature verification failed", ) - def test_login_jwt_expired(self): + def test_login_jwt_expired(self) -> None: channel = self.jwt_login({"sub": "frog", "exp": 864000}) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -870,7 +882,7 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Signature has expired" ) - def test_login_jwt_not_before(self): + def test_login_jwt_not_before(self) -> None: now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -880,14 +892,14 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: The token is not yet valid (nbf)", ) - def test_login_no_sub(self): + def test_login_no_sub(self) -> None: channel = self.jwt_login({"username": "root"}) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) - def test_login_iss(self): + def test_login_iss(self) -> None: """Test validating the issuer claim.""" # A valid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) @@ -911,14 +923,14 @@ class JWTTestCase(unittest.HomeserverTestCase): 'JWT validation failed: Token is missing the "iss" claim', ) - def test_login_iss_no_config(self): + def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) - def test_login_aud(self): + def test_login_aud(self) -> None: """Test validating the audience claim.""" # A valid audience. channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) @@ -942,7 +954,7 @@ class JWTTestCase(unittest.HomeserverTestCase): 'JWT validation failed: Token is missing the "aud" claim', ) - def test_login_aud_no_config(self): + def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -951,20 +963,20 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Invalid audience" ) - def test_login_default_sub(self): + def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) - def test_login_custom_sub(self): + def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" channel = self.jwt_login({"username": "frog"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_no_token(self): + def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -1026,7 +1038,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): ] ) - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["jwt_config"] = { "enabled": True, @@ -1042,17 +1054,17 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return result.decode("ascii") return result - def jwt_login(self, *args): + def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid(self): + def test_login_jwt_valid(self) -> None: channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_invalid_signature(self): + def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -1071,7 +1083,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() self.service = ApplicationService( @@ -1105,7 +1117,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.services_cache.append(self.another_service) return self.hs - def test_login_appservice_user(self): + def test_login_appservice_user(self) -> None: """Test that an appservice user can use /login""" self.register_appservice_user(AS_USER, self.service.token) @@ -1119,7 +1131,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) - def test_login_appservice_user_bot(self): + def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" self.register_appservice_user(AS_USER, self.service.token) @@ -1133,7 +1145,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) - def test_login_appservice_wrong_user(self): + def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" self.register_appservice_user(AS_USER, self.service.token) @@ -1147,7 +1159,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) - def test_login_appservice_wrong_as(self): + def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" self.register_appservice_user(AS_USER, self.service.token) @@ -1161,7 +1173,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) - def test_login_appservice_no_token(self): + def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice login method """ @@ -1182,7 +1194,7 @@ class UsernamePickerTestCase(HomeserverTestCase): servlets = [login.register_servlets] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL @@ -1202,7 +1214,7 @@ class UsernamePickerTestCase(HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_username_picker(self): + def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" # do the start of the login flow -- cgit 1.4.1