diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index b90ad6d79e..3f116e5b44 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -27,7 +28,8 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@@ -37,17 +39,25 @@ class LoginRestServlet(RestServlet):
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
- JWT_TYPE = "m.login.jwt"
+ JWT_TYPE = "org.matrix.login.jwt"
+ JWT_TYPE_DEPRECATED = "m.login.jwt"
def __init__(self, hs):
super(LoginRestServlet, self).__init__()
self.hs = hs
+
+ # JWT configuration variables.
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
+ self.jwt_issuer = hs.config.jwt_issuer
+ self.jwt_audiences = hs.config.jwt_audiences
+
+ # SSO configuration.
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -68,10 +78,11 @@ class LoginRestServlet(RestServlet):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
- def on_GET(self, request):
+ def on_GET(self, request: SynapseRequest):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
+ flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})
if self.cas_enabled:
# we advertise CAS for backwards compat, though MSC1721 renamed it
@@ -95,20 +106,21 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
- def on_OPTIONS(self, request):
+ def on_OPTIONS(self, request: SynapseRequest):
return 200, {}
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
try:
if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
+ or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
- result = await self.do_jwt_login(login_submission)
+ result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- result = await self.do_token_login(login_submission)
+ result = await self._do_token_login(login_submission)
else:
result = await self._do_other_login(login_submission)
except KeyError as e:
@@ -120,14 +132,14 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- async def _do_other_login(self, login_submission):
+ async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
Args:
login_submission:
Returns:
- dict: HTTP response
+ HTTP response
"""
# Log the request we got, but only certain fields to minimise the chance of
# logging someone's password (even if they accidentally put it in the wrong
@@ -147,9 +159,16 @@ class LoginRestServlet(RestServlet):
medium = login_submission.get("medium")
address = login_submission.get("address")
if medium and address:
- self._failed_attempts_ratelimiter.ratelimit(
- (medium, address.lower()), update=False
- )
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
+ if medium == "email":
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
+
+ self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
# Extract a localpart or user ID from the values in the identifier
username = await self.auth_handler.username_from_identifier(
@@ -192,8 +211,14 @@ class LoginRestServlet(RestServlet):
return result
async def _complete_login(
- self, user_id, login_submission, callback=None, create_non_existent_users=False
- ):
+ self,
+ user_id: str,
+ login_submission: JsonDict,
+ callback: Optional[
+ Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
+ ] = None,
+ create_non_existent_users: bool = False,
+ ) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually log them in (e.g. create devices). This gets called on
all successful logins.
@@ -202,15 +227,14 @@ class LoginRestServlet(RestServlet):
account.
Args:
- user_id (str): ID of the user to register.
- login_submission (dict): Dictionary of login information.
- callback (func|None): Callback function to run after registration.
- create_non_existent_users (bool): Whether to create the user if
- they don't exist. Defaults to False.
+ user_id: ID of the user to register.
+ login_submission: Dictionary of login information.
+ callback: Callback function to run after registration.
+ create_non_existent_users: Whether to create the user if they don't
+ exist. Defaults to False.
Returns:
- result (Dict[str,str]): Dictionary of account information after
- successful registration.
+ result: Dictionary of account information after successful registration.
"""
# Before we actually log them in we check if they've already logged in
@@ -244,7 +268,7 @@ class LoginRestServlet(RestServlet):
return result
- async def do_token_login(self, login_submission):
+ async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -254,28 +278,32 @@ class LoginRestServlet(RestServlet):
result = await self._complete_login(user_id, login_submission)
return result
- async def do_jwt_login(self, login_submission):
+ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission.get("token", None)
if token is None:
raise LoginError(
- 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
+ 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
)
import jwt
- from jwt.exceptions import InvalidTokenError
try:
payload = jwt.decode(
- token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+ token,
+ self.jwt_secret,
+ algorithms=[self.jwt_algorithm],
+ issuer=self.jwt_issuer,
+ audience=self.jwt_audiences,
+ )
+ except jwt.PyJWTError as e:
+ # A JWT error occurred, return some info back to the client.
+ raise LoginError(
+ 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
)
- except jwt.ExpiredSignatureError:
- raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
- except InvalidTokenError:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user = payload.get("sub", None)
if user is None:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
+ raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(
|