diff options
-rw-r--r-- | changelog.d/11361.feature | 1 | ||||
-rw-r--r-- | docs/jwt.md | 5 | ||||
-rw-r--r-- | docs/sample_config.yaml | 6 | ||||
-rw-r--r-- | synapse/config/jwt.py | 9 | ||||
-rw-r--r-- | synapse/rest/client/login.py | 3 | ||||
-rw-r--r-- | tests/rest/client/test_login.py | 68 |
6 files changed, 57 insertions, 35 deletions
diff --git a/changelog.d/11361.feature b/changelog.d/11361.feature new file mode 100644 index 0000000000..24c9244887 --- /dev/null +++ b/changelog.d/11361.feature @@ -0,0 +1 @@ +Update the JWT login type to support custom a `sub` claim. diff --git a/docs/jwt.md b/docs/jwt.md index 5be9fd26e3..32f58cc0cb 100644 --- a/docs/jwt.md +++ b/docs/jwt.md @@ -22,8 +22,9 @@ will be removed in a future version of Synapse. The `token` field should include the JSON web token with the following claims: -* The `sub` (subject) claim is required and should encode the local part of the - user ID. +* A claim that encodes the local part of the user ID is required. By default, + the `sub` (subject) claim is used, or a custom claim can be set in the + configuration file. * The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`) claims are optional, but validated if present. * The issuer (`iss`) claim is optional, but required and validated if configured. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index aee300013f..ae476d19ac 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2039,6 +2039,12 @@ sso: # #algorithm: "provided-by-your-issuer" + # Name of the claim containing a unique identifier for the user. + # + # Optional, defaults to `sub`. + # + #subject_claim: "sub" + # The issuer to validate the "iss" claim against. # # Optional, if provided the "iss" claim will be required and diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py index 9d295f5856..24c3ef01fc 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt.py @@ -31,6 +31,8 @@ class JWTConfig(Config): self.jwt_secret = jwt_config["secret"] self.jwt_algorithm = jwt_config["algorithm"] + self.jwt_subject_claim = jwt_config.get("subject_claim", "sub") + # The issuer and audiences are optional, if provided, it is asserted # that the claims exist on the JWT. self.jwt_issuer = jwt_config.get("issuer") @@ -46,6 +48,7 @@ class JWTConfig(Config): self.jwt_enabled = False self.jwt_secret = None self.jwt_algorithm = None + self.jwt_subject_claim = None self.jwt_issuer = None self.jwt_audiences = None @@ -88,6 +91,12 @@ class JWTConfig(Config): # #algorithm: "provided-by-your-issuer" + # Name of the claim containing a unique identifier for the user. + # + # Optional, defaults to `sub`. + # + #subject_claim: "sub" + # The issuer to validate the "iss" claim against. # # Optional, if provided the "iss" claim will be required and diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 467444a041..00e65c66ac 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -72,6 +72,7 @@ class LoginRestServlet(RestServlet): # JWT configuration variables. self.jwt_enabled = hs.config.jwt.jwt_enabled self.jwt_secret = hs.config.jwt.jwt_secret + self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim self.jwt_algorithm = hs.config.jwt.jwt_algorithm self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences @@ -413,7 +414,7 @@ class LoginRestServlet(RestServlet): errcode=Codes.FORBIDDEN, ) - user = payload.get("sub", None) + user = payload.get(self.jwt_subject_claim, None) if user is None: raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 0b90e3f803..19f5e46537 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -815,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase): jwt_secret = "secret" jwt_algorithm = "HS256" + base_config = { + "enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + } - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt.jwt_enabled = True - self.hs.config.jwt.jwt_secret = self.jwt_secret - self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm - return self.hs + def default_config(self): + config = super().default_config() + + # If jwt_config has been defined (eg via @override_config), don't replace it. + if config.get("jwt_config") is None: + config["jwt_config"] = self.base_config + + return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. @@ -879,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "issuer": "test-issuer", - } - } - ) + @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) def test_login_iss(self): """Test validating the issuer claim.""" # A valid issuer. @@ -919,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "audiences": ["test-audience"], - } - } - ) + @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) def test_login_aud(self): """Test validating the audience claim.""" # A valid audience. @@ -962,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Invalid audience" ) + def test_login_default_sub(self): + """Test reading user ID from the default subject claim.""" + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) + def test_login_custom_sub(self): + """Test reading user ID from a custom subject claim.""" + channel = self.jwt_login({"username": "frog"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@frog:test") + def test_login_no_token(self): params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) @@ -1024,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): ] ) - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt.jwt_enabled = True - self.hs.config.jwt.jwt_secret = self.jwt_pubkey - self.hs.config.jwt.jwt_algorithm = "RS256" - return self.hs + def default_config(self): + config = super().default_config() + config["jwt_config"] = { + "enabled": True, + "secret": self.jwt_pubkey, + "algorithm": "RS256", + } + return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. |