diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index fce96b4acf..3252ad9e7f 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -32,6 +32,11 @@ class JWTConfig(Config):
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
+ # The issuer and audiences are optional, if provided, it is asserted
+ # that the claims exist on the JWT.
+ self.jwt_issuer = jwt_config.get("issuer")
+ self.jwt_audiences = jwt_config.get("audiences")
+
try:
import jwt
@@ -42,6 +47,8 @@ class JWTConfig(Config):
self.jwt_enabled = False
self.jwt_secret = None
self.jwt_algorithm = None
+ self.jwt_issuer = None
+ self.jwt_audiences = None
def generate_config_section(self, **kwargs):
return """\
@@ -52,6 +59,9 @@ class JWTConfig(Config):
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
# used as the localpart of the mxid.
#
+ # Additionally, the expiration time ("exp"), not before time ("nbf"),
+ # and issued at ("iat") claims are validated if present.
+ #
# Note that this is a non-standard login type and client support is
# expected to be non-existant.
#
@@ -78,4 +88,22 @@ class JWTConfig(Config):
# Required if 'enabled' is true.
#
#algorithm: "provided-by-your-issuer"
+
+ # The issuer to validate the "iss" claim against.
+ #
+ # Optional, if provided the "iss" claim will be required and
+ # validated for all JSON web tokens.
+ #
+ #issuer: "provided-by-your-issuer"
+
+ # A list of audiences to validate the "aud" claim against.
+ #
+ # Optional, if provided the "aud" claim will be required and
+ # validated for all JSON web tokens.
+ #
+ # Note that if the "aud" claim is included in a JSON web token then
+ # validation will fail without configuring audiences.
+ #
+ #audiences:
+ # - "provided-by-your-issuer"
"""
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 64d5c58b65..326ffa0056 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -89,12 +89,19 @@ class LoginRestServlet(RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__()
self.hs = hs
+
+ # JWT configuration variables.
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
+ self.jwt_issuer = hs.config.jwt_issuer
+ self.jwt_audiences = hs.config.jwt_audiences
+
+ # SSO configuration.
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -368,16 +375,22 @@ class LoginRestServlet(RestServlet):
)
import jwt
- from jwt.exceptions import InvalidTokenError
try:
payload = jwt.decode(
- token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+ token,
+ self.jwt_secret,
+ algorithms=[self.jwt_algorithm],
+ issuer=self.jwt_issuer,
+ audience=self.jwt_audiences,
+ )
+ except jwt.PyJWTError as e:
+ # A JWT error occurred, return some info back to the client.
+ raise LoginError(
+ 401,
+ "JWT validation failed: %s" % (str(e),),
+ errcode=Codes.UNAUTHORIZED,
)
- except jwt.ExpiredSignatureError:
- raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
- except InvalidTokenError:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user = payload.get("sub", None)
if user is None:
|