summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v1/login.py39
-rw-r--r--synapse/rest/client/v2_alpha/tokenrefresh.py10
2 files changed, 43 insertions, 6 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index a1f2ba8773..e8b791519c 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet):
         self.servername = hs.config.server_name
         self.http_client = hs.get_simple_http_client()
         self.auth_handler = self.hs.get_auth_handler()
+        self.device_handler = self.hs.get_device_handler()
 
     def on_GET(self, request):
         flows = []
@@ -149,14 +150,16 @@ class LoginRestServlet(ClientV1RestServlet):
             user_id=user_id,
             password=login_submission["password"],
         )
+        device_id = yield self._register_device(user_id, login_submission)
         access_token, refresh_token = (
-            yield auth_handler.get_login_tuple_for_user_id(user_id)
+            yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
         )
         result = {
             "user_id": user_id,  # may have changed
             "access_token": access_token,
             "refresh_token": refresh_token,
             "home_server": self.hs.hostname,
+            "device_id": device_id,
         }
 
         defer.returnValue((200, result))
@@ -168,14 +171,16 @@ class LoginRestServlet(ClientV1RestServlet):
         user_id = (
             yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
+        device_id = yield self._register_device(user_id, login_submission)
         access_token, refresh_token = (
-            yield auth_handler.get_login_tuple_for_user_id(user_id)
+            yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
         )
         result = {
             "user_id": user_id,  # may have changed
             "access_token": access_token,
             "refresh_token": refresh_token,
             "home_server": self.hs.hostname,
+            "device_id": device_id,
         }
 
         defer.returnValue((200, result))
@@ -252,8 +257,13 @@ class LoginRestServlet(ClientV1RestServlet):
         auth_handler = self.auth_handler
         registered_user_id = yield auth_handler.check_user_exists(user_id)
         if registered_user_id:
+            device_id = yield self._register_device(
+                registered_user_id, login_submission
+            )
             access_token, refresh_token = (
-                yield auth_handler.get_login_tuple_for_user_id(registered_user_id)
+                yield auth_handler.get_login_tuple_for_user_id(
+                    registered_user_id, device_id
+                )
             )
             result = {
                 "user_id": registered_user_id,
@@ -262,6 +272,9 @@ class LoginRestServlet(ClientV1RestServlet):
                 "home_server": self.hs.hostname,
             }
         else:
+            # TODO: we should probably check that the register isn't going
+            # to fonx/change our user_id before registering the device
+            device_id = yield self._register_device(user_id, login_submission)
             user_id, access_token = (
                 yield self.handlers.registration_handler.register(localpart=user)
             )
@@ -300,6 +313,26 @@ class LoginRestServlet(ClientV1RestServlet):
 
         return (user, attributes)
 
+    def _register_device(self, user_id, login_submission):
+        """Register a device for a user.
+
+        This is called after the user's credentials have been validated, but
+        before the access token has been issued.
+
+        Args:
+            (str) user_id: full canonical @user:id
+            (object) login_submission: dictionary supplied to /login call, from
+               which we pull device_id and initial_device_name
+        Returns:
+            defer.Deferred: (str) device_id
+        """
+        device_id = login_submission.get("device_id")
+        initial_display_name = login_submission.get(
+            "initial_device_display_name")
+        return self.device_handler.check_device_registered(
+            user_id, device_id, initial_display_name
+        )
+
 
 class SAML2RestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/login/saml2", releases=())
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 8270e8787f..0d312c91d4 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet):
         try:
             old_refresh_token = body["refresh_token"]
             auth_handler = self.hs.get_auth_handler()
-            (user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
-                old_refresh_token, auth_handler.generate_refresh_token)
-            new_access_token = yield auth_handler.issue_access_token(user_id)
+            refresh_result = yield self.store.exchange_refresh_token(
+                old_refresh_token, auth_handler.generate_refresh_token
+            )
+            (user_id, new_refresh_token, device_id) = refresh_result
+            new_access_token = yield auth_handler.issue_access_token(
+                user_id, device_id
+            )
             defer.returnValue((200, {
                 "access_token": new_access_token,
                 "refresh_token": new_refresh_token,