diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f4ea1209d9..f6efa5fe37 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -14,7 +14,7 @@
import json
import time
import urllib.parse
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode
@@ -41,7 +41,7 @@ from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
try:
- import jwt
+ from authlib.jose import jwk, jwt
HAS_JWT = True
except ImportError:
@@ -841,7 +841,7 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertIn(b"SSO account deactivated", channel.result["body"])
-@skip_unless(HAS_JWT, "requires jwt")
+@skip_unless(HAS_JWT, "requires authlib")
class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -866,11 +866,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
- # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
- if isinstance(result, bytes):
- return result.decode("ascii")
- return result
+ header = {"alg": self.jwt_algorithm}
+ result: bytes = jwt.encode(header, payload, secret)
+ return result.decode("ascii")
def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
@@ -902,7 +900,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Signature has expired"
+ channel.json_body["error"],
+ "JWT validation failed: expired_token: The token is expired",
)
def test_login_jwt_not_before(self) -> None:
@@ -912,7 +911,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
- "JWT validation failed: The token is not yet valid (nbf)",
+ "JWT validation failed: invalid_token: The token is not valid yet",
)
def test_login_no_sub(self) -> None:
@@ -934,7 +933,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Invalid issuer"
+ channel.json_body["error"],
+ 'JWT validation failed: invalid_claim: Invalid claim "iss"',
)
# Not providing an issuer.
@@ -943,7 +943,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
- 'JWT validation failed: Token is missing the "iss" claim',
+ 'JWT validation failed: missing_claim: Missing "iss" claim',
)
def test_login_iss_no_config(self) -> None:
@@ -965,7 +965,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Invalid audience"
+ channel.json_body["error"],
+ 'JWT validation failed: invalid_claim: Invalid claim "aud"',
)
# Not providing an audience.
@@ -974,7 +975,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
- 'JWT validation failed: Token is missing the "aud" claim',
+ 'JWT validation failed: missing_claim: Missing "aud" claim',
)
def test_login_aud_no_config(self) -> None:
@@ -983,7 +984,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Invalid audience"
+ channel.json_body["error"],
+ 'JWT validation failed: invalid_claim: Invalid claim "aud"',
)
def test_login_default_sub(self) -> None:
@@ -1010,7 +1012,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key.
-@skip_unless(HAS_JWT, "requires jwt")
+@skip_unless(HAS_JWT, "requires authlib")
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
@@ -1071,11 +1073,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
- # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
- if isinstance(result, bytes):
- return result.decode("ascii")
- return result
+ header = {"alg": "RS256"}
+ if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
+ secret = jwk.dumps(secret, kty="RSA")
+ result: bytes = jwt.encode(header, payload, secret)
+ return result.decode("ascii")
def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|