summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py182
1 files changed, 149 insertions, 33 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 12c50f32f2..0e5be98daa 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -75,13 +75,17 @@ class AuthHandler(BaseHandler):
         logger.info("Extra password_providers: %r", self.password_providers)
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
-        self.device_handler = hs.get_device_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
 
         login_types = set()
         if self._password_enabled:
             login_types.add(LoginType.PASSWORD)
+        for provider in self.password_providers:
+            if hasattr(provider, "get_supported_login_types"):
+                login_types.update(
+                    provider.get_supported_login_types().keys()
+                )
         self._supported_login_types = frozenset(login_types)
 
     @defer.inlineCallbacks
@@ -406,8 +410,7 @@ class AuthHandler(BaseHandler):
         return self.sessions[session_id]
 
     @defer.inlineCallbacks
-    def get_access_token_for_user_id(self, user_id, device_id=None,
-                                     initial_display_name=None):
+    def get_access_token_for_user_id(self, user_id, device_id=None):
         """
         Creates a new access token for the user with the given user ID.
 
@@ -421,13 +424,10 @@ class AuthHandler(BaseHandler):
             device_id (str|None): the device ID to associate with the tokens.
                None to leave the tokens unassociated with a device (deprecated:
                we should always have a device ID)
-            initial_display_name (str): display name to associate with the
-               device if it needs re-registering
         Returns:
               The access token for the user's session.
         Raises:
             StoreError if there was a problem storing the token.
-            LoginError if there was an authentication problem.
         """
         logger.info("Logging in user %s on device %s", user_id, device_id)
         access_token = yield self.issue_access_token(user_id, device_id)
@@ -437,9 +437,11 @@ class AuthHandler(BaseHandler):
         # really don't want is active access_tokens without a record of the
         # device, so we double-check it here.
         if device_id is not None:
-            yield self.device_handler.check_device_registered(
-                user_id, device_id, initial_display_name
-            )
+            try:
+                yield self.store.get_device(user_id, device_id)
+            except StoreError:
+                yield self.store.delete_access_token(access_token)
+                raise StoreError(400, "Login raced against device deletion")
 
         defer.returnValue(access_token)
 
@@ -504,14 +506,14 @@ class AuthHandler(BaseHandler):
         return self._supported_login_types
 
     @defer.inlineCallbacks
-    def validate_login(self, user_id, login_submission):
+    def validate_login(self, username, login_submission):
         """Authenticates the user for the /login API
 
         Also used by the user-interactive auth flow to validate
         m.login.password auth types.
 
         Args:
-            user_id (str): user_id supplied by the user
+            username (str): username supplied by the user
             login_submission (dict): the whole of the login submission
                 (including 'type' and other relevant fields)
         Returns:
@@ -522,32 +524,81 @@ class AuthHandler(BaseHandler):
             LoginError if there was an authentication problem.
         """
 
