summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v1/login.py25
1 files changed, 19 insertions, 6 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 64d5c58b65..326ffa0056 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -89,12 +89,19 @@ class LoginRestServlet(RestServlet):
     def __init__(self, hs):
         super(LoginRestServlet, self).__init__()
         self.hs = hs
+
+        # JWT configuration variables.
         self.jwt_enabled = hs.config.jwt_enabled
         self.jwt_secret = hs.config.jwt_secret
         self.jwt_algorithm = hs.config.jwt_algorithm
+        self.jwt_issuer = hs.config.jwt_issuer
+        self.jwt_audiences = hs.config.jwt_audiences
+
+        # SSO configuration.
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
+
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self.handlers = hs.get_handlers()
@@ -368,16 +375,22 @@ class LoginRestServlet(RestServlet):
             )
 
         import jwt
-        from jwt.exceptions import InvalidTokenError
 
         try:
             payload = jwt.decode(
-                token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+                token,
+                self.jwt_secret,
+                algorithms=[self.jwt_algorithm],
+                issuer=self.jwt_issuer,
+                audience=self.jwt_audiences,
+            )
+        except jwt.PyJWTError as e:
+            # A JWT error occurred, return some info back to the client.
+            raise LoginError(
+                401,
+                "JWT validation failed: %s" % (str(e),),
+                errcode=Codes.UNAUTHORIZED,
             )
-        except jwt.ExpiredSignatureError:
-            raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
-        except InvalidTokenError:
-            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
 
         user = payload.get("sub", None)
         if user is None: