summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/register.py51
-rw-r--r--synapse/replication/http/login.py17
-rw-r--r--synapse/rest/client/v1/login.py59
-rw-r--r--synapse/rest/client/v2_alpha/register.py49
-rw-r--r--tests/rest/client/v2_alpha/test_register.py93
5 files changed, 97 insertions, 172 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 8ea557a003..f92ab4d525 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -27,6 +27,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.http.client import CaptchaServerHttpClient
+from synapse.replication.http.login import RegisterDeviceReplicationServlet
 from synapse.replication.http.register import ReplicationRegisterServlet
 from synapse.types import RoomAlias, RoomID, UserID, create_requester
 from synapse.util.async_helpers import Linearizer
@@ -64,6 +65,11 @@ class RegistrationHandler(BaseHandler):
 
         if hs.config.worker_app:
             self._register_client = ReplicationRegisterServlet.make_client(hs)
+            self._register_device_client = (
+                RegisterDeviceReplicationServlet.make_client(hs)
+            )
+        else:
+            self.device_handler = hs.get_device_handler()
 
     @defer.inlineCallbacks
     def check_username(self, localpart, guest_access_token=None,
@@ -159,7 +165,7 @@ class RegistrationHandler(BaseHandler):
         yield self.auth.check_auth_blocking(threepid=threepid)
         password_hash = None
         if password:
-            password_hash = yield self.auth_handler().hash(password)
+            password_hash = yield self._auth_handler.hash(password)
 
         if localpart:
             yield self.check_username(localpart, guest_access_token=guest_access_token)
@@ -516,9 +522,6 @@ class RegistrationHandler(BaseHandler):
 
         defer.returnValue((user_id, token))
 
-    def auth_handler(self):
-        return self.hs.get_auth_handler()
-
     @defer.inlineCallbacks
     def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
         """Get a guest access token for a 3PID, creating a guest account if
@@ -628,3 +631,43 @@ class RegistrationHandler(BaseHandler):
                 admin=admin,
                 user_type=user_type,
             )
+
+    @defer.inlineCallbacks
+    def register_device(self, user_id, device_id, initial_display_name,
+                        is_guest=False):
+        """Register a device for a user and generate an access token.
+
+        Args:
+            user_id (str): full canonical @user:id
+            device_id (str|None): The device ID to check, or None to generate
+                a new one.
+            initial_display_name (str|None): An optional display name for the
+                device.
+            is_guest (bool): Whether this is a guest account
+
+        Returns:
+            defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+        """
+
+        if self.hs.config.worker_app:
+            r = yield self._register_device_client(
+                user_id=user_id,
+                device_id=device_id,
+                initial_display_name=initial_display_name,
+                is_guest=is_guest,
+            )
+            defer.returnValue((r["device_id"], r["access_token"]))
+        else:
+            device_id = yield self.device_handler.check_device_registered(
+                user_id, device_id, initial_display_name
+            )
+            if is_guest:
+                access_token = self.macaroon_gen.generate_access_token(
+                    user_id, ["guest = true"]
+                )
+            else:
+                access_token = yield self._auth_handler.get_access_token_for_user_id(
+                    user_id, device_id=device_id,
+                )
+
+            defer.returnValue((device_id, access_token))
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 797f6aabd1..1590eca317 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -35,9 +35,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
 
     def __init__(self, hs):
         super(RegisterDeviceReplicationServlet, self).__init__(hs)
-        self.auth_handler = hs.get_auth_handler()
-        self.device_handler = hs.get_device_handler()
-        self.macaroon_gen = hs.get_macaroon_generator()
+        self.registration_handler = hs.get_handlers().registration_handler
 
     @staticmethod
     def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
@@ -62,19 +60,10 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         initial_display_name = content["initial_display_name"]
         is_guest = content["is_guest"]
 
-        device_id = yield self.device_handler.check_device_registered(
-            user_id, device_id, initial_display_name,
+        device_id, access_token = yield self.registration_handler.register_device(
+            user_id, device_id, initial_display_name, is_guest,
         )
 
-        if is_guest:
-            access_token = self.macaroon_gen.generate_access_token(
-                user_id, ["guest = true"]
-            )
-        else:
-            access_token = yield self.auth_handler.get_access_token_for_user_id(
-                user_id, device_id=device_id,
-            )
-
         defer.returnValue((200, {
             "device_id": device_id,
             "access_token": access_token,
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 942e4d3816..4a5775083f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -94,7 +94,7 @@ class LoginRestServlet(ClientV1RestServlet):
         self.jwt_algorithm = hs.config.jwt_algorithm
         self.cas_enabled = hs.config.cas_enabled
         self.auth_handler = self.hs.get_auth_handler()
-        self.device_handler = self.hs.get_device_handler()
+        self.registration_handler = hs.get_handlers().registration_handler
         self.handlers = hs.get_handlers()
         self._well_known_builder = WellKnownBuilder(hs)
 
@@ -220,11 +220,10 @@ class LoginRestServlet(ClientV1RestServlet):
             login_submission,
         )
 
-        device_id = yield self._register_device(
-            canonical_user_id, login_submission,
-        )
-        access_token = yield auth_handler.get_access_token_for_user_id(
-            canonical_user_id, device_id,
+        device_id = login_submission.get("device_id")
+        initial_display_name = login_submission.get("initial_device_display_name")
+        device_id, access_token = yield self.registration_handler.register_device(
+            canonical_user_id, device_id, initial_display_name,
         )
 
         result = {
@@ -246,10 +245,13 @@ class LoginRestServlet(ClientV1RestServlet):
         user_id = (
             yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
-        device_id = yield self._register_device(user_id, login_submission)
-        access_token = yield auth_handler.get_access_token_for_user_id(
-            user_id, device_id,
+
+        device_id = login_submission.get("device_id")
+        initial_display_name = login_submission.get("initial_device_display_name")
+        device_id, access_token = yield self.registration_handler.register_device(
+            user_id, device_id, initial_display_name,
         )
+
         result = {
             "user_id": user_id,  # may have changed
             "access_token": access_token,
@@ -286,11 +288,10 @@ class LoginRestServlet(ClientV1RestServlet):
         auth_handler = self.auth_handler
         registered_user_id = yield auth_handler.check_user_exists(user_id)
         if registered_user_id:
-            device_id = yield self._register_device(
-                registered_user_id, login_submission
-            )
-            access_token = yield auth_handler.get_access_token_for_user_id(
-                registered_user_id, device_id,
+            device_id = login_submission.get("device_id")
+            initial_display_name = login_submission.get("initial_device_display_name")
+            device_id, access_token = yield self.registration_handler.register_device(
+                registered_user_id, device_id, initial_display_name,
             )
 
             result = {
@@ -299,12 +300,16 @@ class LoginRestServlet(ClientV1RestServlet):
                 "home_server": self.hs.hostname,
             }
         else:
-            # TODO: we should probably check that the register isn't going
-            # to fonx/change our user_id before registering the device
-            device_id = yield self._register_device(user_id, login_submission)
             user_id, access_token = (
                 yield self.handlers.registration_handler.register(localpart=user)
             )
+
+            device_id = login_submission.get("device_id")
+            initial_display_name = login_submission.get("initial_device_display_name")
+            device_id, access_token = yield self.registration_handler.register_device(
+                registered_user_id, device_id, initial_display_name,
+            )
+
             result = {
                 "user_id": user_id,  # may have changed
                 "access_token": access_token,
@@ -313,26 +318,6 @@ class LoginRestServlet(ClientV1RestServlet):
 
         defer.returnValue(result)
 
-    def _register_device(self, user_id, login_submission):
-        """Register a device for a user.
-
-        This is called after the user's credentials have been validated, but
-        before the access token has been issued.
-
-        Args:
-            (str) user_id: full canonical @user:id
-            (object) login_submission: dictionary supplied to /login call, from
-               which we pull device_id and initial_device_name
-        Returns:
-            defer.Deferred: (str) device_id
-        """
-        device_id = login_submission.get("device_id")
-        initial_display_name = login_submission.get(
-            "initial_device_display_name")
-        return self.device_handler.check_device_registered(
-            user_id, device_id, initial_display_name
-        )
-
 
 class CasRedirectServlet(RestServlet):
     PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index c52280c50c..c1cdb8f9c8 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -33,7 +33,6 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
-from synapse.replication.http.login import RegisterDeviceReplicationServlet
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.threepids import check_3pid_allowed
@@ -193,13 +192,6 @@ class RegisterRestServlet(RestServlet):
         self.room_member_handler = hs.get_room_member_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
 
-        if self.hs.config.worker_app:
-            self._register_device_client = (
-                RegisterDeviceReplicationServlet.make_client(hs)
-            )
-        else:
-            self.device_handler = hs.get_device_handler()
-
     @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -642,7 +634,7 @@ class RegisterRestServlet(RestServlet):
         if not params.get("inhibit_login", False):
             device_id = params.get("device_id")
             initial_display_name = params.get("initial_device_display_name")
-            device_id, access_token = yield self._register_device(
+            device_id, access_token = yield self.registration_handler.register_device(
                 user_id, device_id, initial_display_name, is_guest=False,
             )
 
@@ -653,43 +645,6 @@ class RegisterRestServlet(RestServlet):
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def _register_device(self, user_id, device_id, initial_display_name,
-                         is_guest):
-        """Register a device for a user and generate an access token.
-
-        Args:
-            user_id (str): full canonical @user:id
-            device_id (str|None): The device ID to check, or None to generate
-                a new one.
-            initial_display_name (str|None): An optional display name for the
-                device.
-            is_guest (bool): Whether this is a guest account
-        Returns:
-            defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
-        """
-        if self.hs.config.worker_app:
-            r = yield self._register_device_client(
-                user_id=user_id,
-                device_id=device_id,
-                initial_display_name=initial_display_name,
-                is_guest=is_guest,
-            )
-            defer.returnValue((r["device_id"], r["access_token"]))
-        else:
-            device_id = yield self.device_handler.check_device_registered(
-                user_id, device_id, initial_display_name
-            )
-            if is_guest:
-                access_token = self.macaroon_gen.generate_access_token(
-                    user_id, ["guest = true"]
-                )
-            else:
-                access_token = yield self.auth_handler.get_access_token_for_user_id(
-                    user_id, device_id=device_id,
-                )
-            defer.returnValue((device_id, access_token))
-
-    @defer.inlineCallbacks
     def _do_guest_registration(self, params):
         if not self.hs.config.allow_guest_access:
             raise SynapseError(403, "Guest access is disabled")
@@ -702,7 +657,7 @@ class RegisterRestServlet(RestServlet):
         # we have nowhere to store it.
         device_id = synapse.api.auth.GUEST_DEVICE_ID
         initial_display_name = params.get("initial_device_display_name")
-        device_id, access_token = yield self._register_device(
+        device_id, access_token = yield self.registration_handler.register_device(
             user_id, device_id, initial_display_name, is_guest=True,
         )
 
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 18080ebfd6..906b348d3e 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,10 +1,7 @@
 import json
 
-from mock import Mock
-
-from twisted.python import failure
-
-from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.api.constants import LoginType
+from synapse.appservice import ApplicationService
 from synapse.rest.client.v2_alpha.register import register_servlets
 
 from tests import unittest
@@ -18,61 +15,28 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.url = b"/_matrix/client/r0/register"
 
-        self.appservice = None
-        self.auth = Mock(
-            get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
-        )
-
-        self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
-        self.auth_handler = Mock(
-            check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
-            get_session_data=Mock(return_value=None),
-        )
-        self.registration_handler = Mock()
-        self.identity_handler = Mock()
-        self.login_handler = Mock()
-        self.device_handler = Mock()
-
-        def check_device_registered(user_id, device_id, initial_display_name):
-            # Just echo back the given device ID, or return a new "FAKE" device
-            # ID
-            if device_id:
-                return device_id
-            else:
-                return "FAKE"
-
-        self.device_handler.check_device_registered = Mock(
-            side_effect=check_device_registered,
-        )
-
-        self.datastore = Mock(return_value=Mock())
-        self.datastore.get_current_state_deltas = Mock(return_value=[])
-
-        # do the dance to hook it up to the hs global
-        self.handlers = Mock(
-            registration_handler=self.registration_handler,
-            identity_handler=self.identity_handler,
-            login_handler=self.login_handler,
-        )
         self.hs = self.setup_test_homeserver()
-        self.hs.get_auth = Mock(return_value=self.auth)
-        self.hs.get_handlers = Mock(return_value=self.handlers)
-        self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
-        self.hs.get_device_handler = Mock(return_value=self.device_handler)
-        self.hs.get_datastore = Mock(return_value=self.datastore)
         self.hs.config.enable_registration = True
         self.hs.config.registrations_require_3pid = []
         self.hs.config.auto_join_rooms = []
+        self.hs.config.enable_registration_captcha = False
 
         return self.hs
 
     def test_POST_appservice_registration_valid(self):
-        user_id = "@kermit:muppet"
-        token = "kermits_access_token"
-        self.appservice = {"id": "1234"}
-        self.registration_handler.appservice_register = Mock(return_value=user_id)
-        self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
-        request_data = json.dumps({"username": "kermit"})
+        user_id = "@as_user_kermit:test"
+        as_token = "i_am_an_app_service"
+
+        appservice = ApplicationService(
+            as_token, self.hs.config.hostname,
+            id="1234",
+            namespaces={
+                "users": [{"regex": r"@as_user.*", "exclusive": True}],
+            },
+        )
+
+        self.hs.get_datastore().services_cache.append(appservice)
+        request_data = json.dumps({"username": "as_user_kermit"})
 
         request, channel = self.make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
@@ -82,7 +46,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
         det_data = {
             "user_id": user_id,
-            "access_token": token,
             "home_server": self.hs.hostname,
         }
         self.assertDictContainsSubset(det_data, channel.json_body)
@@ -114,37 +77,30 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.json_body["error"], "Invalid username")
 
     def test_POST_user_valid(self):
-        user_id = "@kermit:muppet"
-        token = "kermits_access_token"
+        user_id = "@kermit:test"
         device_id = "frogfone"
-        params = {"username": "kermit", "password": "monkey", "device_id": device_id}
+        params = {
+            "username": "kermit",
+            "password": "monkey",
+            "device_id": device_id,
+            "auth": {"type": LoginType.DUMMY},
+        }
         request_data = json.dumps(params)
-        self.registration_handler.check_username = Mock(return_value=True)
-        self.auth_result = (None, params, None)
-        self.registration_handler.register = Mock(return_value=(user_id, None))
-        self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
-
         request, channel = self.make_request(b"POST", self.url, request_data)
         self.render(request)
 
         det_data = {
             "user_id": user_id,
-            "access_token": token,
             "home_server": self.hs.hostname,
             "device_id": device_id,
         }
         self.assertEquals(channel.result["code"], b"200", channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
-        self.auth_handler.get_login_tuple_for_user_id(
-            user_id, device_id=device_id, initial_device_display_name=None
-        )
 
     def test_POST_disabled_registration(self):
         self.hs.config.enable_registration = False
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
-        self.registration_handler.check_username = Mock(return_value=True)
         self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
-        self.registration_handler.register = Mock(return_value=("@user:id", "t"))
 
         request, channel = self.make_request(b"POST", self.url, request_data)
         self.render(request)
@@ -153,16 +109,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.json_body["error"], "Registration has been disabled")
 
     def test_POST_guest_registration(self):
-        user_id = "a@b"
         self.hs.config.macaroon_secret_key = "test"
         self.hs.config.allow_guest_access = True
-        self.registration_handler.register = Mock(return_value=(user_id, None))
 
         request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
         self.render(request)
 
         det_data = {
-            "user_id": user_id,
             "home_server": self.hs.hostname,
             "device_id": "guest_device",
         }