-        if not user_id.startswith('@'):
-            user_id = UserID(
-                user_id, self.hs.hostname
+        if username.startswith('@'):
+            qualified_user_id = username
+        else:
+            qualified_user_id = UserID(
+                username, self.hs.hostname
             ).to_string()
 
         login_type = login_submission.get("type")
+        known_login_type = False
 
-        if login_type != LoginType.PASSWORD:
-            raise SynapseError(400, "Bad login type.")
-        if not self._password_enabled:
-            raise SynapseError(400, "Password login has been disabled.")
-        if "password" not in login_submission:
-            raise SynapseError(400, "Missing parameter: password")
+        # special case to check for "password" for the check_password interface
+        # for the auth providers
+        password = login_submission.get("password")
+        if login_type == LoginType.PASSWORD:
+            if not self._password_enabled:
+                raise SynapseError(400, "Password login has been disabled.")
+            if not password:
+                raise SynapseError(400, "Missing parameter: password")
 
-        password = login_submission["password"]
         for provider in self.password_providers:
-            is_valid = yield provider.check_password(user_id, password)
-            if is_valid:
-                defer.returnValue(user_id)
+            if (hasattr(provider, "check_password")
+                    and login_type == LoginType.PASSWORD):
+                known_login_type = True
+                is_valid = yield provider.check_password(
+                    qualified_user_id, password,
+                )
+                if is_valid:
+                    defer.returnValue(qualified_user_id)
+
+            if (not hasattr(provider, "get_supported_login_types")
+                    or not hasattr(provider, "check_auth")):
+                # this password provider doesn't understand custom login types
+                continue
+
+            supported_login_types = provider.get_supported_login_types()
+            if login_type not in supported_login_types:
+                # this password provider doesn't understand this login type
+                continue
+
+            known_login_type = True
+            login_fields = supported_login_types[login_type]
+
+            missing_fields = []
+            login_dict = {}
+            for f in login_fields:
+                if f not in login_submission:
+                    missing_fields.append(f)
+                else:
+                    login_dict[f] = login_submission[f]
+            if missing_fields:
+                raise SynapseError(
+                    400, "Missing parameters for login type %s: %s" % (
+                        login_type,
+                        missing_fields,
+                    ),
+                )
 
-        canonical_user_id = yield self._check_local_password(
-            user_id, password,
-        )
+            returned_user_id = yield provider.check_auth(
+                username, login_type, login_dict,
+            )
+            if returned_user_id:
+                defer.returnValue(returned_user_id)
+
+        if login_type == LoginType.PASSWORD:
+            known_login_type = True
 
-        if canonical_user_id:
-            defer.returnValue(canonical_user_id)
+            canonical_user_id = yield self._check_local_password(
+                qualified_user_id, password,
+            )
+
+            if canonical_user_id:
+                defer.returnValue(canonical_user_id)
+
+        if not known_login_type:
+            raise SynapseError(400, "Unknown login type %s" % login_type)
 
         # unknown username or invalid password. We raise a 403 here, but note
         # that if we're doing user-interactive login, it turns all LoginErrors
@@ -608,14 +659,59 @@ class AuthHandler(BaseHandler):
             if e.code == 404:
                 raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
             raise e
-        yield self.store.user_delete_access_tokens(
-            user_id, except_access_token_id
+        yield self.delete_access_tokens_for_user(
+            user_id, except_token_id=except_access_token_id,
         )
         yield self.hs.get_pusherpool().remove_pushers_by_user(
             user_id, except_access_token_id
         )
 
     @defer.inlineCallbacks
+    def deactivate_account(self, user_id):
+        """Deactivate a user's account
+
+        Args:
+            user_id (str): ID of user to be deactivated
+
+        Returns:
+            Deferred
+        """
+        # FIXME: Theoretically there is a race here wherein user resets
+        # password using threepid.
+        yield self.delete_access_tokens_for_user(user_id)
+        yield self.store.user_delete_threepids(user_id)
+        yield self.store.user_set_password_hash(user_id, None)
+
+    def delete_access_token(self, access_token):
+        """Invalidate a single access token
+
+        Args:
+            access_token (str): access token to be deleted
+
+        Returns:
+            Deferred
+        """
+        return self.store.delete_access_token(access_token)
+
+    def delete_access_tokens_for_user(self, user_id, except_token_id=None,
+                                      device_id=None):
+        """Invalidate access tokens belonging to a user
+
+        Args:
+            user_id (str):  ID of user the tokens belong to
+            except_token_id (str|None): access_token ID which should *not* be
+                deleted
+            device_id (str|None):  ID of device the tokens are associated with.
+                If None, tokens associated with any device (or no device) will
+                be deleted
+        Returns:
+            Deferred
+        """
+        return self.store.user_delete_access_tokens(
+            user_id, except_token_id=except_token_id, device_id=device_id,
+        )
+
+    @defer.inlineCallbacks
     def add_threepid(self, user_id, medium, address, validated_at):
         # 'Canonicalise' email addresses down to lower case.
         # We've now moving towards the Home Server being the entity that
@@ -732,11 +828,31 @@ class _AccountHandler(object):
         self._check_user_exists = check_user_exists
         self._store = hs.get_datastore()
 
+    def get_qualified_user_id(self, username):
+        """Qualify a user id, if necessary
+
+        Takes a user id provided by the user and adds the @ and :domain to
+        qualify it, if necessary
+
+        Args:
+            username (str): provided user id
+
+        Returns:
+            str: qualified @user:id
+        """
+        if username.startswith('@'):
+            return username
+        return UserID(username, self.hs.hostname).to_string()
+
     def check_user_exists(self, user_id):
-        """Check if user exissts.
+        """Check if user exists.
+
+        Args:
+            user_id (str): Complete @user:id
 
         Returns:
-            Deferred(bool)
+            Deferred[str|None]: Canonical (case-corrected) user_id, or None
+               if the user is not registered.
         """
         return self._check_user_exists(user_id)