summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/jwt.py10
-rw-r--r--synapse/rest/client/login.py46
2 files changed, 43 insertions, 13 deletions
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 7e3c764b2c..49aaca7cf6 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -18,10 +18,10 @@ from synapse.types import JsonDict
 
 from ._base import Config, ConfigError
 
-MISSING_JWT = """Missing jwt library. This is required for jwt login.
+MISSING_AUTHLIB = """Missing authlib library. This is required for jwt login.
 
     Install by running:
-        pip install pyjwt
+        pip install synapse[jwt]
     """
 
 
@@ -43,11 +43,11 @@ class JWTConfig(Config):
             self.jwt_audiences = jwt_config.get("audiences")
 
             try:
-                import jwt
+                from authlib.jose import JsonWebToken
 
-                jwt  # To stop unused lint.
+                JsonWebToken  # To stop unused lint.
             except ImportError:
-                raise ConfigError(MISSING_JWT)
+                raise ConfigError(MISSING_AUTHLIB)
         else:
             self.jwt_enabled = False
             self.jwt_secret = None
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index cf4196ac0a..dd75e40f34 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -420,17 +420,31 @@ class LoginRestServlet(RestServlet):
                 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
             )
 
-        import jwt
+        from authlib.jose import JsonWebToken, JWTClaims
+        from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
+
+        jwt = JsonWebToken([self.jwt_algorithm])
+        claim_options = {}
+        if self.jwt_issuer is not None:
+            claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
+        if self.jwt_audiences is not None:
+            claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
 
         try:
-            payload = jwt.decode(
+            claims = jwt.decode(
                 token,
-                self.jwt_secret,
-                algorithms=[self.jwt_algorithm],
-                issuer=self.jwt_issuer,
-                audience=self.jwt_audiences,
+                key=self.jwt_secret,
+                claims_cls=JWTClaims,
+                claims_options=claim_options,
+            )
+        except BadSignatureError:
+            # We handle this case separately to provide a better error message
+            raise LoginError(
+                403,
+                "JWT validation failed: Signature verification failed",
+                errcode=Codes.FORBIDDEN,
             )
-        except jwt.PyJWTError as e:
+        except JoseError as e:
             # A JWT error occurred, return some info back to the client.
             raise LoginError(
                 403,
@@ -438,7 +452,23 @@ class LoginRestServlet(RestServlet):
                 errcode=Codes.FORBIDDEN,
             )
 
-        user = payload.get(self.jwt_subject_claim, None)
+        try:
+            claims.validate(leeway=120)  # allows 2 min of clock skew
+
+            # Enforce the old behavior which is rolled out in productive
+            # servers: if the JWT contains an 'aud' claim but none is
+            # configured, the login attempt will fail
+            if claims.get("aud") is not None:
+                if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
+                    raise InvalidClaimError("aud")
+        except JoseError as e:
+            raise LoginError(
+                403,
+                "JWT validation failed: %s" % (str(e),),
+                errcode=Codes.FORBIDDEN,
+            )
+
+        user = claims.get(self.jwt_subject_claim, None)
         if user is None:
             raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)