summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7827.feature1
-rw-r--r--docs/jwt.md16
-rw-r--r--docs/sample_config.yaml21
-rw-r--r--synapse/config/jwt_config.py28
-rw-r--r--synapse/rest/client/v1/login.py25
-rw-r--r--tests/rest/client/v1/test_login.py106
6 files changed, 182 insertions, 15 deletions
diff --git a/changelog.d/7827.feature b/changelog.d/7827.feature
new file mode 100644
index 0000000000..0fd116e198
--- /dev/null
+++ b/changelog.d/7827.feature
@@ -0,0 +1 @@
+Add the option to validate the `iss` and `aud` claims for JWT logins.
diff --git a/docs/jwt.md b/docs/jwt.md
index 289d66b365..93b8d05236 100644
--- a/docs/jwt.md
+++ b/docs/jwt.md
@@ -20,8 +20,17 @@ follows:
 Note that the login type of `m.login.jwt` is supported, but is deprecated. This
 will be removed in a future version of Synapse.
 
-The `jwt` should encode the local part of the user ID as the standard `sub`
-claim. In the case that the token is not valid, the homeserver must respond with
+The `token` field should include the JSON web token with the following claims:
+
+* The `sub` (subject) claim is required and should encode the local part of the
+  user ID.
+* The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`)
+  claims are optional, but validated if present.
+* The issuer (`iss`) claim is optional, but required and validated if configured.
+* The audience (`aud`) claim is optional, but required and validated if configured.
+  Providing the audience claim when not configured will cause validation to fail.
+
+In the case that the token is not valid, the homeserver must respond with
 `401 Unauthorized` and an error code of `M_UNAUTHORIZED`.
 
 (Note that this differs from the token based logins which return a
@@ -55,7 +64,8 @@ sample settings.
 Although JSON Web Tokens are typically generated from an external server, the
 examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
 
-1.  Configure Synapse with JWT logins:
+1.  Configure Synapse with JWT logins, note that this example uses a pre-shared
+    secret and an algorithm of HS256:
 
     ```yaml
     jwt_config:
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 1a2d9fb153..9d94495464 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1812,6 +1812,9 @@ sso:
 # Each JSON Web Token needs to contain a "sub" (subject) claim, which is
 # used as the localpart of the mxid.
 #
+# Additionally, the expiration time ("exp"), not before time ("nbf"),
+# and issued at ("iat") claims are validated if present.
+#
 # Note that this is a non-standard login type and client support is
 # expected to be non-existant.
 #
@@ -1839,6 +1842,24 @@ sso:
     #
     #algorithm: "provided-by-your-issuer"
 
+    # The issuer to validate the "iss" claim against.
+    #
+    # Optional, if provided the "iss" claim will be required and
+    # validated for all JSON web tokens.
+    #
+    #issuer: "provided-by-your-issuer"
+
+    # A list of audiences to validate the "aud" claim against.
+    #
+    # Optional, if provided the "aud" claim will be required and
+    # validated for all JSON web tokens.
+    #
+    # Note that if the "aud" claim is included in a JSON web token then
+    # validation will fail without configuring audiences.
+    #
+    #audiences:
+    #    - "provided-by-your-issuer"
+
 
 password_config:
    # Uncomment to disable password login
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index fce96b4acf..3252ad9e7f 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -32,6 +32,11 @@ class JWTConfig(Config):
             self.jwt_secret = jwt_config["secret"]
             self.jwt_algorithm = jwt_config["algorithm"]
 
+            # The issuer and audiences are optional, if provided, it is asserted
+            # that the claims exist on the JWT.
+            self.jwt_issuer = jwt_config.get("issuer")
+            self.jwt_audiences = jwt_config.get("audiences")
+
             try:
                 import jwt
 
@@ -42,6 +47,8 @@ class JWTConfig(Config):
             self.jwt_enabled = False
             self.jwt_secret = None
             self.jwt_algorithm = None
+            self.jwt_issuer = None
+            self.jwt_audiences = None
 
     def generate_config_section(self, **kwargs):
         return """\
@@ -52,6 +59,9 @@ class JWTConfig(Config):
         # Each JSON Web Token needs to contain a "sub" (subject) claim, which is
         # used as the localpart of the mxid.
         #
+        # Additionally, the expiration time ("exp"), not before time ("nbf"),
+        # and issued at ("iat") claims are validated if present.
+        #
         # Note that this is a non-standard login type and client support is
         # expected to be non-existant.
         #
@@ -78,4 +88,22 @@ class JWTConfig(Config):
             # Required if 'enabled' is true.
             #
             #algorithm: "provided-by-your-issuer"
+
+            # The issuer to validate the "iss" claim against.
+            #
+            # Optional, if provided the "iss" claim will be required and
+            # validated for all JSON web tokens.
+            #
+            #issuer: "provided-by-your-issuer"
+
+            # A list of audiences to validate the "aud" claim against.
+            #
+            # Optional, if provided the "aud" claim will be required and
+            # validated for all JSON web tokens.
+            #
+            # Note that if the "aud" claim is included in a JSON web token then
+            # validation will fail without configuring audiences.
+            #
+            #audiences:
+            #    - "provided-by-your-issuer"
         """
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 64d5c58b65..326ffa0056 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -89,12 +89,19 @@ class LoginRestServlet(RestServlet):
     def __init__(self, hs):
         super(LoginRestServlet, self).__init__()
         self.hs = hs
