summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7438.feature1
-rw-r--r--synapse/handlers/auth.py216
-rw-r--r--synapse/rest/client/v1/login.py159
-rw-r--r--tests/rest/client/v1/test_login.py4
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py8
-rw-r--r--tests/rest/client/v2_alpha/test_register.py92
6 files changed, 333 insertions, 147 deletions
diff --git a/changelog.d/7438.feature b/changelog.d/7438.feature
new file mode 100644
index 0000000000..b00529790e
--- /dev/null
+++ b/changelog.d/7438.feature
@@ -0,0 +1 @@
+Support `identifier` dictionary fields in User-Interactive Authentication flows. Relax requirement of the `user` parameter.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 654f58ddae..2d64ee5e44 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -38,12 +38,14 @@ from synapse.api.ratelimiting import Ratelimiter
 from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.http.server import finish_request, respond_with_html
+from synapse.http.servlet import assert_params_in_dict
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import ModuleApi
 from synapse.types import Requester, UserID
 from synapse.util import stringutils as stringutils
+from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
 from ._base import BaseHandler
@@ -51,6 +53,82 @@ from ._base import BaseHandler
 logger = logging.getLogger(__name__)
 
 
+def client_dict_convert_legacy_fields_to_identifier(
+    submission: Dict[str, Union[str, Dict]]
+):
+    """
+    Convert a legacy-formatted login submission to an identifier dict.
+
+    Legacy login submissions (used in both login and user-interactive authentication)
+    provide user-identifying information at the top-level instead of in an `indentifier`
+    property. This is now deprecated and replaced with identifiers:
+    https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
+
+    Args:
+        submission: The client dict to convert. Passed by reference and modified
+
+    Raises:
+        SynapseError: If the format of the client dict is invalid
+    """
+    if "user" in submission:
+        submission["identifier"] = {"type": "m.id.user", "user": submission.pop("user")}
+
+    if "medium" in submission and "address" in submission:
+        submission["identifier"] = {
+            "type": "m.id.thirdparty",
+            "medium": submission.pop("medium"),
+            "address": submission.pop("address"),
+        }
+
+    # We've converted valid, legacy login submissions to an identifier. If the
+    # dict still doesn't have an identifier, it's invalid
+    assert_params_in_dict(submission, required=["identifier"])
+
+    # Ensure the identifier has a type
+    if "type" not in submission["identifier"]:
+        raise SynapseError(
+            400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
+        )
+
+
+def login_id_phone_to_thirdparty(identifier: Dict[str, str]) -> Dict[str, str]:
+    """Convert a phone login identifier type to a generic threepid identifier.
+
+    Args:
+        identifier: Login identifier dict of type 'm.id.phone'
+
+    Returns:
+        An equivalent m.id.thirdparty identifier dict.
+    """
+    if "type" not in identifier:
+        raise SynapseError(
+            400, "Invalid phone-type identifier", errcode=Codes.MISSING_PARAM
+        )
+
+    if "country" not in identifier or (
+        # XXX: We used to require `number` instead of `phone`. The spec
+        # defines `phone`. So accept both
+        "phone" not in identifier
+        and "number" not in identifier
+    ):
+        raise SynapseError(
+            400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
+        )
+
+    # Accept both "phone" and "number" as valid keys in m.id.phone
+    phone_number = identifier.get("phone", identifier.get("number"))
+
+    # Convert user-provided phone number to a consistent representation
+    msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
+
+    # Return the new dictionary
+    return {
+        "type": "m.id.thirdparty",
+        "medium": "msisdn",
+        "address": msisdn,
+    }
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -319,7 +397,7 @@ class AuthHandler(BaseHandler):
             # otherwise use whatever was last provided.
             #
             # This was designed to allow the client to omit the parameters
-            # and just supply the session in subsequent calls so it split
+            # and just supply the session in subsequent calls. So it splits
             # auth between devices by just sharing the session, (eg. so you
             # could continue registration from your phone having clicked the
             # email auth link on there). It's probably too open to abuse
@@ -524,16 +602,129 @@ class AuthHandler(BaseHandler):
             res = await checker.check_auth(authdict, clientip=clientip)
             return res
 
