summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11361.feature1
-rw-r--r--docs/jwt.md5
-rw-r--r--docs/sample_config.yaml6
-rw-r--r--synapse/config/jwt.py9
-rw-r--r--synapse/rest/client/login.py3
-rw-r--r--tests/rest/client/test_login.py68
6 files changed, 57 insertions, 35 deletions
diff --git a/changelog.d/11361.feature b/changelog.d/11361.feature
new file mode 100644
index 0000000000..24c9244887
--- /dev/null
+++ b/changelog.d/11361.feature
@@ -0,0 +1 @@
+Update the JWT login type to support custom a `sub` claim.
diff --git a/docs/jwt.md b/docs/jwt.md
index 5be9fd26e3..32f58cc0cb 100644
--- a/docs/jwt.md
+++ b/docs/jwt.md
@@ -22,8 +22,9 @@ will be removed in a future version of Synapse.
 
 The `token` field should include the JSON web token with the following claims:
 
-* The `sub` (subject) claim is required and should encode the local part of the
-  user ID.
+* A claim that encodes the local part of the user ID is required. By default,
+  the `sub` (subject) claim is used, or a custom claim can be set in the
+  configuration file.
 * The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`)
   claims are optional, but validated if present.
 * The issuer (`iss`) claim is optional, but required and validated if configured.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index aee300013f..ae476d19ac 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -2039,6 +2039,12 @@ sso:
     #
     #algorithm: "provided-by-your-issuer"
 
+    # Name of the claim containing a unique identifier for the user.
+    #
+    # Optional, defaults to `sub`.
+    #
+    #subject_claim: "sub"
+
     # The issuer to validate the "iss" claim against.
     #
     # Optional, if provided the "iss" claim will be required and
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 9d295f5856..24c3ef01fc 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -31,6 +31,8 @@ class JWTConfig(Config):
             self.jwt_secret = jwt_config["secret"]
             self.jwt_algorithm = jwt_config["algorithm"]
 
+            self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")
+
             # The issuer and audiences are optional, if provided, it is asserted
             # that the claims exist on the JWT.
             self.jwt_issuer = jwt_config.get("issuer")
@@ -46,6 +48,7 @@ class JWTConfig(Config):
             self.jwt_enabled = False
             self.jwt_secret = None
             self.jwt_algorithm = None
+            self.jwt_subject_claim = None
             self.jwt_issuer = None
             self.jwt_audiences = None
 
@@ -88,6 +91,12 @@ class JWTConfig(Config):
             #
             #algorithm: "provided-by-your-issuer"
 
+            # Name of the claim containing a unique identifier for the user.
+            #
+            # Optional, defaults to `sub`.
+            #
+            #subject_claim: "sub"
+
             # The issuer to validate the "iss" claim against.
             #
             # Optional, if provided the "iss" claim will be required and
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 467444a041..00e65c66ac 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -72,6 +72,7 @@ class LoginRestServlet(RestServlet):
         # JWT configuration variables.
         self.jwt_enabled = hs.config.jwt.jwt_enabled
         self.jwt_secret = hs.config.jwt.jwt_secret
+        self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
         self.jwt_algorithm = hs.config.jwt.jwt_algorithm
         self.jwt_issuer = hs.config.jwt.jwt_issuer
         self.jwt_audiences = hs.config.jwt.jwt_audiences
@@ -413,7 +414,7 @@ class LoginRestServlet(RestServlet):
                 errcode=Codes.FORBIDDEN,
             )
 
-        user = payload.get("sub", None)
+        user = payload.get(self.jwt_subject_claim, None)
         if user is None:
             raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
 
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 0b90e3f803..19f5e46537 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -815,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     jwt_secret = "secret"
     jwt_algorithm = "HS256"
+    base_config = {
+        "enabled": True,
+        "secret": jwt_secret,
+        "algorithm": jwt_algorithm,
+    }
 
-    def make_homeserver(self, reactor, clock):
-        self.hs = self.setup_test_homeserver()
-        self.hs.config.jwt.jwt_enabled = True
-        self.hs.config.jwt.jwt_secret = self.jwt_secret
-        self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
-        return self.hs
+    def default_config(self):
+        config = super().default_config()
+
+        # If jwt_config has been defined (eg via @override_config), don't replace it.
+        if config.get("jwt_config") is None:
+            config["jwt_config"] = self.base_config
+
+        return config
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
@@ -879,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Invalid JWT")
 
-    @override_config(
-        {
-            "jwt_config": {
-                "jwt_enabled": True,
-                "secret": jwt_secret,
-                "algorithm": jwt_algorithm,
-                "issuer": "test-issuer",
-            }
-        }
-    )
+    @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
     def test_login_iss(self):
         """Test validating the issuer claim."""
         # A valid issuer.
@@ -919,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
-    @override_config(
-        {
-            "jwt_config": {
-                "jwt_enabled": True,
-                "secret": jwt_secret,
-                "algorithm": jwt_algorithm,
-                "audiences": ["test-audience"],
-            }
-        }
-    )
+    @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
     def test_login_aud(self):
         """Test validating the audience claim."""
         # A valid audience.
@@ -962,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"], "JWT validation failed: Invalid audience"
         )
 
+    def test_login_default_sub(self):
+        """Test reading user ID from the default subject claim."""
+        channel = self.jwt_login({"sub": "kermit"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+    @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
+    def test_login_custom_sub(self):
+        """Test reading user ID from a custom subject claim."""
+        channel = self.jwt_login({"username": "frog"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@frog:test")
+
     def test_login_no_token(self):
         params = {"type": "org.matrix.login.jwt"}
         channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -1024,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         ]
     )
 
-    def make_homeserver(self, reactor, clock):
-        self.hs = self.setup_test_homeserver()
-        self.hs.config.jwt.jwt_enabled = True
-        self.hs.config.jwt.jwt_secret = self.jwt_pubkey
-        self.hs.config.jwt.jwt_algorithm = "RS256"
-        return self.hs
+    def default_config(self):
+        config = super().default_config()
+        config["jwt_config"] = {
+            "enabled": True,
+            "secret": self.jwt_pubkey,
+            "algorithm": "RS256",
+        }
+        return config
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.