From f863a52ceacf69ab19b073383be80603a2f51c0a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 13:19:07 +0100 Subject: Add device_id support to /login Add a 'devices' table to the storage, as well as a 'device_id' column to refresh_tokens. Allow the client to pass a device_id, and initial_device_display_name, to /login. If login is successful, then register the device in the devices table if it wasn't known already. If no device_id was supplied, make one up. Associate the device_id with the access token and refresh token, so that we can get at it again later. Ensure that the device_id is copied from the refresh token to the access_token when the token is refreshed. --- synapse/rest/client/v1/login.py | 39 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) (limited to 'synapse/rest/client/v1/login.py') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a1f2ba8773..e8b791519c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet): self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() + self.device_handler = self.hs.get_device_handler() def on_GET(self, request): flows = [] @@ -149,14 +150,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id=user_id, password=login_submission["password"], ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -168,14 +171,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -252,8 +257,13 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: + device_id = yield self._register_device( + registered_user_id, login_submission + ) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(registered_user_id) + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id, device_id + ) ) result = { "user_id": registered_user_id, @@ -262,6 +272,9 @@ class LoginRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, } else: + # TODO: we should probably check that the register isn't going + # to fonx/change our user_id before registering the device + device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) @@ -300,6 +313,26 @@ class LoginRestServlet(ClientV1RestServlet): return (user, attributes) + def _register_device(self, user_id, login_submission): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) login_submission: dictionary supplied to /login call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get( + "initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + class SAML2RestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/saml2", releases=()) -- cgit 1.4.1