diff --git a/changelog.d/7827.feature b/changelog.d/7827.feature
new file mode 100644
index 0000000000..0fd116e198
--- /dev/null
+++ b/changelog.d/7827.feature
@@ -0,0 +1 @@
+Add the option to validate the `iss` and `aud` claims for JWT logins.
diff --git a/docs/jwt.md b/docs/jwt.md
index 289d66b365..93b8d05236 100644
--- a/docs/jwt.md
+++ b/docs/jwt.md
@@ -20,8 +20,17 @@ follows:
Note that the login type of `m.login.jwt` is supported, but is deprecated. This
will be removed in a future version of Synapse.
-The `jwt` should encode the local part of the user ID as the standard `sub`
-claim. In the case that the token is not valid, the homeserver must respond with
+The `token` field should include the JSON web token with the following claims:
+
+* The `sub` (subject) claim is required and should encode the local part of the
+ user ID.
+* The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`)
+ claims are optional, but validated if present.
+* The issuer (`iss`) claim is optional, but required and validated if configured.
+* The audience (`aud`) claim is optional, but required and validated if configured.
+ Providing the audience claim when not configured will cause validation to fail.
+
+In the case that the token is not valid, the homeserver must respond with
`401 Unauthorized` and an error code of `M_UNAUTHORIZED`.
(Note that this differs from the token based logins which return a
@@ -55,7 +64,8 @@ sample settings.
Although JSON Web Tokens are typically generated from an external server, the
examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
-1. Configure Synapse with JWT logins:
+1. Configure Synapse with JWT logins, note that this example uses a pre-shared
+ secret and an algorithm of HS256:
```yaml
jwt_config:
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 31ed49da33..578034a31b 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1987,6 +1987,9 @@ sso:
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
# used as the localpart of the mxid.
#
+# Additionally, the expiration time ("exp"), not before time ("nbf"),
+# and issued at ("iat") claims are validated if present.
+#
# Note that this is a non-standard login type and client support is
# expected to be non-existant.
#
@@ -2014,6 +2017,24 @@ sso:
#
#algorithm: "provided-by-your-issuer"
+ # The issuer to validate the "iss" claim against.
+ #
+ # Optional, if provided the "iss" claim will be required and
+ # validated for all JSON web tokens.
+ #
+ #issuer: "provided-by-your-issuer"
+
+ # A list of audiences to validate the "aud" claim against.
+ #
+ # Optional, if provided the "aud" claim will be required and
+ # validated for all JSON web tokens.
+ #
+ # Note that if the "aud" claim is included in a JSON web token then
+ # validation will fail without configuring audiences.
+ #
+ #audiences:
+ # - "provided-by-your-issuer"
+
password_config:
# Uncomment to disable password login
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index fce96b4acf..3252ad9e7f 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -32,6 +32,11 @@ class JWTConfig(Config):
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
+ # The issuer and audiences are optional, if provided, it is asserted
+ # that the claims exist on the JWT.
+ self.jwt_issuer = jwt_config.get("issuer")
+ self.jwt_audiences = jwt_config.get("audiences")
+
try:
import jwt
@@ -42,6 +47,8 @@ class JWTConfig(Config):
self.jwt_enabled = False
self.jwt_secret = None
self.jwt_algorithm = None
+ self.jwt_issuer = None
+ self.jwt_audiences = None
def generate_config_section(self, **kwargs):
return """\
@@ -52,6 +59,9 @@ class JWTConfig(Config):
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
# used as the localpart of the mxid.
#
+ # Additionally, the expiration time ("exp"), not before time ("nbf"),
+ # and issued at ("iat") claims are validated if present.
+ #
# Note that this is a non-standard login type and client support is
# expected to be non-existant.
#
@@ -78,4 +88,22 @@ class JWTConfig(Config):
# Required if 'enabled' is true.
#
#algorithm: "provided-by-your-issuer"
+
+ # The issuer to validate the "iss" claim against.
+ #
+ # Optional, if provided the "iss" claim will be required and
+ # validated for all JSON web tokens.
+ #
+ #issuer: "provided-by-your-issuer"
+
+ # A list of audiences to validate the "aud" claim against.
+ #
+ # Optional, if provided the "aud" claim will be required and
+ # validated for all JSON web tokens.
+ #
+ # Note that if the "aud" claim is included in a JSON web token then
+ # validation will fail without configuring audiences.
+ #
+ #audiences:
+ # - "provided-by-your-issuer"
"""
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 64d5c58b65..326ffa0056 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -89,12 +89,19 @@ class LoginRestServlet(RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__()
self.hs = hs
+
+ # JWT configuration variables.
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
+ self.jwt_issuer = hs.config.jwt_issuer
+ self.jwt_audiences = hs.config.jwt_audiences
+
+ # SSO configuration.
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -368,16 +375,22 @@ class LoginRestServlet(RestServlet):
)
import jwt
- from jwt.exceptions import InvalidTokenError
try:
payload = jwt.decode(
- token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+ token,
+ self.jwt_secret,
+ algorithms=[self.jwt_algorithm],
+ issuer=self.jwt_issuer,
+ audience=self.jwt_audiences,
+ )
+ except jwt.PyJWTError as e:
+ # A JWT error occurred, return some info back to the client.
+ raise LoginError(
+ 401,
+ "JWT validation failed: %s" % (str(e),),
+ errcode=Codes.UNAUTHORIZED,
)
- except jwt.ExpiredSignatureError:
- raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
- except InvalidTokenError:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user = payload.get("sub", None)
if user is None:
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2be7238b00..4413bb3932 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -514,16 +514,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
]
jwt_secret = "secret"
+ jwt_algorithm = "HS256"
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt_enabled = True
self.hs.config.jwt_secret = self.jwt_secret
- self.hs.config.jwt_algorithm = "HS256"
+ self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
def jwt_encode(self, token, secret=jwt_secret):
- return jwt.encode(token, secret, "HS256").decode("ascii")
+ return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
def jwt_login(self, *args):
params = json.dumps(
@@ -548,20 +549,28 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: Signature verification failed",
+ )
def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "JWT expired")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Signature has expired"
+ )
def test_login_jwt_not_before(self):
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: The token is not yet valid (nbf)",
+ )
def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"})
@@ -569,6 +578,88 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ @override_config(
+ {
+ "jwt_config": {
+ "jwt_enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ "issuer": "test-issuer",
+ }
+ }
+ )
+ def test_login_iss(self):
+ """Test validating the issuer claim."""
+ # A valid issuer.
+ channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ # An invalid issuer.
+ channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid issuer"
+ )
+
+ # Not providing an issuer.
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(
+ channel.json_body["error"],
+ 'JWT validation failed: Token is missing the "iss" claim',
+ )
+
+ def test_login_iss_no_config(self):
+ """Test providing an issuer claim without requiring it in the configuration."""
+ channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ @override_config(
+ {
+ "jwt_config": {
+ "jwt_enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ "audiences": ["test-audience"],
+ }
+ }
+ )
+ def test_login_aud(self):
+ """Test validating the audience claim."""
+ # A valid audience.
+ channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ # An invalid audience.
+ channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid audience"
+ )
+
+ # Not providing an audience.
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(
+ channel.json_body["error"],
+ 'JWT validation failed: Token is missing the "aud" claim',
+ )
+
+ def test_login_aud_no_config(self):
+ """Test providing an audience without requiring it in the configuration."""
+ channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid audience"
+ )
+
def test_login_no_token(self):
params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -658,4 +749,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: Signature verification failed",
+ )
|