From dcfd71aa4c4a1d3d71356fd2f5d854fb1db8fafa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 12:34:23 +0100 Subject: Refactor login flow Make sure that we have the canonical user_id *before* calling get_login_tuple_for_user_id. Replace login_with_password with a method which just validates the password, and have the caller call get_login_tuple_for_user_id. This brings the password flow into line with the other flows, and will give us a place to register the device_id if necessary. --- synapse/rest/client/v1/login.py | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) (limited to 'synapse/rest/client/v1/login.py') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 8df9d10efa..a1f2ba8773 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -145,10 +145,13 @@ class LoginRestServlet(ClientV1RestServlet): ).to_string() auth_handler = self.auth_handler - user_id, access_token, refresh_token = yield auth_handler.login_with_password( + user_id = yield auth_handler.validate_password_login( user_id=user_id, - password=login_submission["password"]) - + password=login_submission["password"], + ) + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(user_id) + ) result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -165,7 +168,7 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) - user_id, access_token, refresh_token = ( + access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id(user_id) ) result = { @@ -196,13 +199,15 @@ class LoginRestServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id + ) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, @@ -245,13 +250,13 @@ class LoginRestServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(registered_user_id) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, @@ -414,13 +419,13 @@ class CasTicketServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if not user_exists: - user_id, _ = ( + registered_user_id = yield auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id, _ = ( yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(user_id) + login_token = auth_handler.generate_short_term_login_token(registered_user_id) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) -- cgit 1.4.1 From f863a52ceacf69ab19b073383be80603a2f51c0a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 13:19:07 +0100 Subject: Add device_id support to /login Add a 'devices' table to the storage, as well as a 'device_id' column to refresh_tokens. Allow the client to pass a device_id, and initial_device_display_name, to /login. If login is successful, then register the device in the devices table if it wasn't known already. If no device_id was supplied, make one up. Associate the device_id with the access token and refresh token, so that we can get at it again later. Ensure that the device_id is copied from the refresh token to the access_token when the token is refreshed. --- synapse/handlers/auth.py | 19 +++--- synapse/handlers/device.py | 71 ++++++++++++++++++++ synapse/rest/client/v1/login.py | 39 ++++++++++- synapse/rest/client/v2_alpha/tokenrefresh.py | 10 ++- synapse/server.py | 5 ++ synapse/storage/__init__.py | 3 + synapse/storage/devices.py | 77 ++++++++++++++++++++++ synapse/storage/registration.py | 28 +++++--- synapse/storage/schema/delta/33/devices.sql | 21 ++++++ .../schema/delta/33/refreshtoken_device.sql | 16 +++++ tests/handlers/test_device.py | 75 +++++++++++++++++++++ tests/storage/test_registration.py | 21 ++++-- 12 files changed, 354 insertions(+), 31 deletions(-) create mode 100644 synapse/handlers/device.py create mode 100644 synapse/storage/devices.py create mode 100644 synapse/storage/schema/delta/33/devices.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device.sql create mode 100644 tests/handlers/test_device.py (limited to 'synapse/rest/client/v1/login.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 983994fa95..ce9bc18849 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -361,7 +361,7 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id): + def get_login_tuple_for_user_id(self, user_id, device_id=None): """ Gets login tuple for the user with the given user ID. @@ -372,6 +372,7 @@ class AuthHandler(BaseHandler): Args: user_id (str): canonical User ID + device_id (str): the device ID to associate with the access token Returns: A tuple of: The access token for the user's session. @@ -380,9 +381,9 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) + logger.info("Logging in user %s on device %s", user_id, device_id) + access_token = yield self.issue_access_token(user_id, device_id) + refresh_token = yield self.issue_refresh_token(user_id, device_id) defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks @@ -638,15 +639,17 @@ class AuthHandler(BaseHandler): defer.returnValue(False) @defer.inlineCallbacks - def issue_access_token(self, user_id): + def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token) + yield self.store.add_access_token_to_user(user_id, access_token, + device_id) defer.returnValue(access_token) @defer.inlineCallbacks - def issue_refresh_token(self, user_id): + def issue_refresh_token(self, user_id, device_id=None): refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token) + yield self.store.add_refresh_token_to_user(user_id, refresh_token, + device_id) defer.returnValue(refresh_token) def generate_access_token(self, user_id, extra_caveats=None, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py new file mode 100644 index 0000000000..8d7d9874f8 --- /dev/null +++ b/synapse/handlers/device.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import StoreError +from synapse.util import stringutils +from twisted.internet import defer +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeviceHandler(BaseHandler): + def __init__(self, hs): + super(DeviceHandler, self).__init__(hs) + + @defer.inlineCallbacks + def check_device_registered(self, user_id, device_id, + initial_device_display_name): + """ + If the given device has not been registered, register it with the + supplied display name. + + If no device_id is supplied, we make one up. + + Args: + user_id (str): @user:id + device_id (str | None): device id supplied by client + initial_device_display_name (str | None): device display name from + client + Returns: + str: device id (generated if none was supplied) + """ + if device_id is not None: + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=True, + ) + defer.returnValue(device_id) + + # if the device id is not specified, we'll autogen one, but loop a few + # times in case of a clash. + attempts = 0 + while attempts < 5: + try: + device_id = stringutils.random_string_with_symbols(16) + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=False, + ) + defer.returnValue(device_id) + except StoreError: + attempts += 1 + + raise StoreError(500, "Couldn't generate a device ID.") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a1f2ba8773..e8b791519c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet): self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() + self.device_handler = self.hs.get_device_handler() def on_GET(self, request): flows = [] @@ -149,14 +150,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id=user_id, password=login_submission["password"], ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -168,14 +171,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -252,8 +257,13 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: + device_id = yield self._register_device( + registered_user_id, login_submission + ) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(registered_user_id) + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id, device_id + ) ) result = { "user_id": registered_user_id, @@ -262,6 +272,9 @@ class LoginRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, } else: + # TODO: we should probably check that the register isn't going + # to fonx/change our user_id before registering the device + device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) @@ -300,6 +313,26 @@ class LoginRestServlet(ClientV1RestServlet): return (user, attributes) + def _register_device(self, user_id, login_submission): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) login_submission: dictionary supplied to /login call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get( + "initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + class SAML2RestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/saml2", releases=()) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 8270e8787f..0d312c91d4 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet): try: old_refresh_token = body["refresh_token"] auth_handler = self.hs.get_auth_handler() - (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token) - new_access_token = yield auth_handler.issue_access_token(user_id) + refresh_result = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token + ) + (user_id, new_refresh_token, device_id) = refresh_result + new_access_token = yield auth_handler.issue_access_token( + user_id, device_id + ) defer.returnValue((200, { "access_token": new_access_token, "refresh_token": new_refresh_token, diff --git a/synapse/server.py b/synapse/server.py index d49a1a8a96..e8b166990d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,6 +25,7 @@ from twisted.enterprise import adbapi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.api import ApplicationServiceApi from synapse.federation import initialize_http_replication +from synapse.handlers.device import DeviceHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier from synapse.api.auth import Auth @@ -92,6 +93,7 @@ class HomeServer(object): 'typing_handler', 'room_list_handler', 'auth_handler', + 'device_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -197,6 +199,9 @@ class HomeServer(object): def build_auth_handler(self): return AuthHandler(self) + def build_device_handler(self): + return DeviceHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 1c93e18f9d..73fb334dd6 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. from twisted.internet import defer + +from synapse.storage.devices import DeviceStore from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) @@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore, EventPushActionsStore, OpenIdStore, ClientIpStore, + DeviceStore, ): def __init__(self, db_conn, hs): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py new file mode 100644 index 0000000000..9065e96d28 --- /dev/null +++ b/synapse/storage/devices.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from ._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class DeviceStore(SQLBaseStore): + @defer.inlineCallbacks + def store_device(self, user_id, device_id, + initial_device_display_name, + ignore_if_known=True): + """Ensure the given device is known; add it to the store if not + + Args: + user_id (str): id of user associated with the device + device_id (str): id of device + initial_device_display_name (str): initial displayname of the + device + ignore_if_known (bool): ignore integrity errors which mean the + device is already known + Returns: + defer.Deferred + Raises: + StoreError: if ignore_if_known is False and the device was already + known + """ + try: + yield self._simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": device_id, + "display_name": initial_device_display_name + }, + desc="store_device", + or_ignore=ignore_if_known, + ) + except Exception as e: + logger.error("store_device with device_id=%s failed: %s", + device_id, e) + raise StoreError(500, "Problem storing device.") + + def get_device(self, user_id, device_id): + """Retrieve a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred for a namedtuple containing the device information + Raises: + StoreError: if the device is not found + """ + return self._simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d957a629dc..26ef1cfd8a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -31,12 +31,14 @@ class RegistrationStore(SQLBaseStore): self.clock = hs.get_clock() @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token): + def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -47,18 +49,21 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_access_token_to_user", ) @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token): + def add_refresh_token_to_user(self, user_id, token, device_id=None): """Adds a refresh token for the given user. Args: user_id (str): The user ID. token (str): The new refresh token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -69,7 +74,8 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_refresh_token_to_user", ) @@ -291,18 +297,18 @@ class RegistrationStore(SQLBaseStore): ) def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new access token and refresh token. + """Exchange a refresh token for a new one. Doing so invalidates the old refresh token - refresh tokens are single use. Args: - token (str): The refresh token of a user. + refresh_token (str): The refresh token of a user. token_generator (fn: str -> str): Function which, when given a user ID, returns a unique refresh token for that user. This function must never return the same value twice. Returns: - tuple of (user_id, refresh_token) + tuple of (user_id, new_refresh_token, device_id) Raises: StoreError if no user was found with that refresh token. """ @@ -314,12 +320,13 @@ class RegistrationStore(SQLBaseStore): ) def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" txn.execute(sql, (old_token,)) rows = self.cursor_to_dict(txn) if not rows: raise StoreError(403, "Did not recognize refresh token") user_id = rows[0]["user_id"] + device_id = rows[0]["device_id"] # TODO(danielwh): Maybe perform a validation on the macaroon that # macaroon.user_id == user_id. @@ -328,7 +335,7 @@ class RegistrationStore(SQLBaseStore): sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" txn.execute(sql, (new_token, old_token,)) - return user_id, new_token + return user_id, new_token, device_id @defer.inlineCallbacks def is_server_admin(self, user): @@ -356,7 +363,8 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id" + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql new file mode 100644 index 0000000000..eca7268d82 --- /dev/null +++ b/synapse/storage/schema/delta/33/devices.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE devices ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/33/refreshtoken_device.sql new file mode 100644 index 0000000000..b21da00dde --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE refresh_tokens ADD COLUMN device_id BIGINT; diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py new file mode 100644 index 0000000000..cc6512ccc7 --- /dev/null +++ b/tests/handlers/test_device.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.handlers.device import DeviceHandler +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DeviceHandlers(object): + def __init__(self, hs): + self.device_handler = DeviceHandler(hs) + + +class DeviceTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(handlers=None) + self.hs.handlers = handlers = DeviceHandlers(self.hs) + self.handler = handlers.device_handler + + @defer.inlineCallbacks + def test_device_is_created_if_doesnt_exist(self): + res = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_is_preserved_if_exists(self): + res1 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res1, "fco") + + res2 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="new display name" + ) + self.assertEqual(res2, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_id_is_made_up_if_unspecified(self): + device_id = yield self.handler.check_device_registered( + user_id="theresa", + device_id=None, + initial_device_display_name="display" + ) + + dev = yield self.handler.store.get_device("theresa", device_id) + self.assertEqual(dev["display_name"], "display") diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index b8384c98d8..b03ca303a2 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -38,6 +38,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "BcDeFgHiJkLmNoPqRsTuVwXyZa" ] self.pwhash = "{xx1}123456789" + self.device_id = "akgjhdjklgshg" @defer.inlineCallbacks def test_register(self): @@ -64,13 +65,15 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) - yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], + self.device_id) result = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertDictContainsSubset( { "name": self.user_id, + "device_id": self.device_id, }, result ) @@ -80,20 +83,24 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_exchange_refresh_token_valid(self): uid = stringutils.random_string(32) + device_id = stringutils.random_string(16) generator = TokenGenerator() last_token = generator.generate(uid) self.db_pool.runQuery( - "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", - (uid, last_token,)) + "INSERT INTO refresh_tokens(user_id, token, device_id) " + "VALUES(?,?,?)", + (uid, last_token, device_id)) - (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( - last_token, generator.generate) + (found_user_id, refresh_token, device_id) = \ + yield self.store.exchange_refresh_token(last_token, + generator.generate) self.assertEqual(uid, found_user_id) rows = yield self.db_pool.runQuery( - "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) - self.assertEqual([(refresh_token,)], rows) + "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?", + (uid, )) + self.assertEqual([(refresh_token, device_id)], rows) # We issued token 1, then exchanged it for token 2 expected_refresh_token = u"%s-%d" % (uid, 2,) self.assertEqual(expected_refresh_token, refresh_token) -- cgit 1.4.1 From 436bffd15fb8382a0d2dddd3c6f7a077ba751da2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 22 Jul 2016 14:52:53 +0100 Subject: Implement deleting devices --- synapse/handlers/auth.py | 22 ++++++++++++++++-- synapse/handlers/device.py | 27 +++++++++++++++++++++- synapse/rest/client/v1/login.py | 13 ++++++++--- synapse/rest/client/v2_alpha/devices.py | 14 +++++++++++ synapse/rest/client/v2_alpha/register.py | 10 ++++---- synapse/storage/devices.py | 15 ++++++++++++ synapse/storage/registration.py | 26 +++++++++++++++++---- .../schema/delta/33/access_tokens_device_index.sql | 17 ++++++++++++++ .../schema/delta/33/refreshtoken_device_index.sql | 17 ++++++++++++++ tests/handlers/test_device.py | 22 ++++++++++++++++-- tests/rest/client/v2_alpha/test_register.py | 14 +++++++---- 11 files changed, 176 insertions(+), 21 deletions(-) create mode 100644 synapse/storage/schema/delta/33/access_tokens_device_index.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device_index.sql (limited to 'synapse/rest/client/v1/login.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d2072436..2e138f328f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -77,6 +77,7 @@ class AuthHandler(BaseHandler): self.ldap_bind_password = hs.config.ldap_bind_password self.hs = hs # FIXME better possibility to access registrationHandler later? + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -374,7 +375,8 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id, device_id=None): + def get_login_tuple_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ Gets login tuple for the user with the given user ID. @@ -383,9 +385,15 @@ class AuthHandler(BaseHandler): The user is assumed to have been authenticated by some other machanism (e.g. CAS), and the user_id converted to the canonical case. + The device will be recorded in the table if it is not there already. + Args: user_id (str): canonical User ID - device_id (str): the device ID to associate with the access token + device_id (str|None): the device ID to associate with the tokens. + None to leave the tokens unassociated with a device (deprecated: + we should always have a device ID) + initial_display_name (str): display name to associate with the + device if it needs re-registering Returns: A tuple of: The access token for the user's session. @@ -397,6 +405,16 @@ class AuthHandler(BaseHandler): logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) refresh_token = yield self.issue_refresh_token(user_id, device_id) + + # the device *should* have been registered before we got here; however, + # it's possible we raced against a DELETE operation. The thing we + # really don't want is active access_tokens without a record of the + # device, so we double-check it here. + if device_id is not None: + yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1f9e15c33c..a7a192e1c9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler): Args: user_id (str): - device_id (str) + device_id (str): Returns: defer.Deferred: dict[str, X]: info on the device @@ -117,6 +117,31 @@ class DeviceHandler(BaseHandler): _update_device_from_client_ips(device, ips) defer.returnValue(device) + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """ Delete the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: + """ + + try: + yield self.store.delete_device(user_id, device_id) + except errors.StoreError, e: + if e.code == 404: + # no match + pass + else: + raise + + yield self.store.user_delete_access_tokens(user_id, + device_id=device_id) + + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e8b791519c..92fcae674a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id( - registered_user_id, device_id + registered_user_id, device_id, + login_submission.get("initial_device_display_name") ) ) result = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 8b9ab4f674..30ef8b3da9 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -70,6 +70,20 @@ class DeviceRestServlet(RestServlet): ) defer.returnValue((200, device)) + @defer.inlineCallbacks + def on_DELETE(self, request, device_id): + # XXX: it's not completely obvious we want to expose this endpoint. + # It allows the client to delete access tokens, which feels like a + # thing which merits extra auth. But if we want to do the interactive- + # auth dance, we should really make it possible to delete more than one + # device at a time. + requester = yield self.auth.get_user_by_req(request) + yield self.device_handler.delete_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, {})) + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c8c9395fc6..9f599ea8bb 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet): """ device_id = yield self._register_device(user_id, params) - access_token = yield self.auth_handler.issue_access_token( - user_id, device_id=device_id + access_token, refresh_token = ( + yield self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, + initial_display_name=params.get("initial_device_display_name") + ) ) - refresh_token = yield self.auth_handler.issue_refresh_token( - user_id, device_id=device_id - ) defer.returnValue({ "user_id": user_id, "access_token": access_token, diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 1cc6e07f2b..4689980f80 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -76,6 +76,21 @@ class DeviceStore(SQLBaseStore): desc="get_device", ) + def delete_device(self, user_id, device_id): + """Delete a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred + """ + return self._simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_device", + ) + @defer.inlineCallbacks def get_devices_by_user(self, user_id): """Retrieve all of a user's registered devices. diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 9a92b35361..935e82bf7a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,18 +18,31 @@ import re from twisted.internet import defer from synapse.api.errors import StoreError, Codes - -from ._base import SQLBaseStore +from synapse.storage import background_updates from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -class RegistrationStore(SQLBaseStore): +class RegistrationStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): super(RegistrationStore, self).__init__(hs) self.clock = hs.get_clock() + self.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.register_background_index_update( + "refresh_tokens_device_index", + index_name="refresh_tokens_device_id", + table="refresh_tokens", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -238,11 +251,16 @@ class RegistrationStore(SQLBaseStore): self.get_user_by_id.invalidate((user_id,)) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids=[]): + def user_delete_access_tokens(self, user_id, except_token_ids=[], + device_id=None): def f(txn): sql = "SELECT token FROM access_tokens WHERE user_id = ?" clauses = [user_id] + if device_id is not None: + sql += " AND device_id = ?" + clauses.append(device_id) + if except_token_ids: sql += " AND id NOT IN (%s)" % ( ",".join(["?" for _ in except_token_ids]), diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql new file mode 100644 index 0000000000..61ad3fe3e8 --- /dev/null +++ b/synapse/storage/schema/delta/33/access_tokens_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql new file mode 100644 index 0000000000..bb225dafbf --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('refresh_tokens_device_index', '{}'); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 331aa13fed..214e722eb3 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse import types + from twisted.internet import defer +import synapse.api.errors import synapse.handlers.device + import synapse.storage +from synapse import types from tests import unittest, utils user1 = "@boris:aaa" @@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(DeviceTestCase, self).__init__(*args, **kwargs) self.store = None # type: synapse.storage.DataStore - self.handler = None # type: device.DeviceHandler + self.handler = None # type: synapse.handlers.device.DeviceHandler self.clock = None # type: utils.MockClock @defer.inlineCallbacks @@ -123,6 +126,21 @@ class DeviceTestCase(unittest.TestCase): "last_seen_ts": 3000000, }, res) + @defer.inlineCallbacks + def test_delete_device(self): + yield self._record_users() + + # delete the device + yield self.handler.delete_device(user1, "abc") + + # check the device was deleted + with self.assertRaises(synapse.api.errors.NotFoundError): + yield self.handler.get_device(user1, "abc") + + # we'd like to check the access token was invalidated, but that's a + # bit of a PITA. + + @defer.inlineCallbacks def _record_users(self): # check this works for both devices which have a recorded client_ip, diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 3bd7065e32..8ac56a1fb2 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.appservice_register = Mock( return_value=user_id ) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname } self.assertDictContainsSubset(det_data, result) @@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase): "password": "monkey" }, None) self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) self.device_handler.check_device_registered = \ Mock(return_value=device_id) @@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase): det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname, "device_id": device_id, } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) - self.auth_handler.issue_access_token.assert_called_once_with( - user_id, device_id=device_id) + self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, initial_device_display_name=None) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.4.1