summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/test_login.py44
1 files changed, 23 insertions, 21 deletions
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f4ea1209d9..f6efa5fe37 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -14,7 +14,7 @@
 import json
 import time
 import urllib.parse
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional
 from unittest.mock import Mock
 from urllib.parse import urlencode
 
@@ -41,7 +41,7 @@ from tests.test_utils.html_parsers import TestHtmlParser
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
 try:
-    import jwt
+    from authlib.jose import jwk, jwt
 
     HAS_JWT = True
 except ImportError:
@@ -841,7 +841,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         self.assertIn(b"SSO account deactivated", channel.result["body"])
 
 
-@skip_unless(HAS_JWT, "requires jwt")
+@skip_unless(HAS_JWT, "requires authlib")
 class JWTTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -866,11 +866,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
         return config
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
-        # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
-        if isinstance(result, bytes):
-            return result.decode("ascii")
-        return result
+        header = {"alg": self.jwt_algorithm}
+        result: bytes = jwt.encode(header, payload, secret)
+        return result.decode("ascii")
 
     def jwt_login(self, *args: Any) -> FakeChannel:
         params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
@@ -902,7 +900,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
-            channel.json_body["error"], "JWT validation failed: Signature has expired"
+            channel.json_body["error"],
+            "JWT validation failed: expired_token: The token is expired",
         )
 
     def test_login_jwt_not_before(self) -> None:
@@ -912,7 +911,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
-            "JWT validation failed: The token is not yet valid (nbf)",
+            "JWT validation failed: invalid_token: The token is not valid yet",
         )
 
     def test_login_no_sub(self) -> None:
@@ -934,7 +933,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
-            channel.json_body["error"], "JWT validation failed: Invalid issuer"
+            channel.json_body["error"],
+            'JWT validation failed: invalid_claim: Invalid claim "iss"',
         )
 
         # Not providing an issuer.
@@ -943,7 +943,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
-            'JWT validation failed: Token is missing the "iss" claim',
+            'JWT validation failed: missing_claim: Missing "iss" claim',
         )
 
     def test_login_iss_no_config(self) -> None:
@@ -965,7 +965,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
-            channel.json_body["error"], "JWT validation failed: Invalid audience"
+            channel.json_body["error"],
+            'JWT validation failed: invalid_claim: Invalid claim "aud"',
         )
 
         # Not providing an audience.
@@ -974,7 +975,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
-            'JWT validation failed: Token is missing the "aud" claim',
+            'JWT validation failed: missing_claim: Missing "aud" claim',
         )
 
     def test_login_aud_no_config(self) -> None:
@@ -983,7 +984,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
-            channel.json_body["error"], "JWT validation failed: Invalid audience"
+            channel.json_body["error"],
+            'JWT validation failed: invalid_claim: Invalid claim "aud"',
         )
 
     def test_login_default_sub(self) -> None:
@@ -1010,7 +1012,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
 # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
 # signed by the private key.
-@skip_unless(HAS_JWT, "requires jwt")
+@skip_unless(HAS_JWT, "requires authlib")
 class JWTPubKeyTestCase(unittest.HomeserverTestCase):
     servlets = [
         login.register_servlets,
@@ -1071,11 +1073,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         return config
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
-        # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
-        if isinstance(result, bytes):
-            return result.decode("ascii")
-        return result
+        header = {"alg": "RS256"}
+        if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
+            secret = jwk.dumps(secret, kty="RSA")
+        result: bytes = jwt.encode(header, payload, secret)
+        return result.decode("ascii")
 
     def jwt_login(self, *args: Any) -> FakeChannel:
         params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}