summary refs log tree commit diff
path: root/synapse/rest/client/login.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/login.py')
-rw-r--r--synapse/rest/client/login.py46
1 files changed, 38 insertions, 8 deletions
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index cf4196ac0a..dd75e40f34 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -420,17 +420,31 @@ class LoginRestServlet(RestServlet):
                 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
             )
 
-        import jwt
+        from authlib.jose import JsonWebToken, JWTClaims
+        from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
+
+        jwt = JsonWebToken([self.jwt_algorithm])
+        claim_options = {}
+        if self.jwt_issuer is not None:
+            claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
+        if self.jwt_audiences is not None:
+            claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
 
         try:
-            payload = jwt.decode(
+            claims = jwt.decode(
                 token,
-                self.jwt_secret,
-                algorithms=[self.jwt_algorithm],
-                issuer=self.jwt_issuer,
-                audience=self.jwt_audiences,
+                key=self.jwt_secret,
+                claims_cls=JWTClaims,
+                claims_options=claim_options,
+            )
+        except BadSignatureError:
+            # We handle this case separately to provide a better error message
+            raise LoginError(
+                403,
+                "JWT validation failed: Signature verification failed",
+                errcode=Codes.FORBIDDEN,
             )
-        except jwt.PyJWTError as e:
+        except JoseError as e:
             # A JWT error occurred, return some info back to the client.
             raise LoginError(
                 403,
@@ -438,7 +452,23 @@ class LoginRestServlet(RestServlet):
                 errcode=Codes.FORBIDDEN,
             )
 
-        user = payload.get(self.jwt_subject_claim, None)
+        try:
+            claims.validate(leeway=120)  # allows 2 min of clock skew
+
+            # Enforce the old behavior which is rolled out in productive
+            # servers: if the JWT contains an 'aud' claim but none is
+            # configured, the login attempt will fail
+            if claims.get("aud") is not None:
+                if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
+                    raise InvalidClaimError("aud")
+        except JoseError as e:
+            raise LoginError(
+                403,
+                "JWT validation failed: %s" % (str(e),),
+                errcode=Codes.FORBIDDEN,
+            )
+
+        user = claims.get(self.jwt_subject_claim, None)
         if user is None:
             raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)