diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index acae4d9e0d..93d8ac0e04 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -77,6 +77,12 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
+ self._password_enabled = hs.config.password_enabled
+
+ login_types = set()
+ if self._password_enabled:
+ login_types.add(LoginType.PASSWORD)
+ self._supported_login_types = frozenset(login_types)
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@@ -266,10 +272,11 @@ class AuthHandler(BaseHandler):
user_id = authdict["user"]
password = authdict["password"]
- if not user_id.startswith('@'):
- user_id = UserID(user_id, self.hs.hostname).to_string()
- return self._check_password(user_id, password)
+ return self.validate_login(user_id, {
+ "type": LoginType.PASSWORD,
+ "password": password,
+ })
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -398,23 +405,6 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
- def validate_password_login(self, user_id, password):
- """
- Authenticates the user with their username and password.
-
- Used only by the v1 login API.
-
- Args:
- user_id (str): complete @user:id
- password (str): Password
- Returns:
- defer.Deferred: (str) canonical user id
- Raises:
- StoreError if there was a problem accessing the database
- LoginError if there was an authentication problem.
- """
- return self._check_password(user_id, password)
-
@defer.inlineCallbacks
def get_access_token_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
@@ -501,26 +491,60 @@ class AuthHandler(BaseHandler):
)
defer.returnValue(result)
+ def get_supported_login_types(self):
+ """Get a the login types supported for the /login API
+
+ By default this is just 'm.login.password' (unless password_enabled is
+ False in the config file), but password auth providers can provide
+ other login types.
+
+ Returns:
+ Iterable[str]: login types
+ """
+ return self._supported_login_types
+
@defer.inlineCallbacks
- def _check_password(self, user_id, password):
- """Authenticate a user against the LDAP and local databases.
+ def validate_login(self, user_id, login_submission):
+ """Authenticates the user for the /login API
- user_id is checked case insensitively against the local database, but
- will throw if there are multiple inexact matches.
+ Also used by the user-interactive auth flow to validate
+ m.login.password auth types.
Args:
- user_id (str): complete @user:id
+ user_id (str): user_id supplied by the user
+ login_submission (dict): the whole of the login submission
+ (including 'type' and other relevant fields)
Returns:
- (str) the canonical_user_id
+ Deferred[str]: canonical user id
Raises:
- LoginError if login fails
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
"""
+
+ if not user_id.startswith('@'):
+ user_id = UserID(
+ user_id, self.hs.hostname
+ ).to_string()
+
+ login_type = login_submission.get("type")
+
+ if login_type != LoginType.PASSWORD:
+ raise SynapseError(400, "Bad login type.")
+ if not self._password_enabled:
+ raise SynapseError(400, "Password login has been disabled.")
+ if "password" not in login_submission:
+ raise SynapseError(400, "Missing parameter: password")
+
+ password = login_submission["password"]
for provider in self.password_providers:
is_valid = yield provider.check_password(user_id, password)
if is_valid:
defer.returnValue(user_id)
- canonical_user_id = yield self._check_local_password(user_id, password)
+ canonical_user_id = yield self._check_local_password(
+ user_id, password,
+ )
if canonical_user_id:
defer.returnValue(canonical_user_id)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 9536e8ade6..d24590011b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
- PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
@@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
- self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
@@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- if self.password_enabled:
- flows.append({"type": LoginRestServlet.PASS_TYPE})
+
+ flows.extend((
+ {"type": t} for t in self.auth_handler.get_supported_login_types()
+ ))
return (200, {"flows": flows})
@@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request):
login_submission = parse_json_object_from_request(request)
try:
- if login_submission["type"] == LoginRestServlet.PASS_TYPE:
- if not self.password_enabled:
- raise SynapseError(400, "Password login has been disabled.")
-
- result = yield self.do_password_login(login_submission)
- defer.returnValue(result)
- elif self.saml2_enabled and (login_submission["type"] ==
- LoginRestServlet.SAML2_TYPE):
+ if self.saml2_enabled and (login_submission["type"] ==
+ LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote(
@@ -157,15 +151,21 @@ class LoginRestServlet(ClientV1RestServlet):
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else:
- raise SynapseError(400, "Bad login type.")
+ result = yield self._do_other_login(login_submission)
+ defer.returnValue(result)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks
- def do_password_login(self, login_submission):
- if "password" not in login_submission:
- raise SynapseError(400, "Missing parameter: password")
+ def _do_other_login(self, login_submission):
+ """Handle non-token/saml/jwt logins
+ Args:
+ login_submission:
+
+ Returns:
+ (int, object): HTTP code/response
+ """
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
@@ -208,25 +208,22 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
- user_id = identifier["user"]
-
- if not user_id.startswith('@'):
- user_id = UserID(
- user_id, self.hs.hostname
- ).to_string()
-
auth_handler = self.auth_handler
- user_id = yield auth_handler.validate_password_login(
- user_id=user_id,
- password=login_submission["password"],
+ canonical_user_id = yield auth_handler.validate_login(
+ identifier["user"],
+ login_submission,
+ )
+
+ device_id = yield self._register_device(
+ canonical_user_id, login_submission,
)
- device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
- user_id, device_id,
+ canonical_user_id, device_id,
login_submission.get("initial_device_display_name"),
)
+
result = {
- "user_id": user_id, # may have changed
+ "user_id": canonical_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
|