diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 12c50f32f2..0e5be98daa 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -75,13 +75,17 @@ class AuthHandler(BaseHandler):
logger.info("Extra password_providers: %r", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
- self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
login_types = set()
if self._password_enabled:
login_types.add(LoginType.PASSWORD)
+ for provider in self.password_providers:
+ if hasattr(provider, "get_supported_login_types"):
+ login_types.update(
+ provider.get_supported_login_types().keys()
+ )
self._supported_login_types = frozenset(login_types)
@defer.inlineCallbacks
@@ -406,8 +410,7 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
@defer.inlineCallbacks
- def get_access_token_for_user_id(self, user_id, device_id=None,
- initial_display_name=None):
+ def get_access_token_for_user_id(self, user_id, device_id=None):
"""
Creates a new access token for the user with the given user ID.
@@ -421,13 +424,10 @@ class AuthHandler(BaseHandler):
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
- initial_display_name (str): display name to associate with the
- device if it needs re-registering
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
- LoginError if there was an authentication problem.
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
@@ -437,9 +437,11 @@ class AuthHandler(BaseHandler):
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
- yield self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
- )
+ try:
+ yield self.store.get_device(user_id, device_id)
+ except StoreError:
+ yield self.store.delete_access_token(access_token)
+ raise StoreError(400, "Login raced against device deletion")
defer.returnValue(access_token)
@@ -504,14 +506,14 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
@defer.inlineCallbacks
- def validate_login(self, user_id, login_submission):
+ def validate_login(self, username, login_submission):
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Args:
- user_id (str): user_id supplied by the user
+ username (str): username supplied by the user
login_submission (dict): the whole of the login submission
(including 'type' and other relevant fields)
Returns:
@@ -522,32 +524,81 @@ class AuthHandler(BaseHandler):
LoginError if there was an authentication problem.
"""
- if not user_id.startswith('@'):
- user_id = UserID(
- user_id, self.hs.hostname
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(
+ username, self.hs.hostname
).to_string()
login_type = login_submission.get("type")
+ known_login_type = False
- if login_type != LoginType.PASSWORD:
- raise SynapseError(400, "Bad login type.")
- if not self._password_enabled:
- raise SynapseError(400, "Password login has been disabled.")
- if "password" not in login_submission:
- raise SynapseError(400, "Missing parameter: password")
+ # special case to check for "password" for the check_password interface
+ # for the auth providers
+ password = login_submission.get("password")
+ if login_type == LoginType.PASSWORD:
+ if not self._password_enabled:
+ raise SynapseError(400, "Password login has been disabled.")
+ if not password:
+ raise SynapseError(400, "Missing parameter: password")
- password = login_submission["password"]
for provider in self.password_providers:
- is_valid = yield provider.check_password(user_id, password)
- if is_valid:
- defer.returnValue(user_id)
+ if (hasattr(provider, "check_password")
+ and login_type == LoginType.PASSWORD):
+ known_login_type = True
+ is_valid = yield provider.check_password(
+ qualified_user_id, password,
+ )
+ if is_valid:
+ defer.returnValue(qualified_user_id)
+
+ if (not hasattr(provider, "get_supported_login_types")
+ or not hasattr(provider, "check_auth")):
+ # this password provider doesn't understand custom login types
+ continue
+
+ supported_login_types = provider.get_supported_login_types()
+ if login_type not in supported_login_types:
+ # this password provider doesn't understand this login type
+ continue
+
+ known_login_type = True
+ login_fields = supported_login_types[login_type]
+
+ missing_fields = []
+ login_dict = {}
+ for f in login_fields:
+ if f not in login_submission:
+ missing_fields.append(f)
+ else:
+ login_dict[f] = login_submission[f]
+ if missing_fields:
+ raise SynapseError(
+ 400, "Missing parameters for login type %s: %s" % (
+ login_type,
+ missing_fields,
+ ),
+ )
- canonical_user_id = yield self._check_local_password(
- user_id, password,
- )
+ returned_user_id = yield provider.check_auth(
+ username, login_type, login_dict,
+ )
+ if returned_user_id:
+ defer.returnValue(returned_user_id)
+
+ if login_type == LoginType.PASSWORD:
+ known_login_type = True
- if canonical_user_id:
- defer.returnValue(canonical_user_id)
+ canonical_user_id = yield self._check_local_password(
+ qualified_user_id, password,
+ )
+
+ if canonical_user_id:
+ defer.returnValue(canonical_user_id)
+
+ if not known_login_type:
+ raise SynapseError(400, "Unknown login type %s" % login_type)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
@@ -608,14 +659,59 @@ class AuthHandler(BaseHandler):
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
- yield self.store.user_delete_access_tokens(
- user_id, except_access_token_id
+ yield self.delete_access_tokens_for_user(
+ user_id, except_token_id=except_access_token_id,
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id
)
@defer.inlineCallbacks
+ def deactivate_account(self, user_id):
+ """Deactivate a user's account
+
+ Args:
+ user_id (str): ID of user to be deactivated
+
+ Returns:
+ Deferred
+ """
+ # FIXME: Theoretically there is a race here wherein user resets
+ # password using threepid.
+ yield self.delete_access_tokens_for_user(user_id)
+ yield self.store.user_delete_threepids(user_id)
+ yield self.store.user_set_password_hash(user_id, None)
+
+ def delete_access_token(self, access_token):
+ """Invalidate a single access token
+
+ Args:
+ access_token (str): access token to be deleted
+
+ Returns:
+ Deferred
+ """
+ return self.store.delete_access_token(access_token)
+
+ def delete_access_tokens_for_user(self, user_id, except_token_id=None,
+ device_id=None):
+ """Invalidate access tokens belonging to a user
+
+ Args:
+ user_id (str): ID of user the tokens belong to
+ except_token_id (str|None): access_token ID which should *not* be
+ deleted
+ device_id (str|None): ID of device the tokens are associated with.
+ If None, tokens associated with any device (or no device) will
+ be deleted
+ Returns:
+ Deferred
+ """
+ return self.store.user_delete_access_tokens(
+ user_id, except_token_id=except_token_id, device_id=device_id,
+ )
+
+ @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case.
# We've now moving towards the Home Server being the entity that
@@ -732,11 +828,31 @@ class _AccountHandler(object):
self._check_user_exists = check_user_exists
self._store = hs.get_datastore()
+ def get_qualified_user_id(self, username):
+ """Qualify a user id, if necessary
+
+ Takes a user id provided by the user and adds the @ and :domain to
+ qualify it, if necessary
+
+ Args:
+ username (str): provided user id
+
+ Returns:
+ str: qualified @user:id
+ """
+ if username.startswith('@'):
+ return username
+ return UserID(username, self.hs.hostname).to_string()
+
def check_user_exists(self, user_id):
- """Check if user exissts.
+ """Check if user exists.
+
+ Args:
+ user_id (str): Complete @user:id
Returns:
- Deferred(bool)
+ Deferred[str|None]: Canonical (case-corrected) user_id, or None
+ if the user is not registered.
"""
return self._check_user_exists(user_id)
|