summary refs log tree commit diff
diff options
context:
space:
mode:
authorHannes Lerchl <aytchell@users.noreply.github.com>2022-06-15 18:45:16 +0200
committerGitHub <noreply@github.com>2022-06-15 16:45:16 +0000
commit7d99414edf2c5c7e602a88c72245add665e6afb4 (patch)
treeda17d91c48acdae424833784f40efc29a14c4416
parentSort failing jobs in Complement CI to the top of the logs to make them easier... (diff)
downloadsynapse-7d99414edf2c5c7e602a88c72245add665e6afb4.tar.xz
Replace pyjwt with authlib in `org.matrix.login.jwt` (#13011)
-rw-r--r--changelog.d/13011.misc1
-rw-r--r--docs/jwt.md35
-rw-r--r--docs/usage/configuration/config_documentation.md6
-rw-r--r--poetry.lock8
-rw-r--r--pyproject.toml7
-rw-r--r--synapse/config/jwt.py10
-rw-r--r--synapse/rest/client/login.py46
-rw-r--r--tests/rest/client/test_login.py44
8 files changed, 100 insertions, 57 deletions
diff --git a/changelog.d/13011.misc b/changelog.d/13011.misc
new file mode 100644
index 0000000000..4da223219f
--- /dev/null
+++ b/changelog.d/13011.misc
@@ -0,0 +1 @@
+Replaced usage of PyJWT with methods from Authlib in `org.matrix.login.jwt`. Contributed by Hannes Lerchl.
diff --git a/docs/jwt.md b/docs/jwt.md
index 346daf78ad..8f859d59a6 100644
--- a/docs/jwt.md
+++ b/docs/jwt.md
@@ -37,19 +37,19 @@ As with other login types, there are additional fields (e.g. `device_id` and
 ## Preparing Synapse
 
 The JSON Web Token integration in Synapse uses the
-[`PyJWT`](https://pypi.org/project/pyjwt/) library, which must be installed
+[`Authlib`](https://docs.authlib.org/en/latest/index.html) library, which must be installed
 as follows:
 
- * The relevant libraries are included in the Docker images and Debian packages
-   provided by `matrix.org` so no further action is needed.
+* The relevant libraries are included in the Docker images and Debian packages
+  provided by `matrix.org` so no further action is needed.
 
- * If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip
-   install synapse[pyjwt]` to install the necessary dependencies.
+* If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip
+  install synapse[jwt]` to install the necessary dependencies.
 
- * For other installation mechanisms, see the documentation provided by the
-   maintainer.
+* For other installation mechanisms, see the documentation provided by the
+  maintainer.
 
-To enable the JSON web token integration, you should then add an `jwt_config` section
+To enable the JSON web token integration, you should then add a `jwt_config` section
 to your configuration file (or uncomment the `enabled: true` line in the
 existing section). See [sample_config.yaml](./sample_config.yaml) for some
 sample settings.
@@ -57,7 +57,7 @@ sample settings.
 ## How to test JWT as a developer
 
 Although JSON Web Tokens are typically generated from an external server, the
-examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
+example below uses a locally generated JWT.
 
 1.  Configure Synapse with JWT logins, note that this example uses a pre-shared
     secret and an algorithm of HS256:
@@ -70,10 +70,21 @@ examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
     ```
 2.  Generate a JSON web token:
 
-    ```bash
-    $ pyjwt --key=my-secret-token --alg=HS256 encode sub=test-user
-    eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.Ag71GT8v01UO3w80aqRPTeuVPBIBZkYhNTJJ-_-zQIc
+    You can use the following short Python snippet to generate a JWT
+    protected by an HMAC.
+    Take care that the `secret` and the algorithm given in the `header` match
+    the entries from `jwt_config` above.
+
+    ```python
+    from authlib.jose import jwt
+
+    header = {"alg": "HS256"}
+    payload = {"sub": "user1", "aud": ["audience"]}
+    secret = "my-secret-token"
+    result = jwt.encode(header, payload, secret)
+    print(result.decode("ascii"))
     ```
+
 3.  Query for the login types and ensure `org.matrix.login.jwt` is there:
 
     ```bash
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index 392ae80a75..e88f68d2b8 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -2946,8 +2946,10 @@ Additional sub-options for this setting include:
    tokens. Defaults to false.
 * `secret`: This is either the private shared secret or the public key used to
    decode the contents of the JSON web token. Required if `enabled` is set to true.
-* `algorithm`: The algorithm used to sign the JSON web token. Supported algorithms are listed at
-   https://pyjwt.readthedocs.io/en/latest/algorithms.html Required if `enabled` is set to true.
+* `algorithm`: The algorithm used to sign (or HMAC) the JSON web token.
+   Supported algorithms are listed
+   [here (section JWS)](https://docs.authlib.org/en/latest/specs/rfc7518.html).
+   Required if `enabled` is set to true.
 * `subject_claim`: Name of the claim containing a unique identifier for the user.
    Optional, defaults to `sub`.
 * `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the 
diff --git a/poetry.lock b/poetry.lock
index 6a67f59bca..849e8a7a99 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -815,7 +815,7 @@ python-versions = ">=3.5"
 name = "pyjwt"
 version = "2.4.0"
 description = "JSON Web Token implementation in Python"
-category = "main"
+category = "dev"
 optional = false
 python-versions = ">=3.6"
 
@@ -1546,9 +1546,9 @@ docs = ["sphinx", "repoze.sphinx.autointerface"]
 test = ["zope.i18nmessageid", "zope.testing", "zope.testrunner"]
 
 [extras]
-all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "pyjwt", "txredisapi", "hiredis", "Pympler"]
+all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "txredisapi", "hiredis", "Pympler"]
 cache_memory = ["Pympler"]
-jwt = ["pyjwt"]
+jwt = ["authlib"]
 matrix-synapse-ldap3 = ["matrix-synapse-ldap3"]
 oidc = ["authlib"]
 opentracing = ["jaeger-client", "opentracing"]
@@ -1563,7 +1563,7 @@ url_preview = ["lxml"]
 [metadata]
 lock-version = "1.1"
 python-versions = "^3.7.1"
-content-hash = "37bd4bccfdb5a869635f2135a85bea4a0729af7375a27de153b4fd9a4aebc195"
+content-hash = "73882e279e0379482f2fc7414cb71addfd408ca48ad508ff8a02b0cb544762af"
 
 [metadata.files]
 attrs = [
diff --git a/pyproject.toml b/pyproject.toml
index 85c2c9534f..44aa775c33 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -175,7 +175,6 @@ lxml = { version = ">=4.2.0", optional = true }
 sentry-sdk = { version = ">=0.7.2", optional = true }
 opentracing = { version = ">=2.2.0", optional = true }
 jaeger-client = { version = ">=4.0.0", optional = true }
-pyjwt = { version = ">=1.6.4", optional = true }
 txredisapi = { version = ">=1.4.7", optional = true }
 hiredis = { version = "*", optional = true }
 Pympler = { version = "*", optional = true }
@@ -196,7 +195,7 @@ systemd = ["systemd-python"]
 url_preview = ["lxml"]
 sentry = ["sentry-sdk"]
 opentracing = ["jaeger-client", "opentracing"]
-jwt = ["pyjwt"]
+jwt = ["authlib"]
 # hiredis is not a *strict* dependency, but it makes things much faster.
 # (if it is not installed, we fall back to slow code.)
 redis = ["txredisapi", "hiredis"]
@@ -222,7 +221,7 @@ all = [
     "psycopg2", "psycopg2cffi", "psycopg2cffi-compat",
     # saml2
     "pysaml2",
-    # oidc
+    # oidc and jwt
     "authlib",
     # url_preview
     "lxml",
@@ -230,8 +229,6 @@ all = [
     "sentry-sdk",
     # opentracing
     "jaeger-client", "opentracing",
-    # jwt
-    "pyjwt",
     # redis
     "txredisapi", "hiredis",
     # cache_memory
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 7e3c764b2c..49aaca7cf6 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -18,10 +18,10 @@ from synapse.types import JsonDict
 
 from ._base import Config, ConfigError
 
-MISSING_JWT = """Missing jwt library. This is required for jwt login.
+MISSING_AUTHLIB = """Missing authlib library. This is required for jwt login.
 
     Install by running:
-        pip install pyjwt
+        pip install synapse[jwt]
     """
 
 
@@ -43,11 +43,11 @@ class JWTConfig(Config):
             self.jwt_audiences = jwt_config.get("audiences")
 
             try:
-                import jwt
+                from authlib.jose import JsonWebToken
 
-                jwt  # To stop unused lint.
+                JsonWebToken  # To stop unused lint.
             except ImportError:
-                raise ConfigError(MISSING_JWT)
+                raise ConfigError(MISSING_AUTHLIB)
         else:
             self.jwt_enabled = False
             self.jwt_secret = None
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index cf4196ac0a..dd75e40f34 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -420,17 +420,31 @@ class LoginRestServlet(RestServlet):
                 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
             )
 
-        import jwt
+        from authlib.jose import JsonWebToken, JWTClaims
+        from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
+
+        jwt = JsonWebToken([self.jwt_algorithm])
+        claim_options = {}
+        if self.jwt_issuer is not None:
+            claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
+        if self.jwt_audiences is not None:
+            claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
 
         try:
-            payload = jwt.decode(
+            claims = jwt.decode(
                 token,
-                self.jwt_secret,
-                algorithms=[self.jwt_algorithm],
-                issuer=self.jwt_issuer,
-                audience=self.jwt_audiences,
+                key=self.jwt_secret,
+                claims_cls=JWTClaims,
+                claims_options=claim_options,
+            )
+        except BadSignatureError:
+            # We handle this case separately to provide a better error message
+            raise LoginError(
+                403,
+                "JWT validation failed: Signature verification failed",
+                errcode=Codes.FORBIDDEN,
             )
-        except jwt.PyJWTError as e:
+        except JoseError as e:
             # A JWT error occurred, return some info back to the client.
             raise LoginError(
                 403,
@@ -438,7 +452,23 @@ class LoginRestServlet(RestServlet):
                 errcode=Codes.FORBIDDEN,
             )
 
-        user = payload.get(self.jwt_subject_claim, None)
+        try:
+            claims.validate(leeway=120)  # allows 2 min of clock skew
+
+            # Enforce the old behavior which is rolled out in productive
+            # servers: if the JWT contains an 'aud' claim but none is
+            # configured, the login attempt will fail
+            if claims.get("aud") is not None:
+                if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
+                    raise InvalidClaimError("aud")
+        except JoseError as e:
+            raise LoginError(
+                403,
+                "JWT validation failed: %s" % (str(e),),
+                errcode=Codes.FORBIDDEN,
+            )
+
+        user = claims.get(self.jwt_subject_claim, None)
         if user is None:
             raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
 
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f4ea1209d9..f6efa5fe37 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -14,7 +14,7 @@
 import json
 import time
 import urllib.parse
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional
 from unittest.mock import Mock
 from urllib.parse import urlencode
 
@@ -41,7 +41,7 @@ from tests.test_utils.html_parsers import TestHtmlParser
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
 try:
-    import jwt
+    from authlib.jose import jwk, jwt
 
     HAS_JWT = True
 except ImportError:
@@ -841,7 +841,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         self.assertIn(b"SSO account deactivated", channel.result["body"])
 
 
-@skip_unless(HAS_JWT, "requires jwt")
+@skip_unless(HAS_JWT, "requires authlib")
 class JWTTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -866,11 +866,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
         return config
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
-        # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
-        if isinstance(result, bytes):
-            return result.decode("ascii")
-        return result
+        header = {"alg": self.jwt_algorithm}
+        result: bytes = jwt.encode(header, payload, secret)
+        return result.decode("ascii")
 
     def jwt_login(self, *args: Any) -> FakeChannel:
         params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
@@ -902,7 +900,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
-            channel.json_body["error"], "JWT validation failed: Signature has expired"
+            channel.json_body["error"],
+            "JWT validation failed: expired_token: The token is expired",
         )
 
     def test_login_jwt_not_before(self) -> None:
@@ -912,7 +911,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
-            "JWT validation failed: The token is not yet valid (nbf)",
+            "JWT validation failed: invalid_token: The token is not valid yet",
         )
 
     def test_login_no_sub(self) -> None:
@@ -934,7 +933,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
-            channel.json_body["error"], "JWT validation failed: Invalid issuer"
+            channel.json_body["error"],
+            'JWT validation failed: invalid_claim: Invalid claim "iss"',
         )
 
         # Not providing an issuer.
@@ -943,7 +943,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
-            'JWT validation failed: Token is missing the "iss" claim',
+            'JWT validation failed: missing_claim: Missing "iss" claim',
         )
 
     def test_login_iss_no_config(self) -> None:
@@ -965,7 +965,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
-            channel.json_body["error"], "JWT validation failed: Invalid audience"
+            channel.json_body["error"],
+            'JWT validation failed: invalid_claim: Invalid claim "aud"',
         )
 
         # Not providing an audience.
@@ -974,7 +975,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
-            'JWT validation failed: Token is missing the "aud" claim',
+            'JWT validation failed: missing_claim: Missing "aud" claim',
         )
 
     def test_login_aud_no_config(self) -> None:
@@ -983,7 +984,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
-            channel.json_body["error"], "JWT validation failed: Invalid audience"
+            channel.json_body["error"],
+            'JWT validation failed: invalid_claim: Invalid claim "aud"',
         )
 
     def test_login_default_sub(self) -> None:
@@ -1010,7 +1012,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
 # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
 # signed by the private key.
-@skip_unless(HAS_JWT, "requires jwt")
+@skip_unless(HAS_JWT, "requires authlib")
 class JWTPubKeyTestCase(unittest.HomeserverTestCase):
     servlets = [
         login.register_servlets,
@@ -1071,11 +1073,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         return config
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
-        # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
-        if isinstance(result, bytes):
-            return result.decode("ascii")
-        return result
+        header = {"alg": "RS256"}
+        if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
+            secret = jwk.dumps(secret, kty="RSA")
+        result: bytes = jwt.encode(header, payload, secret)
+        return result.decode("ascii")
 
     def jwt_login(self, *args: Any) -> FakeChannel:
         params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}