diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 7e3c764b2c..49aaca7cf6 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -18,10 +18,10 @@ from synapse.types import JsonDict
from ._base import Config, ConfigError
-MISSING_JWT = """Missing jwt library. This is required for jwt login.
+MISSING_AUTHLIB = """Missing authlib library. This is required for jwt login.
Install by running:
- pip install pyjwt
+ pip install synapse[jwt]
"""
@@ -43,11 +43,11 @@ class JWTConfig(Config):
self.jwt_audiences = jwt_config.get("audiences")
try:
- import jwt
+ from authlib.jose import JsonWebToken
- jwt # To stop unused lint.
+ JsonWebToken # To stop unused lint.
except ImportError:
- raise ConfigError(MISSING_JWT)
+ raise ConfigError(MISSING_AUTHLIB)
else:
self.jwt_enabled = False
self.jwt_secret = None
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index cf4196ac0a..dd75e40f34 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -420,17 +420,31 @@ class LoginRestServlet(RestServlet):
403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
)
- import jwt
+ from authlib.jose import JsonWebToken, JWTClaims
+ from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
+
+ jwt = JsonWebToken([self.jwt_algorithm])
+ claim_options = {}
+ if self.jwt_issuer is not None:
+ claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
+ if self.jwt_audiences is not None:
+ claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
try:
- payload = jwt.decode(
+ claims = jwt.decode(
token,
- self.jwt_secret,
- algorithms=[self.jwt_algorithm],
- issuer=self.jwt_issuer,
- audience=self.jwt_audiences,
+ key=self.jwt_secret,
+ claims_cls=JWTClaims,
+ claims_options=claim_options,
+ )
+ except BadSignatureError:
+ # We handle this case separately to provide a better error message
+ raise LoginError(
+ 403,
+ "JWT validation failed: Signature verification failed",
+ errcode=Codes.FORBIDDEN,
)
- except jwt.PyJWTError as e:
+ except JoseError as e:
# A JWT error occurred, return some info back to the client.
raise LoginError(
403,
@@ -438,7 +452,23 @@ class LoginRestServlet(RestServlet):
errcode=Codes.FORBIDDEN,
)
- user = payload.get(self.jwt_subject_claim, None)
+ try:
+ claims.validate(leeway=120) # allows 2 min of clock skew
+
+ # Enforce the old behavior which is rolled out in productive
+ # servers: if the JWT contains an 'aud' claim but none is
+ # configured, the login attempt will fail
+ if claims.get("aud") is not None:
+ if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
+ raise InvalidClaimError("aud")
+ except JoseError as e:
+ raise LoginError(
+ 403,
+ "JWT validation failed: %s" % (str(e),),
+ errcode=Codes.FORBIDDEN,
+ )
+
+ user = claims.get(self.jwt_subject_claim, None)
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|