+
+        # JWT configuration variables.
         self.jwt_enabled = hs.config.jwt_enabled
         self.jwt_secret = hs.config.jwt_secret
         self.jwt_algorithm = hs.config.jwt_algorithm
+        self.jwt_issuer = hs.config.jwt_issuer
+        self.jwt_audiences = hs.config.jwt_audiences
+
+        # SSO configuration.
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
+
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self.handlers = hs.get_handlers()
@@ -368,16 +375,22 @@ class LoginRestServlet(RestServlet):
             )
 
         import jwt
-        from jwt.exceptions import InvalidTokenError
 
         try:
             payload = jwt.decode(
-                token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+                token,
+                self.jwt_secret,
+                algorithms=[self.jwt_algorithm],
+                issuer=self.jwt_issuer,
+                audience=self.jwt_audiences,
+            )
+        except jwt.PyJWTError as e:
+            # A JWT error occurred, return some info back to the client.
+            raise LoginError(
+                401,
+                "JWT validation failed: %s" % (str(e),),
+                errcode=Codes.UNAUTHORIZED,
             )
-        except jwt.ExpiredSignatureError:
-            raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
-        except InvalidTokenError:
-            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
 
         user = payload.get("sub", None)
         if user is None:
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2be7238b00..4413bb3932 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -514,16 +514,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
     ]
 
     jwt_secret = "secret"
+    jwt_algorithm = "HS256"
 
     def make_homeserver(self, reactor, clock):
         self.hs = self.setup_test_homeserver()
         self.hs.config.jwt_enabled = True
         self.hs.config.jwt_secret = self.jwt_secret
-        self.hs.config.jwt_algorithm = "HS256"
+        self.hs.config.jwt_algorithm = self.jwt_algorithm
         return self.hs
 
     def jwt_encode(self, token, secret=jwt_secret):
-        return jwt.encode(token, secret, "HS256").decode("ascii")
+        return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
 
     def jwt_login(self, *args):
         params = json.dumps(
@@ -548,20 +549,28 @@ class JWTTestCase(unittest.HomeserverTestCase):
         channel = self.jwt_login({"sub": "frog"}, "notsecret")
         self.assertEqual(channel.result["code"], b"401", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "Invalid JWT")
+        self.assertEqual(
+            channel.json_body["error"],
+            "JWT validation failed: Signature verification failed",
+        )
 
     def test_login_jwt_expired(self):
         channel = self.jwt_login({"sub": "frog", "exp": 864000})
         self.assertEqual(channel.result["code"], b"401", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "JWT expired")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Signature has expired"
+        )
 
     def test_login_jwt_not_before(self):
         now = int(time.time())
         channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
         self.assertEqual(channel.result["code"], b"401", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "Invalid JWT")
+        self.assertEqual(
+            channel.json_body["error"],
+            "JWT validation failed: The token is not yet valid (nbf)",
+        )
 
     def test_login_no_sub(self):
         channel = self.jwt_login({"username": "root"})
@@ -569,6 +578,88 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
         self.assertEqual(channel.json_body["error"], "Invalid JWT")
 
+    @override_config(
+        {
+            "jwt_config": {
+                "jwt_enabled": True,
+                "secret": jwt_secret,
+                "algorithm": jwt_algorithm,
+                "issuer": "test-issuer",
+            }
+        }
+    )
+    def test_login_iss(self):
+        """Test validating the issuer claim."""
+        # A valid issuer.
+        channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+        # An invalid issuer.
+        channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Invalid issuer"
+        )
+
+        # Not providing an issuer.
+        channel = self.jwt_login({"sub": "kermit"})
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(
+            channel.json_body["error"],
+            'JWT validation failed: Token is missing the "iss" claim',
+        )
+
+    def test_login_iss_no_config(self):
+        """Test providing an issuer claim without requiring it in the configuration."""
+        channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+    @override_config(
+        {
+            "jwt_config": {
+                "jwt_enabled": True,
+                "secret": jwt_secret,
+                "algorithm": jwt_algorithm,
+                "audiences": ["test-audience"],
+            }
+        }
+    )
+    def test_login_aud(self):
+        """Test validating the audience claim."""
+        # A valid audience.
+        channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+        # An invalid audience.
+        channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Invalid audience"
+        )
+
+        # Not providing an audience.
+        channel = self.jwt_login({"sub": "kermit"})
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(
+            channel.json_body["error"],
+            'JWT validation failed: Token is missing the "aud" claim',
+        )
+
+    def test_login_aud_no_config(self):
+        """Test providing an audience without requiring it in the configuration."""
+        channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Invalid audience"
+        )
+
     def test_login_no_token(self):
         params = json.dumps({"type": "org.matrix.login.jwt"})
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -658,4 +749,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
         self.assertEqual(channel.result["code"], b"401", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "Invalid JWT")
+        self.assertEqual(
+            channel.json_body["error"],
+            "JWT validation failed: Signature verification failed",
+        )