-        # build a v1-login-style dict out of the authdict and fall back to the
-        # v1 code
-        user_id = authdict.get("user")
+        # We don't have a checker for the auth type provided by the client
+        # Assume that it is `m.login.password`.
+        if login_type != LoginType.PASSWORD:
+            raise SynapseError(
+                400, "Unknown authentication type", errcode=Codes.INVALID_PARAM,
+            )
+
+        password = authdict.get("password")
+        if password is None:
+            raise SynapseError(
+                400,
+                "Missing parameter for m.login.password dict: 'password'",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        # Retrieve the user ID using details provided in the authdict
+
+        # Deprecation notice: Clients used to be able to simply provide a
+        # `user` field which pointed to a user_id or localpart. This has
+        # been deprecated in favour of an `identifier` key, which is a
+        # dictionary providing information on how to identify a single
+        # user.
+        # https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
+        #
+        # We convert old-style dicts to new ones here
+        client_dict_convert_legacy_fields_to_identifier(authdict)
+
+        # Extract a user ID from the values in the identifier
+        username = await self.username_from_identifier(authdict["identifier"], password)
 
-        if user_id is None:
-            raise SynapseError(400, "", Codes.MISSING_PARAM)
+        if username is None:
+            raise SynapseError(400, "Valid username not found")
+
+        # Now that we've found the username, validate that the password is correct
+        canonical_id, _ = await self.validate_login(username, authdict)
 
-        (canonical_id, callback) = await self.validate_login(user_id, authdict)
         return canonical_id
 
+    async def username_from_identifier(
+        self, identifier: Dict[str, str], password: Optional[str] = None
+    ) -> Optional[str]:
+        """Given a dictionary containing an identifier from a client, extract the
+        possibly unqualified username of the user that it identifies. Does *not*
+        guarantee that the user exists.
+
+        If this identifier dict contains a threepid, we attempt to ask password
+        auth providers about it or, failing that, look up an associated user in
+        the database.
+
+        Args:
+            identifier: The identifier dictionary provided by the client
+            password: The user provided password if one exists. Used for asking
+                password auth providers for usernames from 3pid+password combos.
+
+        Returns:
+            A username if one was found, or None otherwise
+
+        Raises:
+            SynapseError: If the identifier dict is invalid
+        """
+
+        # Convert phone type identifiers to generic threepid identifiers, which
+        # will be handled in the next step
+        if identifier["type"] == "m.id.phone":
+            identifier = login_id_phone_to_thirdparty(identifier)
+
+        # Convert a threepid identifier to an user identifier
+        if identifier["type"] == "m.id.thirdparty":
+            address = identifier.get("address")
+            medium = identifier.get("medium")
+
+            if not medium or not address:
+                # An error would've already been raised in
+                # `login_id_thirdparty_from_phone` if the original submission
+                # was a phone identifier
+                raise SynapseError(
+                    400, "Invalid thirdparty identifier", errcode=Codes.INVALID_PARAM,
+                )
+
+            if medium == "email":
+                # For emails, transform the address to lowercase.
+                # We store all email addresses as lowercase in the DB.
+                # (See add_threepid in synapse/handlers/auth.py)
+                address = address.lower()
+
+            # Check for auth providers that support 3pid login types
+            if password is not None:
+                canonical_user_id, _ = await self.check_password_provider_3pid(
+                    medium, address, password,
+                )
+                if canonical_user_id:
+                    # Authentication through password provider and 3pid succeeded
+                    return canonical_user_id
+
+            # Check local store
+            user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+                medium, address
+            )
+            if not user_id:
+                # We were unable to find a user_id that belonged to the threepid returned
+                # by the password auth provider
+                return None
+
+            identifier = {"type": "m.id.user", "user": user_id}
+
+        # By this point, the identifier should be a `m.id.user`: if it's anything
+        # else, we haven't understood it.
+        if identifier["type"] != "m.id.user":
+            raise SynapseError(
+                400, "Unknown login identifier type", errcode=Codes.INVALID_PARAM,
+            )
+
+        # User identifiers have a "user" key
+        user = identifier.get("user")
+        if user is None:
+            raise SynapseError(
+                400,
+                "User identifier is missing 'user' key",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        return user
+
     def _get_params_recaptcha(self) -> dict:
         return {"public_key": self.hs.config.recaptcha_public_key}
 
@@ -698,7 +889,8 @@ class AuthHandler(BaseHandler):
         m.login.password auth types.
 
         Args:
-            username: username supplied by the user
+            username: a localpart or fully qualified user ID - what is provided by the
+                client
             login_submission: the whole of the login submission
                 (including 'type' and other relevant fields)
         Returns:
@@ -710,10 +902,10 @@ class AuthHandler(BaseHandler):
             LoginError if there was an authentication problem.
         """
 
-        if username.startswith("@"):
-            qualified_user_id = username
-        else:
-            qualified_user_id = UserID(username, self.hs.hostname).to_string()
+        # We need a fully qualified User ID for some method calls here
+        qualified_user_id = username
+        if not qualified_user_id.startswith("@"):
+            qualified_user_id = UserID(qualified_user_id, self.hs.hostname).to_string()
 
         login_type = login_submission.get("type")
         known_login_type = False
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 379f668d6f..3f116e5b44 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,6 +18,7 @@ from typing import Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.handlers.auth import client_dict_convert_legacy_fields_to_identifier
 from synapse.http.server import finish_request
 from synapse.http.servlet import (
     RestServlet,
@@ -28,56 +29,11 @@ from synapse.http.site import SynapseRequest
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
 from synapse.types import JsonDict, UserID
-from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
 logger = logging.getLogger(__name__)
 
 
-def login_submission_legacy_convert(submission):
-    """
-    If the input login submission is an old style object
-    (ie. with top-level user / medium / address) convert it
-    to a typed object.
-    """
-    if "user" in submission:
-        submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
-        del submission["user"]
-
-    if "medium" in submission and "address" in submission:
-        submission["identifier"] = {
-            "type": "m.id.thirdparty",
-            "medium": submission["medium"],
-            "address": submission["address"],
-        }
-        del submission["medium"]
-        del submission["address"]
-
-
-def login_id_thirdparty_from_phone(identifier):
-    """
-    Convert a phone login identifier type to a generic threepid identifier
-    Args:
-        identifier(dict): Login identifier dict of type 'm.id.phone'
-
-    Returns: Login identifier dict of type 'm.id.threepid'
-    """
-    if "country" not in identifier or (
-        # The specification requires a "phone" field, while Synapse used to require a "number"
-        # field. Accept both for backwards compatibility.
-        "phone" not in identifier
-        and "number" not in identifier
-    ):
-        raise SynapseError(400, "Invalid phone-type identifier")
-
-    # Accept both "phone" and "number" as valid keys in m.id.phone
-    phone_number = identifier.get("phone", identifier["number"])
-
-    msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
-
-    return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
-
-
 class LoginRestServlet(RestServlet):
     PATTERNS = client_patterns("/login$", v1=True)
     CAS_TYPE = "m.login.cas"
@@ -167,7 +123,8 @@ class LoginRestServlet(RestServlet):
                 result = await self._do_token_login(login_submission)
             else:
                 result = await self._do_other_login(login_submission)
-        except KeyError:
+        except KeyError as e:
+            logger.debug("KeyError during login: %s", e)
             raise SynapseError(400, "Missing JSON keys.")
 
         well_known_data = self._well_known_builder.get_well_known()
@@ -194,27 +151,14 @@ class LoginRestServlet(RestServlet):
             login_submission.get("address"),
             login_submission.get("user"),
         )
-        login_submission_legacy_convert(login_submission)
-
-        if "identifier" not in login_submission:
-            raise SynapseError(400, "Missing param: identifier")
-
-        identifier = login_submission["identifier"]
-        if "type" not in identifier:
-            raise SynapseError(400, "Login identifier has no type")
-
-        # convert phone type identifiers to generic threepids
-        if identifier["type"] == "m.id.phone":
-            identifier = login_id_thirdparty_from_phone(identifier)
-
-        # convert threepid identifiers to user IDs
-        if identifier["type"] == "m.id.thirdparty":
-            address = identifier.get("address")
-            medium = identifier.get("medium")
-
-            if medium is None or address is None:
-                raise SynapseError(400, "Invalid thirdparty identifier")
-
+        # Convert deprecated authdict formats to the current scheme
+        client_dict_convert_legacy_fields_to_identifier(login_submission)
+
+        # Check whether this attempt uses a threepid, if so, check if our failed attempt
+        # ratelimiter allows another attempt at this time
+        medium = login_submission.get("medium")
+        address = login_submission.get("address")
+        if medium and address:
             # For emails, canonicalise the address.
             # We store all email addresses canonicalised in the DB.
             # (See add_threepid in synapse/handlers/auth.py)
@@ -224,74 +168,41 @@ class LoginRestServlet(RestServlet):
                 except ValueError as e:
                     raise SynapseError(400, str(e))
 
-            # We also apply account rate limiting using the 3PID as a key, as
-            # otherwise using 3PID bypasses the ratelimiting based on user ID.
             self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
 
-            # Check for login providers that support 3pid login types
-            (
-                canonical_user_id,
-                callback_3pid,
-            ) = await self.auth_handler.check_password_provider_3pid(
-                medium, address, login_submission["password"]
-            )
-            if canonical_user_id:
-                # Authentication through password provider and 3pid succeeded
+        # Extract a localpart or user ID from the values in the identifier
+        username = await self.auth_handler.username_from_identifier(
+            login_submission["identifier"], login_submission.get("password")
+        )
 
-                result = await self._complete_login(
-                    canonical_user_id, login_submission, callback_3pid
+        if not username:
+            if medium and address:
+                # The user attempted to login via threepid and failed
+                # Record this failed attempt using the threepid as a key, as otherwise
+                # the user could bypass the ratelimiter by not providing a username
+                self._failed_attempts_ratelimiter.can_do_action(
+                    (medium, address.lower())
                 )
-                return result
 
-            # No password providers were able to handle this 3pid
-            # Check local store
-            user_id = await self.hs.get_datastore().get_user_id_by_threepid(
-                medium, address
-            )
-            if not user_id:
-                logger.warning(
-                    "unknown 3pid identifier medium %s, address %r", medium, address
-                )
-                # We mark that we've failed to log in here, as
-                # `check_password_provider_3pid` might have returned `None` due
-                # to an incorrect password, rather than the account not
-                # existing.
-                #
-                # If it returned None but the 3PID was bound then we won't hit
-                # this code path, which is fine as then the per-user ratelimit
-                # will kick in below.
-                self._failed_attempts_ratelimiter.can_do_action((medium, address))
-                raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
-            identifier = {"type": "m.id.user", "user": user_id}
-
-        # by this point, the identifier should be an m.id.user: if it's anything
-        # else, we haven't understood it.
-        if identifier["type"] != "m.id.user":
-            raise SynapseError(400, "Unknown login identifier type")
-        if "user" not in identifier:
-            raise SynapseError(400, "User identifier is missing 'user' key")
-
-        if identifier["user"].startswith("@"):
-            qualified_user_id = identifier["user"]
-        else:
-            qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
-
-        # Check if we've hit the failed ratelimit (but don't update it)
-        self._failed_attempts_ratelimiter.ratelimit(
-            qualified_user_id.lower(), update=False
-        )
+                raise LoginError(403, "Unauthorized threepid", errcode=Codes.FORBIDDEN)
+
+            # The login failed for another reason
+            raise LoginError(403, "Invalid login", errcode=Codes.FORBIDDEN)
+
+        # We were able to extract a username successfully
+        # Check if we've hit the failed ratelimit for this user ID
+        self._failed_attempts_ratelimiter.ratelimit(username.lower(), update=False)
 
         try:
             canonical_user_id, callback = await self.auth_handler.validate_login(
-                identifier["user"], login_submission
+                username, login_submission
             )
         except LoginError:
             # The user has failed to log in, so we need to update the rate
             # limiter. Using `can_do_action` avoids us raising a ratelimit
-            # exception and masking the LoginError. The actual ratelimiting
-            # should have happened above.
-            self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
+            # exception and masking the LoginError. This just records the attempt.
+            # The actual rate-limiting happens above
+            self._failed_attempts_ratelimiter.can_do_action(username.lower())
             raise
 
         result = await self._complete_login(
@@ -309,7 +220,7 @@ class LoginRestServlet(RestServlet):
         create_non_existent_users: bool = False,
     ) -> Dict[str, str]:
         """Called when we've successfully authed the user and now need to
-        actually login them in (e.g. create devices). This gets called on
+        actually log them in (e.g. create devices). This gets called on
         all successful logins.
 
         Applies the ratelimiting for successful login attempts against an
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2668662c9e..3ddcca288b 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -267,9 +267,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         auth = {
             "type": "m.login.password",
-            # https://github.com/matrix-org/synapse/issues/5665
-            # "identifier": {"type": "m.id.user", "user": user_id},
-            "user": user_id,
+            "identifier": {"type": "m.id.user", "user": user_id},
             "password": password,
             "session": channel.json_body["session"],
         }
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 293ccfba2b..8f97dd0dd2 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -38,11 +38,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
         return succeed(True)
 
 
-class DummyPasswordChecker(UserInteractiveAuthChecker):
-    def check_auth(self, authdict, clientip):
-        return succeed(authdict["identifier"]["user"])
-
-
 class FallbackAuthTests(unittest.HomeserverTestCase):
 
     servlets = [
@@ -166,9 +161,6 @@ class UIAuthTests(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        auth_handler = hs.get_auth_handler()
-        auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
-
         self.user_pass = "pass"
         self.user = self.register_user("test", self.user_pass)
         self.user_tok = self.login("test", self.user_pass)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2fc3a60fc5..0f33c7806d 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -593,6 +593,89 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(len(self.email_attempts), 0)
 
+    def test_deactivated_user_using_user_identifier(self):
+        self.email_attempts = []
+
+        (user_id, tok) = self.create_user()
+
+        request_data = json.dumps(
+            {
+                "auth": {
+                    "type": "m.login.password",
+                    "identifier": {"type": "m.id.user", "user": user_id},
+                    "password": "monkey",
+                },
+                "erase": False,
+            }
+        )
+        request, channel = self.make_request(
+            "POST", "account/deactivate", request_data, access_token=tok
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        self.reactor.advance(datetime.timedelta(days=8).total_seconds())
+
+        self.assertEqual(len(self.email_attempts), 0)
+
+    def test_deactivated_user_using_thirdparty_identifier(self):
+        self.email_attempts = []
+
+        (user_id, tok) = self.create_user()
+
+        request_data = json.dumps(
+            {
+                "auth": {
+                    "type": "m.login.password",
+                    "identifier": {
+                        "type": "m.id.thirdparty",
+                        "medium": "email",
+                        "address": "kermit@example.com",
+                    },
+                    "password": "monkey",
+                },
+                "erase": False,
+            }
+        )
+        request, channel = self.make_request(
+            "POST", "account/deactivate", request_data, access_token=tok
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        self.reactor.advance(datetime.timedelta(days=8).total_seconds())
+
+        self.assertEqual(len(self.email_attempts), 0)
+
+    def test_deactivated_user_using_phone_identifier(self):
+        self.email_attempts = []
+
+        (user_id, tok) = self.create_user()
+
+        request_data = json.dumps(
+            {
+                "auth": {
+                    "type": "m.login.password",
+                    "identifier": {
+                        "type": "m.id.phone",
+                        "country": "GB",
+                        "phone": "077-009-00001",
+                    },
+                    "password": "monkey",
+                },
+                "erase": False,
+            }
+        )
+        request, channel = self.make_request(
+            "POST", "account/deactivate", request_data, access_token=tok
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        self.reactor.advance(datetime.timedelta(days=8).total_seconds())
+
+        self.assertEqual(len(self.email_attempts), 0)
+
     def create_user(self):
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
@@ -608,6 +691,15 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
                 added_at=now,
             )
         )
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="msisdn",
+                address="447700900001",
+                validated_at=now,
+                added_at=now,
+            )
+        )
         return user_id, tok
 
     def test_manual_email_send_expired_account(self):