summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/v1/test_login.py106
1 files changed, 100 insertions, 6 deletions
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2be7238b00..4413bb3932 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -514,16 +514,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
     ]
 
     jwt_secret = "secret"
+    jwt_algorithm = "HS256"
 
     def make_homeserver(self, reactor, clock):
         self.hs = self.setup_test_homeserver()
         self.hs.config.jwt_enabled = True
         self.hs.config.jwt_secret = self.jwt_secret
-        self.hs.config.jwt_algorithm = "HS256"
+        self.hs.config.jwt_algorithm = self.jwt_algorithm
         return self.hs
 
     def jwt_encode(self, token, secret=jwt_secret):
-        return jwt.encode(token, secret, "HS256").decode("ascii")
+        return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
 
     def jwt_login(self, *args):
         params = json.dumps(
@@ -548,20 +549,28 @@ class JWTTestCase(unittest.HomeserverTestCase):
         channel = self.jwt_login({"sub": "frog"}, "notsecret")
         self.assertEqual(channel.result["code"], b"401", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "Invalid JWT")
+        self.assertEqual(
+            channel.json_body["error"],
+            "JWT validation failed: Signature verification failed",
+        )
 
     def test_login_jwt_expired(self):
         channel = self.jwt_login({"sub": "frog", "exp": 864000})
         self.assertEqual(channel.result["code"], b"401", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "JWT expired")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Signature has expired"
+        )
 
     def test_login_jwt_not_before(self):
         now = int(time.time())
         channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
         self.assertEqual(channel.result["code"], b"401", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "Invalid JWT")
+        self.assertEqual(
+            channel.json_body["error"],
+            "JWT validation failed: The token is not yet valid (nbf)",
+        )
 
     def test_login_no_sub(self):
         channel = self.jwt_login({"username": "root"})
@@ -569,6 +578,88 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
         self.assertEqual(channel.json_body["error"], "Invalid JWT")
 
+    @override_config(
+        {
+            "jwt_config": {
+                "jwt_enabled": True,
+                "secret": jwt_secret,
+                "algorithm": jwt_algorithm,
+                "issuer": "test-issuer",
+            }
+        }
+    )
+    def test_login_iss(self):
+        """Test validating the issuer claim."""
+        # A valid issuer.
+        channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+        # An invalid issuer.
+        channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Invalid issuer"
+        )
+
+        # Not providing an issuer.
+        channel = self.jwt_login({"sub": "kermit"})
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(
+            channel.json_body["error"],
+            'JWT validation failed: Token is missing the "iss" claim',
+        )
+
+    def test_login_iss_no_config(self):
+        """Test providing an issuer claim without requiring it in the configuration."""
+        channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+    @override_config(
+        {
+            "jwt_config": {
+                "jwt_enabled": True,
+                "secret": jwt_secret,
+                "algorithm": jwt_algorithm,
+                "audiences": ["test-audience"],
+            }
+        }
+    )
+    def test_login_aud(self):
+        """Test validating the audience claim."""
+        # A valid audience.
+        channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+        # An invalid audience.
+        channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Invalid audience"
+        )
+
+        # Not providing an audience.
+        channel = self.jwt_login({"sub": "kermit"})
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(
+            channel.json_body["error"],
+            'JWT validation failed: Token is missing the "aud" claim',
+        )
+
+    def test_login_aud_no_config(self):
+        """Test providing an audience without requiring it in the configuration."""
+        channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Invalid audience"
+        )
+
     def test_login_no_token(self):
         params = json.dumps({"type": "org.matrix.login.jwt"})
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -658,4 +749,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
         self.assertEqual(channel.result["code"], b"401", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "Invalid JWT")
+        self.assertEqual(
+            channel.json_body["error"],
+            "JWT validation failed: Signature verification failed",
+        )