From 68c7a6936f8921744d083e6dc8a2a085cce30b2a Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Fri, 18 Sep 2020 14:55:13 +0100 Subject: Allow appservice users to /login (#8320) Add ability for ASes to /login using the `uk.half-shot.msc2778.login.application_service` login `type`. Co-authored-by: Patrick Cloke --- synapse/rest/client/v1/login.py | 49 ++++++++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 10 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a14618ac84..dd8cdc0d9f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,6 +18,7 @@ from typing import Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter +from synapse.appservice import ApplicationService from synapse.handlers.auth import ( convert_client_dict_legacy_fields_to_identifier, login_id_phone_to_thirdparty, @@ -44,6 +45,7 @@ class LoginRestServlet(RestServlet): TOKEN_TYPE = "m.login.token" JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" + APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" def __init__(self, hs): super(LoginRestServlet, self).__init__() @@ -61,6 +63,8 @@ class LoginRestServlet(RestServlet): self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self.auth = hs.get_auth() + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() @@ -107,6 +111,8 @@ class LoginRestServlet(RestServlet): ({"type": t} for t in self.auth_handler.get_supported_login_types()) ) + flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) + return 200, {"flows": flows} def on_OPTIONS(self, request: SynapseRequest): @@ -116,8 +122,12 @@ class LoginRestServlet(RestServlet): self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) + try: - if self.jwt_enabled and ( + if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: + appservice = self.auth.get_appservice_by_req(request) + result = await self._do_appservice_login(login_submission, appservice) + elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): @@ -134,6 +144,33 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result + def _get_qualified_user_id(self, identifier): + if identifier["type"] != "m.id.user": + raise SynapseError(400, "Unknown login identifier type") + if "user" not in identifier: + raise SynapseError(400, "User identifier is missing 'user' key") + + if identifier["user"].startswith("@"): + return identifier["user"] + else: + return UserID(identifier["user"], self.hs.hostname).to_string() + + async def _do_appservice_login( + self, login_submission: JsonDict, appservice: ApplicationService + ): + logger.info( + "Got appservice login request with identifier: %r", + login_submission.get("identifier"), + ) + + identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) + qualified_user_id = self._get_qualified_user_id(identifier) + + if not appservice.is_interested_in_user(qualified_user_id): + raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) + + return await self._complete_login(qualified_user_id, login_submission) + async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: """Handle non-token/saml/jwt logins @@ -219,15 +256,7 @@ class LoginRestServlet(RestServlet): # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. - if identifier["type"] != "m.id.user": - raise SynapseError(400, "Unknown login identifier type") - if "user" not in identifier: - raise SynapseError(400, "User identifier is missing 'user' key") - - if identifier["user"].startswith("@"): - qualified_user_id = identifier["user"] - else: - qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string() + qualified_user_id = self._get_qualified_user_id(identifier) # Check if we've hit the failed ratelimit (but don't update it) self._failed_attempts_ratelimiter.ratelimit( -- cgit 1.4.1