summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/jwt_config.py28
-rw-r--r--synapse/rest/client/v1/login.py25
2 files changed, 47 insertions, 6 deletions
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index fce96b4acf..3252ad9e7f 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -32,6 +32,11 @@ class JWTConfig(Config):
             self.jwt_secret = jwt_config["secret"]
             self.jwt_algorithm = jwt_config["algorithm"]
 
+            # The issuer and audiences are optional, if provided, it is asserted
+            # that the claims exist on the JWT.
+            self.jwt_issuer = jwt_config.get("issuer")
+            self.jwt_audiences = jwt_config.get("audiences")
+
             try:
                 import jwt
 
@@ -42,6 +47,8 @@ class JWTConfig(Config):
             self.jwt_enabled = False
             self.jwt_secret = None
             self.jwt_algorithm = None
+            self.jwt_issuer = None
+            self.jwt_audiences = None
 
     def generate_config_section(self, **kwargs):
         return """\
@@ -52,6 +59,9 @@ class JWTConfig(Config):
         # Each JSON Web Token needs to contain a "sub" (subject) claim, which is
         # used as the localpart of the mxid.
         #
+        # Additionally, the expiration time ("exp"), not before time ("nbf"),
+        # and issued at ("iat") claims are validated if present.
+        #
         # Note that this is a non-standard login type and client support is
         # expected to be non-existant.
         #
@@ -78,4 +88,22 @@ class JWTConfig(Config):
             # Required if 'enabled' is true.
             #
             #algorithm: "provided-by-your-issuer"
+
+            # The issuer to validate the "iss" claim against.
+            #
+            # Optional, if provided the "iss" claim will be required and
+            # validated for all JSON web tokens.
+            #
+            #issuer: "provided-by-your-issuer"
+
+            # A list of audiences to validate the "aud" claim against.
+            #
+            # Optional, if provided the "aud" claim will be required and
+            # validated for all JSON web tokens.
+            #
+            # Note that if the "aud" claim is included in a JSON web token then
+            # validation will fail without configuring audiences.
+            #
+            #audiences:
+            #    - "provided-by-your-issuer"
         """
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 64d5c58b65..326ffa0056 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -89,12 +89,19 @@ class LoginRestServlet(RestServlet):
     def __init__(self, hs):
         super(LoginRestServlet, self).__init__()
         self.hs = hs
+
+        # JWT configuration variables.
         self.jwt_enabled = hs.config.jwt_enabled
         self.jwt_secret = hs.config.jwt_secret
         self.jwt_algorithm = hs.config.jwt_algorithm
+        self.jwt_issuer = hs.config.jwt_issuer
+        self.jwt_audiences = hs.config.jwt_audiences
+
+        # SSO configuration.
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
+
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self.handlers = hs.get_handlers()
@@ -368,16 +375,22 @@ class LoginRestServlet(RestServlet):
             )
 
         import jwt
-        from jwt.exceptions import InvalidTokenError
 
         try:
             payload = jwt.decode(
-                token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+                token,
+                self.jwt_secret,
+                algorithms=[self.jwt_algorithm],
+                issuer=self.jwt_issuer,
+                audience=self.jwt_audiences,
+            )
+        except jwt.PyJWTError as e:
+            # A JWT error occurred, return some info back to the client.
+            raise LoginError(
+                401,
+                "JWT validation failed: %s" % (str(e),),
+                errcode=Codes.UNAUTHORIZED,
             )
-        except jwt.ExpiredSignatureError:
-            raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
-        except InvalidTokenError:
-            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
 
         user = payload.get("sub", None)
         if user is None: