summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8320.feature1
-rw-r--r--synapse/rest/client/v1/login.py49
-rw-r--r--tests/rest/client/v1/test_login.py134
3 files changed, 173 insertions, 11 deletions
diff --git a/changelog.d/8320.feature b/changelog.d/8320.feature
new file mode 100644
index 0000000000..475a5fe62d
--- /dev/null
+++ b/changelog.d/8320.feature
@@ -0,0 +1 @@
+Add `uk.half-shot.msc2778.login.application_service` login type to allow appservices to login.
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index a14618ac84..dd8cdc0d9f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,6 +18,7 @@ from typing import Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.appservice import ApplicationService
 from synapse.handlers.auth import (
     convert_client_dict_legacy_fields_to_identifier,
     login_id_phone_to_thirdparty,
@@ -44,6 +45,7 @@ class LoginRestServlet(RestServlet):
     TOKEN_TYPE = "m.login.token"
     JWT_TYPE = "org.matrix.login.jwt"
     JWT_TYPE_DEPRECATED = "m.login.jwt"
+    APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
 
     def __init__(self, hs):
         super(LoginRestServlet, self).__init__()
@@ -61,6 +63,8 @@ class LoginRestServlet(RestServlet):
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
 
+        self.auth = hs.get_auth()
+
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self.handlers = hs.get_handlers()
@@ -107,6 +111,8 @@ class LoginRestServlet(RestServlet):
             ({"type": t} for t in self.auth_handler.get_supported_login_types())
         )
 
+        flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
+
         return 200, {"flows": flows}
 
     def on_OPTIONS(self, request: SynapseRequest):
@@ -116,8 +122,12 @@ class LoginRestServlet(RestServlet):
         self._address_ratelimiter.ratelimit(request.getClientIP())
 
         login_submission = parse_json_object_from_request(request)
+
         try:
-            if self.jwt_enabled and (
+            if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
+                appservice = self.auth.get_appservice_by_req(request)
+                result = await self._do_appservice_login(login_submission, appservice)
+            elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
@@ -134,6 +144,33 @@ class LoginRestServlet(RestServlet):
             result["well_known"] = well_known_data
         return 200, result
 
+    def _get_qualified_user_id(self, identifier):
+        if identifier["type"] != "m.id.user":
+            raise SynapseError(400, "Unknown login identifier type")
+        if "user" not in identifier:
+            raise SynapseError(400, "User identifier is missing 'user' key")
+
+        if identifier["user"].startswith("@"):
+            return identifier["user"]
+        else:
+            return UserID(identifier["user"], self.hs.hostname).to_string()
+
+    async def _do_appservice_login(
+        self, login_submission: JsonDict, appservice: ApplicationService
+    ):
+        logger.info(
+            "Got appservice login request with identifier: %r",
+            login_submission.get("identifier"),
+        )
+
+        identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
+        qualified_user_id = self._get_qualified_user_id(identifier)
+
+        if not appservice.is_interested_in_user(qualified_user_id):
+            raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
+
+        return await self._complete_login(qualified_user_id, login_submission)
+
     async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
         """Handle non-token/saml/jwt logins
 
@@ -219,15 +256,7 @@ class LoginRestServlet(RestServlet):
 
         # by this point, the identifier should be an m.id.user: if it's anything
         # else, we haven't understood it.
-        if identifier["type"] != "m.id.user":
-            raise SynapseError(400, "Unknown login identifier type")
-        if "user" not in identifier:
-            raise SynapseError(400, "User identifier is missing 'user' key")
-
-        if identifier["user"].startswith("@"):
-            qualified_user_id = identifier["user"]
-        else:
-            qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
+        qualified_user_id = self._get_qualified_user_id(identifier)
 
         # Check if we've hit the failed ratelimit (but don't update it)
         self._failed_attempts_ratelimiter.ratelimit(
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2668662c9e..5d987a30c7 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -7,8 +7,9 @@ from mock import Mock
 import jwt
 
 import synapse.rest.admin
+from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login, logout
-from synapse.rest.client.v2_alpha import devices
+from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
 
 from tests import unittest
@@ -748,3 +749,134 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
             "JWT validation failed: Signature verification failed",
         )
+
+
+AS_USER = "as_user_alice"
+
+
+class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        login.register_servlets,
+        register.register_servlets,
+    ]
+
+    def register_as_user(self, username):
+        request, channel = self.make_request(
+            b"POST",
+            "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
+            {"username": username},
+        )
+        self.render(request)
+
+    def make_homeserver(self, reactor, clock):
+        self.hs = self.setup_test_homeserver()
+
+        self.service = ApplicationService(
+            id="unique_identifier",
+            token="some_token",
+            hostname="example.com",
+            sender="@asbot:example.com",
+            namespaces={
+                ApplicationService.NS_USERS: [
+                    {"regex": r"@as_user.*", "exclusive": False}
+                ],
+                ApplicationService.NS_ROOMS: [],
+                ApplicationService.NS_ALIASES: [],
+            },
+        )
+        self.another_service = ApplicationService(
+            id="another__identifier",
+            token="another_token",
+            hostname="example.com",
+            sender="@as2bot:example.com",
+            namespaces={
+                ApplicationService.NS_USERS: [
+                    {"regex": r"@as2_user.*", "exclusive": False}
+                ],
+                ApplicationService.NS_ROOMS: [],
+                ApplicationService.NS_ALIASES: [],
+            },
+        )
+
+        self.hs.get_datastore().services_cache.append(self.service)
+        self.hs.get_datastore().services_cache.append(self.another_service)
+        return self.hs
+
+    def test_login_appservice_user(self):
+        """Test that an appservice user can use /login
+        """
+        self.register_as_user(AS_USER)
+
+        params = {
+            "type": login.LoginRestServlet.APPSERVICE_TYPE,
+            "identifier": {"type": "m.id.user", "user": AS_USER},
+        }
+        request, channel = self.make_request(
+            b"POST", LOGIN_URL, params, access_token=self.service.token
+        )
+
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_login_appservice_user_bot(self):
+        """Test that the appservice bot can use /login
+        """
+        self.register_as_user(AS_USER)
+
+        params = {
+            "type": login.LoginRestServlet.APPSERVICE_TYPE,
+            "identifier": {"type": "m.id.user", "user": self.service.sender},
+        }
+        request, channel = self.make_request(
+            b"POST", LOGIN_URL, params, access_token=self.service.token
+        )
+
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_login_appservice_wrong_user(self):
+        """Test that non-as users cannot login with the as token
+        """
+        self.register_as_user(AS_USER)
+
+        params = {
+            "type": login.LoginRestServlet.APPSERVICE_TYPE,
+            "identifier": {"type": "m.id.user", "user": "fibble_wibble"},
+        }
+        request, channel = self.make_request(
+            b"POST", LOGIN_URL, params, access_token=self.service.token
+        )
+
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+
+    def test_login_appservice_wrong_as(self):
+        """Test that as users cannot login with wrong as token
+        """
+        self.register_as_user(AS_USER)
+
+        params = {
+            "type": login.LoginRestServlet.APPSERVICE_TYPE,
+            "identifier": {"type": "m.id.user", "user": AS_USER},
+        }
+        request, channel = self.make_request(
+            b"POST", LOGIN_URL, params, access_token=self.another_service.token
+        )
+
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+
+    def test_login_appservice_no_token(self):
+        """Test that users must provide a token when using the appservice
+           login method
+        """
+        self.register_as_user(AS_USER)
+
+        params = {
+            "type": login.LoginRestServlet.APPSERVICE_TYPE,
+            "identifier": {"type": "m.id.user", "user": AS_USER},
+        }
+        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"401", channel.result)