summary refs log tree commit diff
path: root/synapse/rest/client/v1/login.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r--synapse/rest/client/v1/login.py83
1 files changed, 64 insertions, 19 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 8df9d10efa..92fcae674a 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet):
         self.servername = hs.config.server_name
         self.http_client = hs.get_simple_http_client()
         self.auth_handler = self.hs.get_auth_handler()
+        self.device_handler = self.hs.get_device_handler()
 
     def on_GET(self, request):
         flows = []
@@ -145,15 +146,23 @@ class LoginRestServlet(ClientV1RestServlet):
             ).to_string()
 
         auth_handler = self.auth_handler
-        user_id, access_token, refresh_token = yield auth_handler.login_with_password(
+        user_id = yield auth_handler.validate_password_login(
             user_id=user_id,
-            password=login_submission["password"])
-
+            password=login_submission["password"],
+        )
+        device_id = yield self._register_device(user_id, login_submission)
+        access_token, refresh_token = (
+            yield auth_handler.get_login_tuple_for_user_id(
+                user_id, device_id,
+                login_submission.get("initial_device_display_name")
+            )
+        )
         result = {
             "user_id": user_id,  # may have changed
             "access_token": access_token,
             "refresh_token": refresh_token,
             "home_server": self.hs.hostname,
+            "device_id": device_id,
         }
 
         defer.returnValue((200, result))
@@ -165,14 +174,19 @@ class LoginRestServlet(ClientV1RestServlet):
         user_id = (
             yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
-        user_id, access_token, refresh_token = (
-            yield auth_handler.get_login_tuple_for_user_id(user_id)
+        device_id = yield self._register_device(user_id, login_submission)
+        access_token, refresh_token = (
+            yield auth_handler.get_login_tuple_for_user_id(
+                user_id, device_id,
+                login_submission.get("initial_device_display_name")
+            )
         )
         result = {
             "user_id": user_id,  # may have changed
             "access_token": access_token,
             "refresh_token": refresh_token,
             "home_server": self.hs.hostname,
+            "device_id": device_id,
         }
 
         defer.returnValue((200, result))
@@ -196,13 +210,15 @@ class LoginRestServlet(ClientV1RestServlet):
 
         user_id = UserID.create(user, self.hs.hostname).to_string()
         auth_handler = self.auth_handler
-        user_exists = yield auth_handler.does_user_exist(user_id)
-        if user_exists:
-            user_id, access_token, refresh_token = (
-                yield auth_handler.get_login_tuple_for_user_id(user_id)
+        registered_user_id = yield auth_handler.check_user_exists(user_id)
+        if registered_user_id:
+            access_token, refresh_token = (
+                yield auth_handler.get_login_tuple_for_user_id(
+                    registered_user_id
+                )
             )
             result = {
-                "user_id": user_id,  # may have changed
+                "user_id": registered_user_id,  # may have changed
                 "access_token": access_token,
                 "refresh_token": refresh_token,
                 "home_server": self.hs.hostname,
@@ -245,18 +261,27 @@ class LoginRestServlet(ClientV1RestServlet):
 
         user_id = UserID.create(user, self.hs.hostname).to_string()
         auth_handler = self.auth_handler
-        user_exists = yield auth_handler.does_user_exist(user_id)
-        if user_exists:
-            user_id, access_token, refresh_token = (
-                yield auth_handler.get_login_tuple_for_user_id(user_id)
+        registered_user_id = yield auth_handler.check_user_exists(user_id)
+        if registered_user_id:
+            device_id = yield self._register_device(
+                registered_user_id, login_submission
+            )
+            access_token, refresh_token = (
+                yield auth_handler.get_login_tuple_for_user_id(
+                    registered_user_id, device_id,
+                    login_submission.get("initial_device_display_name")
+                )
             )
             result = {
-                "user_id": user_id,  # may have changed
+                "user_id": registered_user_id,
                 "access_token": access_token,
                 "refresh_token": refresh_token,
                 "home_server": self.hs.hostname,
             }
         else:
+            # TODO: we should probably check that the register isn't going
+            # to fonx/change our user_id before registering the device
+            device_id = yield self._register_device(user_id, login_submission)
             user_id, access_token = (
                 yield self.handlers.registration_handler.register(localpart=user)
             )
@@ -295,6 +320,26 @@ class LoginRestServlet(ClientV1RestServlet):
 
         return (user, attributes)
 
+    def _register_device(self, user_id, login_submission):
+        """Register a device for a user.
+
+        This is called after the user's credentials have been validated, but
+        before the access token has been issued.
+
+        Args:
+            (str) user_id: full canonical @user:id
+            (object) login_submission: dictionary supplied to /login call, from
+               which we pull device_id and initial_device_name
+        Returns:
+            defer.Deferred: (str) device_id
+        """
+        device_id = login_submission.get("device_id")
+        initial_display_name = login_submission.get(
+            "initial_device_display_name")
+        return self.device_handler.check_device_registered(
+            user_id, device_id, initial_display_name
+        )
+
 
 class SAML2RestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/login/saml2", releases=())
@@ -414,13 +459,13 @@ class CasTicketServlet(ClientV1RestServlet):
 
         user_id = UserID.create(user, self.hs.hostname).to_string()
         auth_handler = self.auth_handler
-        user_exists = yield auth_handler.does_user_exist(user_id)
-        if not user_exists:
-            user_id, _ = (
+        registered_user_id = yield auth_handler.check_user_exists(user_id)
+        if not registered_user_id:
+            registered_user_id, _ = (
                 yield self.handlers.registration_handler.register(localpart=user)
             )
 
-        login_token = auth_handler.generate_short_term_login_token(user_id)
+        login_token = auth_handler.generate_short_term_login_token(registered_user_id)
         redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
                                                             login_token)
         request.redirect(redirect_url)