diff options
author | Hannes Lerchl <aytchell@users.noreply.github.com> | 2022-06-15 18:45:16 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-15 16:45:16 +0000 |
commit | 7d99414edf2c5c7e602a88c72245add665e6afb4 (patch) | |
tree | da17d91c48acdae424833784f40efc29a14c4416 /tests/rest | |
parent | Sort failing jobs in Complement CI to the top of the logs to make them easier... (diff) | |
download | synapse-7d99414edf2c5c7e602a88c72245add665e6afb4.tar.xz |
Replace pyjwt with authlib in `org.matrix.login.jwt` (#13011)
Diffstat (limited to 'tests/rest')
-rw-r--r-- | tests/rest/client/test_login.py | 44 |
1 files changed, 23 insertions, 21 deletions
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index f4ea1209d9..f6efa5fe37 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -14,7 +14,7 @@ import json import time import urllib.parse -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from unittest.mock import Mock from urllib.parse import urlencode @@ -41,7 +41,7 @@ from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless try: - import jwt + from authlib.jose import jwk, jwt HAS_JWT = True except ImportError: @@ -841,7 +841,7 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertIn(b"SSO account deactivated", channel.result["body"]) -@skip_unless(HAS_JWT, "requires jwt") +@skip_unless(HAS_JWT, "requires authlib") class JWTTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -866,11 +866,9 @@ class JWTTestCase(unittest.HomeserverTestCase): return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) - if isinstance(result, bytes): - return result.decode("ascii") - return result + header = {"alg": self.jwt_algorithm} + result: bytes = jwt.encode(header, payload, secret) + return result.decode("ascii") def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} @@ -902,7 +900,8 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Signature has expired" + channel.json_body["error"], + "JWT validation failed: expired_token: The token is expired", ) def test_login_jwt_not_before(self) -> None: @@ -912,7 +911,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - "JWT validation failed: The token is not yet valid (nbf)", + "JWT validation failed: invalid_token: The token is not valid yet", ) def test_login_no_sub(self) -> None: @@ -934,7 +933,8 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid issuer" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "iss"', ) # Not providing an issuer. @@ -943,7 +943,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - 'JWT validation failed: Token is missing the "iss" claim', + 'JWT validation failed: missing_claim: Missing "iss" claim', ) def test_login_iss_no_config(self) -> None: @@ -965,7 +965,8 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "aud"', ) # Not providing an audience. @@ -974,7 +975,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - 'JWT validation failed: Token is missing the "aud" claim', + 'JWT validation failed: missing_claim: Missing "aud" claim', ) def test_login_aud_no_config(self) -> None: @@ -983,7 +984,8 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "aud"', ) def test_login_default_sub(self) -> None: @@ -1010,7 +1012,7 @@ class JWTTestCase(unittest.HomeserverTestCase): # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use # RSS256, with a public key configured in synapse as "jwt_secret", and tokens # signed by the private key. -@skip_unless(HAS_JWT, "requires jwt") +@skip_unless(HAS_JWT, "requires authlib") class JWTPubKeyTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, @@ -1071,11 +1073,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") - if isinstance(result, bytes): - return result.decode("ascii") - return result + header = {"alg": "RS256"} + if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"): + secret = jwk.dumps(secret, kty="RSA") + result: bytes = jwt.encode(header, payload, secret) + return result.decode("ascii") def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} |