From 0a32208e5dde4980a5962f17e9b27f2e28e1f3f1 Mon Sep 17 00:00:00 2001 From: Martin Weinelt Date: Mon, 6 Jun 2016 02:05:57 +0200 Subject: Rework ldap integration with ldap3 Use the pure-python ldap3 library, which eliminates the need for a system dependency. Offer both a `search` and `simple_bind` mode, for more sophisticated ldap scenarios. - `search` tries to find a matching DN within the `user_base` while employing the `user_filter`, then tries the bind when a single matching DN was found. - `simple_bind` tries the bind against a specific DN by combining the localpart and `user_base` Offer support for STARTTLS on a plain connection. The configuration was changed to reflect these new possibilities. Signed-off-by: Martin Weinelt --- synapse/handlers/auth.py | 203 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 170 insertions(+), 33 deletions(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b38f81e999..968095c141 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -20,6 +20,7 @@ from synapse.api.constants import LoginType from synapse.types import UserID from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor +from synapse.config.ldap import LDAPMode from twisted.web.client import PartialDownloadError @@ -28,6 +29,12 @@ import bcrypt import pymacaroons import simplejson +try: + import ldap3 +except ImportError: + ldap3 = None + pass + import synapse.util.stringutils as stringutils @@ -50,17 +57,20 @@ class AuthHandler(BaseHandler): self.INVALID_TOKEN_HTTP_STATUS = 401 self.ldap_enabled = hs.config.ldap_enabled - self.ldap_server = hs.config.ldap_server - self.ldap_port = hs.config.ldap_port - self.ldap_tls = hs.config.ldap_tls - self.ldap_search_base = hs.config.ldap_search_base - self.ldap_search_property = hs.config.ldap_search_property - self.ldap_email_property = hs.config.ldap_email_property - self.ldap_full_name_property = hs.config.ldap_full_name_property - - if self.ldap_enabled is True: - import ldap - logger.info("Import ldap version: %s", ldap.__version__) + if self.ldap_enabled: + if not ldap3: + raise RuntimeError( + 'Missing ldap3 library. This is required for LDAP Authentication.' + ) + self.ldap_mode = hs.config.ldap_mode + self.ldap_uri = hs.config.ldap_uri + self.ldap_start_tls = hs.config.ldap_start_tls + self.ldap_base = hs.config.ldap_base + self.ldap_filter = hs.config.ldap_filter + self.ldap_attributes = hs.config.ldap_attributes + if self.ldap_mode == LDAPMode.SEARCH: + self.ldap_bind_dn = hs.config.ldap_bind_dn + self.ldap_bind_password = hs.config.ldap_bind_password self.hs = hs # FIXME better possibility to access registrationHandler later? @@ -452,40 +462,167 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): - if not self.ldap_enabled: - logger.debug("LDAP not configured") + """ Attempt to authenticate a user against an LDAP Server + and register an account if none exists. + + Returns: + True if authentication against LDAP was successful + """ + + if not ldap3 or not self.ldap_enabled: defer.returnValue(False) - import ldap + if self.ldap_mode not in LDAPMode.LIST: + raise RuntimeError( + 'Invalid ldap mode specified: {mode}'.format( + mode=self.ldap_mode + ) + ) - logger.info("Authenticating %s with LDAP" % user_id) try: - ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) - logger.debug("Connecting LDAP server at %s" % ldap_url) - l = ldap.initialize(ldap_url) - if self.ldap_tls: - logger.debug("Initiating TLS") - self._connection.start_tls_s() + server = ldap3.Server(self.ldap_uri) + logger.debug( + "Attempting ldap connection with %s", + self.ldap_uri + ) - local_name = UserID.from_string(user_id).localpart + localpart = UserID.from_string(user_id).localpart + if self.ldap_mode == LDAPMode.SIMPLE: + # bind with the the local users ldap credentials + bind_dn = "{prop}={value},{base}".format( + prop=self.ldap_attributes['uid'], + value=localpart, + base=self.ldap_base + ) + conn = ldap3.Connection(server, bind_dn, password) + logger.debug( + "Established ldap connection in simple mode: %s", + conn + ) - dn = "%s=%s, %s" % ( - self.ldap_search_property, - local_name, - self.ldap_search_base) - logger.debug("DN for LDAP authentication: %s" % dn) + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded ldap connection in simple mode through StartTLS: %s", + conn + ) + + conn.bind() + + elif self.ldap_mode == LDAPMode.SEARCH: + # connect with preconfigured credentials and search for local user + conn = ldap3.Connection( + server, + self.ldap_bind_dn, + self.ldap_bind_password + ) + logger.debug( + "Established ldap connection in search mode: %s", + conn + ) + + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded ldap connection in search mode through StartTLS: %s", + conn + ) - l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) + conn.bind() + # find matching dn + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + if self.ldap_filter: + query = "(&{query}{filter})".format( + query=query, + filter=self.ldap_filter + ) + logger.debug("ldap search filter: %s", query) + result = conn.search(self.ldap_base, query) + + if result and len(conn.response) == 1: + # found exactly one result + user_dn = conn.response[0]['dn'] + logger.debug('ldap search found dn: %s', user_dn) + + # unbind and reconnect, rebind with found dn + conn.unbind() + conn = ldap3.Connection( + server, + user_dn, + password, + auto_bind=True + ) + else: + # found 0 or > 1 results, abort! + logger.warn( + "ldap search returned unexpected (%d!=1) amount of results", + len(conn.response) + ) + defer.returnValue(False) + + logger.info( + "User authenticated against ldap server: %s", + conn + ) + + # check for existing account, if none exists, create one if not (yield self.does_user_exist(user_id)): - handler = self.hs.get_handlers().registration_handler - user_id, access_token = ( - yield handler.register(localpart=local_name) + # query user metadata for account creation + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + + if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: + query = "(&{filter}{user_filter})".format( + filter=query, + user_filter=self.ldap_filter + ) + logger.debug("ldap registration filter: %s", query) + + result = conn.search( + search_base=self.ldap_base, + search_filter=query, + attributes=[ + self.ldap_attributes['name'], + self.ldap_attributes['mail'] + ] ) + if len(conn.response) == 1: + attrs = conn.response[0]['attributes'] + mail = attrs[self.ldap_attributes['mail']][0] + name = attrs[self.ldap_attributes['name']][0] + + # create account + registration_handler = self.hs.get_handlers().registration_handler + user_id, access_token = ( + yield registration_handler.register(localpart=localpart) + ) + + # TODO: bind email, set displayname with data from ldap directory + + logger.info( + "ldap registration successful: %d: %s (%s, %)", + user_id, + localpart, + name, + mail + ) + else: + logger.warn( + "ldap registration failed: unexpected (%d!=1) amount of results", + len(result) + ) + defer.returnValue(False) + defer.returnValue(True) - except ldap.LDAPError, e: - logger.warn("LDAP error: %s", e) + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during ldap authentication: %s", e) defer.returnValue(False) @defer.inlineCallbacks -- cgit 1.4.1 From 8bdaf5f7afaee98a8cf25d2fb170fe4b2aa97f3d Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Tue, 5 Jul 2016 02:13:52 +0900 Subject: Add pepper to password hashing Signed-off-by: Kent Shikama --- synapse/config/password.py | 6 +++++- synapse/handlers/auth.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/config/password.py b/synapse/config/password.py index dec801ef41..ea822f2bb5 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -23,10 +23,14 @@ class PasswordConfig(Config): def read_config(self, config): password_config = config.get("password_config", {}) self.password_enabled = password_config.get("enabled", True) + self.pepper = password_config.get("pepper", "") def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable password for login. password_config: enabled: true - """ + # Uncomment for extra security for your passwords. + # DO NOT CHANGE THIS AFTER INITIAL SETUP! + #pepper: "HR32t0xZcQnzn3O0ZkEVuetdFvH1W6TeEPw6JjH0Cl+qflVOseGyFJlJR7ACLnywjN9" + """ \ No newline at end of file diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 968095c141..fd5fadf73d 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -750,7 +750,7 @@ class AuthHandler(BaseHandler): Returns: Hashed password (str). """ - return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds)) + return bcrypt.hashpw(password + self.hs.config.password_config.pepper, bcrypt.gensalt(self.bcrypt_rounds)) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. @@ -763,6 +763,7 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash (bool). """ if stored_hash: - return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash + return bcrypt.hashpw(password + self.hs.config.password_config.pepper, + stored_hash.encode('utf-8')) == stored_hash else: return False -- cgit 1.4.1 From 1ee258430724618c7014bb176186c23b0b5b06f0 Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Tue, 5 Jul 2016 19:01:00 +0900 Subject: Fix pep8 --- synapse/config/password.py | 2 +- synapse/handlers/auth.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/config/password.py b/synapse/config/password.py index 7c5cb5f0e1..058a3a5346 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -34,4 +34,4 @@ class PasswordConfig(Config): # Change to a secret random string. # DO NOT CHANGE THIS AFTER INITIAL SETUP! #pepper: "HR32t0xZcQnzn3O0ZkEVuetdFvH1W6TeEPw6JjH0Cl+qflVOseGyFJlJR7ACLnywjN9" - """ \ No newline at end of file + """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fd5fadf73d..be46681c64 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -750,7 +750,8 @@ class AuthHandler(BaseHandler): Returns: Hashed password (str). """ - return bcrypt.hashpw(password + self.hs.config.password_config.pepper, bcrypt.gensalt(self.bcrypt_rounds)) + return bcrypt.hashpw(password + self.hs.config.password_config.pepper, + bcrypt.gensalt(self.bcrypt_rounds)) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. -- cgit 1.4.1 From 14362bf3590eb95a50201a84c8e16d5626b86249 Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Tue, 5 Jul 2016 19:12:53 +0900 Subject: Fix password config --- synapse/config/password.py | 2 +- synapse/handlers/auth.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/config/password.py b/synapse/config/password.py index 058a3a5346..00b1ea3df9 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -23,7 +23,7 @@ class PasswordConfig(Config): def read_config(self, config): password_config = config.get("password_config", {}) self.password_enabled = password_config.get("enabled", True) - self.pepper = password_config.get("pepper", "") + self.password_pepper = password_config.get("pepper", "") def default_config(self, config_dir_path, server_name, **kwargs): return """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index be46681c64..e259213a36 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -750,7 +750,7 @@ class AuthHandler(BaseHandler): Returns: Hashed password (str). """ - return bcrypt.hashpw(password + self.hs.config.password_config.pepper, + return bcrypt.hashpw(password + self.hs.config.password_pepper, bcrypt.gensalt(self.bcrypt_rounds)) def validate_hash(self, password, stored_hash): @@ -764,7 +764,7 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash (bool). """ if stored_hash: - return bcrypt.hashpw(password + self.hs.config.password_config.pepper, + return bcrypt.hashpw(password + self.hs.config.password_pepper, stored_hash.encode('utf-8')) == stored_hash else: return False -- cgit 1.4.1 From 0136a522b18a734db69171d60566f501c0ced663 Mon Sep 17 00:00:00 2001 From: Negar Fazeli Date: Fri, 8 Jul 2016 16:53:18 +0200 Subject: Bug fix: expire invalid access tokens --- synapse/api/auth.py | 3 +++ synapse/handlers/auth.py | 5 +++-- synapse/handlers/register.py | 6 +++--- synapse/rest/client/v1/register.py | 2 +- tests/api/test_auth.py | 31 ++++++++++++++++++++++++++++++- tests/handlers/test_register.py | 4 ++-- 6 files changed, 42 insertions(+), 9 deletions(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index a4d658a9d0..521a52e001 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -629,7 +629,10 @@ class Auth(object): except AuthError: # TODO(daniel): Remove this fallback when all existing access tokens # have been re-issued as macaroons. + if self.hs.config.expire_access_token: + raise ret = yield self._look_up_user_by_access_token(token) + defer.returnValue(ret) @defer.inlineCallbacks diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e259213a36..5a0ed9d6b9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -637,12 +637,13 @@ class AuthHandler(BaseHandler): yield self.store.add_refresh_token_to_user(user_id, refresh_token) defer.returnValue(refresh_token) - def generate_access_token(self, user_id, extra_caveats=None): + def generate_access_token(self, user_id, extra_caveats=None, + duration_in_ms=(60 * 60 * 1000)): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() - expiry = now + (60 * 60 * 1000) + expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8c3381df8a..6b33b27149 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -360,7 +360,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_seconds, + def get_or_create_user(self, localpart, displayname, duration_in_ms, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -390,8 +390,8 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_short_term_login_token( - user_id, duration_seconds) + token = self.auth_handler().generate_access_token( + user_id, None, duration_in_ms) if need_register: yield self.store.register( diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index ce7099b18f..8e1f1b7845 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -429,7 +429,7 @@ class CreateUserRestServlet(ClientV1RestServlet): user_id, token = yield handler.get_or_create_user( localpart=localpart, displayname=displayname, - duration_seconds=duration_seconds, + duration_in_ms=(duration_seconds * 1000), password_hash=password_hash ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ad269af0ec..960c23d631 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -281,7 +281,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) - macaroon.add_first_party_caveat("time < 1") # ms + macaroon.add_first_party_caveat("time < -2000") # ms self.hs.clock.now = 5000 # seconds self.hs.config.expire_access_token = True @@ -293,3 +293,32 @@ class AuthTestCase(unittest.TestCase): yield self.auth.get_user_from_macaroon(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_with_valid_duration(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("time < 900000000") # ms + + self.hs.clock.now = 5000 # seconds + self.hs.config.expire_access_token = True + + user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) + user = user_info["user"] + self.assertEqual(UserID.from_string(user_id), user) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 69a5e5b1d4..a7de3c7c17 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -42,12 +42,12 @@ class RegistrationTestCase(unittest.TestCase): http_client=None, expire_access_token=True) self.auth_handler = Mock( - generate_short_term_login_token=Mock(return_value='secret')) + generate_access_token=Mock(return_value='secret')) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.hs.get_handlers().profile_handler = Mock() self.mock_handler = Mock(spec=[ - "generate_short_term_login_token", + "generate_access_token", ]) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) -- cgit 1.4.1 From dcfd71aa4c4a1d3d71356fd2f5d854fb1db8fafa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 12:34:23 +0100 Subject: Refactor login flow Make sure that we have the canonical user_id *before* calling get_login_tuple_for_user_id. Replace login_with_password with a method which just validates the password, and have the caller call get_login_tuple_for_user_id. This brings the password flow into line with the other flows, and will give us a place to register the device_id if necessary. --- synapse/handlers/auth.py | 106 ++++++++++++++++++++++------------------ synapse/rest/client/v1/login.py | 41 +++++++++------- 2 files changed, 82 insertions(+), 65 deletions(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 5a0ed9d6b9..983994fa95 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -230,7 +230,6 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) - @defer.inlineCallbacks def _check_password_auth(self, authdict, _): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) @@ -240,11 +239,7 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - if not (yield self._check_password(user_id, password)): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - defer.returnValue(user_id) + return self._check_password(user_id, password) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -348,67 +343,66 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - @defer.inlineCallbacks - def login_with_password(self, user_id, password): + def validate_password_login(self, user_id, password): """ Authenticates the user with their username and password. Used only by the v1 login API. Args: - user_id (str): User ID + user_id (str): complete @user:id password (str): Password Returns: - A tuple of: - The user's ID. - The access token for the user's session. - The refresh token for the user's session. + defer.Deferred: (str) canonical user id Raises: - StoreError if there was a problem storing the token. + StoreError if there was a problem accessing the database LoginError if there was an authentication problem. """ - - if not (yield self._check_password(user_id, password)): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + return self._check_password(user_id, password) @defer.inlineCallbacks def get_login_tuple_for_user_id(self, user_id): """ Gets login tuple for the user with the given user ID. + + Creates a new access/refresh token for the user. + The user is assumed to have been authenticated by some other - machanism (e.g. CAS) + machanism (e.g. CAS), and the user_id converted to the canonical case. Args: - user_id (str): User ID + user_id (str): canonical User ID Returns: A tuple of: - The user's ID. The access token for the user's session. The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) - logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks - def does_user_exist(self, user_id): + def check_user_exists(self, user_id): + """ + Checks to see if a user with the given id exists. Will check case + insensitively, but return None if there are multiple inexact matches. + + Args: + (str) user_id: complete @user:id + + Returns: + defer.Deferred: (str) canonical_user_id, or None if zero or + multiple matches + """ try: - yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(True) + res = yield self._find_user_id_and_pwd_hash(user_id) + defer.returnValue(res[0]) except LoginError: - defer.returnValue(False) + defer.returnValue(None) @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): @@ -438,27 +432,45 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_password(self, user_id, password): - """ + """Authenticate a user against the LDAP and local databases. + + user_id is checked case insensitively against the local database, but + will throw if there are multiple inexact matches. + + Args: + user_id (str): complete @user:id Returns: - True if the user_id successfully authenticated + (str) the canonical_user_id + Raises: + LoginError if the password was incorrect """ valid_ldap = yield self._check_ldap_password(user_id, password) if valid_ldap: - defer.returnValue(True) + defer.returnValue(user_id) - valid_local_password = yield self._check_local_password(user_id, password) - if valid_local_password: - defer.returnValue(True) - - defer.returnValue(False) + result = yield self._check_local_password(user_id, password) + defer.returnValue(result) @defer.inlineCallbacks def _check_local_password(self, user_id, password): - try: - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(self.validate_hash(password, password_hash)) - except LoginError: - defer.returnValue(False) + """Authenticate a user against the local password database. + + user_id is checked case insensitively, but will throw if there are + multiple inexact matches. + + Args: + user_id (str): complete @user:id + Returns: + (str) the canonical_user_id + Raises: + LoginError if the password was incorrect + """ + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + result = self.validate_hash(password, password_hash) + if not result: + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(user_id) @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): @@ -570,7 +582,7 @@ class AuthHandler(BaseHandler): ) # check for existing account, if none exists, create one - if not (yield self.does_user_exist(user_id)): + if not (yield self.check_user_exists(user_id)): # query user metadata for account creation query = "({prop}={value})".format( prop=self.ldap_attributes['uid'], diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 8df9d10efa..a1f2ba8773 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -145,10 +145,13 @@ class LoginRestServlet(ClientV1RestServlet): ).to_string() auth_handler = self.auth_handler - user_id, access_token, refresh_token = yield auth_handler.login_with_password( + user_id = yield auth_handler.validate_password_login( user_id=user_id, - password=login_submission["password"]) - + password=login_submission["password"], + ) + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(user_id) + ) result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -165,7 +168,7 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) - user_id, access_token, refresh_token = ( + access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id(user_id) ) result = { @@ -196,13 +199,15 @@ class LoginRestServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id + ) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, @@ -245,13 +250,13 @@ class LoginRestServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(registered_user_id) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, @@ -414,13 +419,13 @@ class CasTicketServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if not user_exists: - user_id, _ = ( + registered_user_id = yield auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id, _ = ( yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(user_id) + login_token = auth_handler.generate_short_term_login_token(registered_user_id) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) -- cgit 1.4.1 From f863a52ceacf69ab19b073383be80603a2f51c0a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 13:19:07 +0100 Subject: Add device_id support to /login Add a 'devices' table to the storage, as well as a 'device_id' column to refresh_tokens. Allow the client to pass a device_id, and initial_device_display_name, to /login. If login is successful, then register the device in the devices table if it wasn't known already. If no device_id was supplied, make one up. Associate the device_id with the access token and refresh token, so that we can get at it again later. Ensure that the device_id is copied from the refresh token to the access_token when the token is refreshed. --- synapse/handlers/auth.py | 19 +++--- synapse/handlers/device.py | 71 ++++++++++++++++++++ synapse/rest/client/v1/login.py | 39 ++++++++++- synapse/rest/client/v2_alpha/tokenrefresh.py | 10 ++- synapse/server.py | 5 ++ synapse/storage/__init__.py | 3 + synapse/storage/devices.py | 77 ++++++++++++++++++++++ synapse/storage/registration.py | 28 +++++--- synapse/storage/schema/delta/33/devices.sql | 21 ++++++ .../schema/delta/33/refreshtoken_device.sql | 16 +++++ tests/handlers/test_device.py | 75 +++++++++++++++++++++ tests/storage/test_registration.py | 21 ++++-- 12 files changed, 354 insertions(+), 31 deletions(-) create mode 100644 synapse/handlers/device.py create mode 100644 synapse/storage/devices.py create mode 100644 synapse/storage/schema/delta/33/devices.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device.sql create mode 100644 tests/handlers/test_device.py (limited to 'synapse/handlers/auth.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 983994fa95..ce9bc18849 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -361,7 +361,7 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id): + def get_login_tuple_for_user_id(self, user_id, device_id=None): """ Gets login tuple for the user with the given user ID. @@ -372,6 +372,7 @@ class AuthHandler(BaseHandler): Args: user_id (str): canonical User ID + device_id (str): the device ID to associate with the access token Returns: A tuple of: The access token for the user's session. @@ -380,9 +381,9 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) + logger.info("Logging in user %s on device %s", user_id, device_id) + access_token = yield self.issue_access_token(user_id, device_id) + refresh_token = yield self.issue_refresh_token(user_id, device_id) defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks @@ -638,15 +639,17 @@ class AuthHandler(BaseHandler): defer.returnValue(False) @defer.inlineCallbacks - def issue_access_token(self, user_id): + def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token) + yield self.store.add_access_token_to_user(user_id, access_token, + device_id) defer.returnValue(access_token) @defer.inlineCallbacks - def issue_refresh_token(self, user_id): + def issue_refresh_token(self, user_id, device_id=None): refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token) + yield self.store.add_refresh_token_to_user(user_id, refresh_token, + device_id) defer.returnValue(refresh_token) def generate_access_token(self, user_id, extra_caveats=None, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py new file mode 100644 index 0000000000..8d7d9874f8 --- /dev/null +++ b/synapse/handlers/device.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import StoreError +from synapse.util import stringutils +from twisted.internet import defer +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeviceHandler(BaseHandler): + def __init__(self, hs): + super(DeviceHandler, self).__init__(hs) + + @defer.inlineCallbacks + def check_device_registered(self, user_id, device_id, + initial_device_display_name): + """ + If the given device has not been registered, register it with the + supplied display name. + + If no device_id is supplied, we make one up. + + Args: + user_id (str): @user:id + device_id (str | None): device id supplied by client + initial_device_display_name (str | None): device display name from + client + Returns: + str: device id (generated if none was supplied) + """ + if device_id is not None: + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=True, + ) + defer.returnValue(device_id) + + # if the device id is not specified, we'll autogen one, but loop a few + # times in case of a clash. + attempts = 0 + while attempts < 5: + try: + device_id = stringutils.random_string_with_symbols(16) + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=False, + ) + defer.returnValue(device_id) + except StoreError: + attempts += 1 + + raise StoreError(500, "Couldn't generate a device ID.") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a1f2ba8773..e8b791519c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet): self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() + self.device_handler = self.hs.get_device_handler() def on_GET(self, request): flows = [] @@ -149,14 +150,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id=user_id, password=login_submission["password"], ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -168,14 +171,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -252,8 +257,13 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: + device_id = yield self._register_device( + registered_user_id, login_submission + ) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(registered_user_id) + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id, device_id + ) ) result = { "user_id": registered_user_id, @@ -262,6 +272,9 @@ class LoginRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, } else: + # TODO: we should probably check that the register isn't going + # to fonx/change our user_id before registering the device + device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) @@ -300,6 +313,26 @@ class LoginRestServlet(ClientV1RestServlet): return (user, attributes) + def _register_device(self, user_id, login_submission): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) login_submission: dictionary supplied to /login call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get( + "initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + class SAML2RestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/saml2", releases=()) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 8270e8787f..0d312c91d4 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet): try: old_refresh_token = body["refresh_token"] auth_handler = self.hs.get_auth_handler() - (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token) - new_access_token = yield auth_handler.issue_access_token(user_id) + refresh_result = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token + ) + (user_id, new_refresh_token, device_id) = refresh_result + new_access_token = yield auth_handler.issue_access_token( + user_id, device_id + ) defer.returnValue((200, { "access_token": new_access_token, "refresh_token": new_refresh_token, diff --git a/synapse/server.py b/synapse/server.py index d49a1a8a96..e8b166990d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,6 +25,7 @@ from twisted.enterprise import adbapi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.api import ApplicationServiceApi from synapse.federation import initialize_http_replication +from synapse.handlers.device import DeviceHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier from synapse.api.auth import Auth @@ -92,6 +93,7 @@ class HomeServer(object): 'typing_handler', 'room_list_handler', 'auth_handler', + 'device_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -197,6 +199,9 @@ class HomeServer(object): def build_auth_handler(self): return AuthHandler(self) + def build_device_handler(self): + return DeviceHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 1c93e18f9d..73fb334dd6 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. from twisted.internet import defer + +from synapse.storage.devices import DeviceStore from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) @@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore, EventPushActionsStore, OpenIdStore, ClientIpStore, + DeviceStore, ): def __init__(self, db_conn, hs): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py new file mode 100644 index 0000000000..9065e96d28 --- /dev/null +++ b/synapse/storage/devices.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from ._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class DeviceStore(SQLBaseStore): + @defer.inlineCallbacks + def store_device(self, user_id, device_id, + initial_device_display_name, + ignore_if_known=True): + """Ensure the given device is known; add it to the store if not + + Args: + user_id (str): id of user associated with the device + device_id (str): id of device + initial_device_display_name (str): initial displayname of the + device + ignore_if_known (bool): ignore integrity errors which mean the + device is already known + Returns: + defer.Deferred + Raises: + StoreError: if ignore_if_known is False and the device was already + known + """ + try: + yield self._simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": device_id, + "display_name": initial_device_display_name + }, + desc="store_device", + or_ignore=ignore_if_known, + ) + except Exception as e: + logger.error("store_device with device_id=%s failed: %s", + device_id, e) + raise StoreError(500, "Problem storing device.") + + def get_device(self, user_id, device_id): + """Retrieve a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred for a namedtuple containing the device information + Raises: + StoreError: if the device is not found + """ + return self._simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d957a629dc..26ef1cfd8a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -31,12 +31,14 @@ class RegistrationStore(SQLBaseStore): self.clock = hs.get_clock() @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token): + def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -47,18 +49,21 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_access_token_to_user", ) @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token): + def add_refresh_token_to_user(self, user_id, token, device_id=None): """Adds a refresh token for the given user. Args: user_id (str): The user ID. token (str): The new refresh token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -69,7 +74,8 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_refresh_token_to_user", ) @@ -291,18 +297,18 @@ class RegistrationStore(SQLBaseStore): ) def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new access token and refresh token. + """Exchange a refresh token for a new one. Doing so invalidates the old refresh token - refresh tokens are single use. Args: - token (str): The refresh token of a user. + refresh_token (str): The refresh token of a user. token_generator (fn: str -> str): Function which, when given a user ID, returns a unique refresh token for that user. This function must never return the same value twice. Returns: - tuple of (user_id, refresh_token) + tuple of (user_id, new_refresh_token, device_id) Raises: StoreError if no user was found with that refresh token. """ @@ -314,12 +320,13 @@ class RegistrationStore(SQLBaseStore): ) def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" txn.execute(sql, (old_token,)) rows = self.cursor_to_dict(txn) if not rows: raise StoreError(403, "Did not recognize refresh token") user_id = rows[0]["user_id"] + device_id = rows[0]["device_id"] # TODO(danielwh): Maybe perform a validation on the macaroon that # macaroon.user_id == user_id. @@ -328,7 +335,7 @@ class RegistrationStore(SQLBaseStore): sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" txn.execute(sql, (new_token, old_token,)) - return user_id, new_token + return user_id, new_token, device_id @defer.inlineCallbacks def is_server_admin(self, user): @@ -356,7 +363,8 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id" + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql new file mode 100644 index 0000000000..eca7268d82 --- /dev/null +++ b/synapse/storage/schema/delta/33/devices.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE devices ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/33/refreshtoken_device.sql new file mode 100644 index 0000000000..b21da00dde --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE refresh_tokens ADD COLUMN device_id BIGINT; diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py new file mode 100644 index 0000000000..cc6512ccc7 --- /dev/null +++ b/tests/handlers/test_device.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.handlers.device import DeviceHandler +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DeviceHandlers(object): + def __init__(self, hs): + self.device_handler = DeviceHandler(hs) + + +class DeviceTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(handlers=None) + self.hs.handlers = handlers = DeviceHandlers(self.hs) + self.handler = handlers.device_handler + + @defer.inlineCallbacks + def test_device_is_created_if_doesnt_exist(self): + res = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_is_preserved_if_exists(self): + res1 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res1, "fco") + + res2 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="new display name" + ) + self.assertEqual(res2, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_id_is_made_up_if_unspecified(self): + device_id = yield self.handler.check_device_registered( + user_id="theresa", + device_id=None, + initial_device_display_name="display" + ) + + dev = yield self.handler.store.get_device("theresa", device_id) + self.assertEqual(dev["display_name"], "display") diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index b8384c98d8..b03ca303a2 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -38,6 +38,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "BcDeFgHiJkLmNoPqRsTuVwXyZa" ] self.pwhash = "{xx1}123456789" + self.device_id = "akgjhdjklgshg" @defer.inlineCallbacks def test_register(self): @@ -64,13 +65,15 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) - yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], + self.device_id) result = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertDictContainsSubset( { "name": self.user_id, + "device_id": self.device_id, }, result ) @@ -80,20 +83,24 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_exchange_refresh_token_valid(self): uid = stringutils.random_string(32) + device_id = stringutils.random_string(16) generator = TokenGenerator() last_token = generator.generate(uid) self.db_pool.runQuery( - "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", - (uid, last_token,)) + "INSERT INTO refresh_tokens(user_id, token, device_id) " + "VALUES(?,?,?)", + (uid, last_token, device_id)) - (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( - last_token, generator.generate) + (found_user_id, refresh_token, device_id) = \ + yield self.store.exchange_refresh_token(last_token, + generator.generate) self.assertEqual(uid, found_user_id) rows = yield self.db_pool.runQuery( - "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) - self.assertEqual([(refresh_token,)], rows) + "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?", + (uid, )) + self.assertEqual([(refresh_token, device_id)], rows) # We issued token 1, then exchanged it for token 2 expected_refresh_token = u"%s-%d" % (uid, 2,) self.assertEqual(expected_refresh_token, refresh_token) -- cgit 1.4.1 From 3413f1e284593aa63723cdcd52f443d63771ef62 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 10:21:42 +0100 Subject: Type annotations Add some type annotations to help PyCharm (in particular) to figure out the types of a bunch of things. --- synapse/handlers/_base.py | 4 ++++ synapse/handlers/auth.py | 4 ++++ synapse/rest/client/v1/base.py | 4 ++++ synapse/rest/client/v1/register.py | 4 ++++ synapse/rest/client/v2_alpha/register.py | 9 +++++++++ synapse/server.pyi | 21 +++++++++++++++++++++ 6 files changed, 46 insertions(+) create mode 100644 synapse/server.pyi (limited to 'synapse/handlers/auth.py') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d00685c389..6264aa0d9a 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -36,6 +36,10 @@ class BaseHandler(object): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.store = hs.get_datastore() self.auth = hs.get_auth() self.notifier = hs.get_notifier() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ce9bc18849..8f83923ddb 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -45,6 +45,10 @@ class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ super(AuthHandler, self).__init__(hs) self.checkers = { LoginType.PASSWORD: self._check_password_auth, diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 1c020b7e2c..96b49b01f2 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.hs = hs self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 8e1f1b7845..efe796c65f 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__(hs) # sessions are stored as: # self.sessions = { diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5db953a1e3..2722a58e3e 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -45,6 +45,10 @@ class RegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_v2_patterns("/register/email/requestToken$") def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRequestTokenRestServlet, self).__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler @@ -77,7 +81,12 @@ class RegisterRestServlet(RestServlet): PATTERNS = client_v2_patterns("/register$") def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/server.pyi b/synapse/server.pyi new file mode 100644 index 0000000000..902f725c06 --- /dev/null +++ b/synapse/server.pyi @@ -0,0 +1,21 @@ +import synapse.handlers +import synapse.handlers.auth +import synapse.handlers.device +import synapse.storage +import synapse.state + +class HomeServer(object): + def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: + pass + + def get_datastore(self) -> synapse.storage.DataStore: + pass + + def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: + pass + + def get_handlers(self) -> synapse.handlers.Handlers: + pass + + def get_state_handler(self) -> synapse.state.StateHandler: + pass -- cgit 1.4.1 From dad2da7e54a4f0e92185e4f8553fb51b037c0bd3 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 22 Jul 2016 17:00:56 +0100 Subject: Log the hostname the reCAPTCHA was completed on This could be useful information to have in the logs. Also comment about how & why we don't verify the hostname. --- synapse/handlers/auth.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8f83923ddb..6fff7e7d03 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -279,8 +279,17 @@ class AuthHandler(BaseHandler): data = pde.response resp_body = simplejson.loads(data) - if 'success' in resp_body and resp_body['success']: - defer.returnValue(True) + if 'success' in resp_body: + # Note that we do NOT check the hostname here: we explicitly + # intend the CAPTCHA to be presented by whatever client the + # user is using, we just care that they have completed a CAPTCHA. + logger.info( + "%s reCAPTCHA from hostname %s", + "Successful" if resp_body['success'] else "Failed", + resp_body['hostname'] + ) + if resp_body['success']: + defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) @defer.inlineCallbacks -- cgit 1.4.1 From 7ed58bb3476c4a18a9af97b8ee3358dac00098eb Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 22 Jul 2016 17:18:50 +0100 Subject: Use get to avoid KeyErrors --- synapse/handlers/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6fff7e7d03..d5d2072436 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -286,7 +286,7 @@ class AuthHandler(BaseHandler): logger.info( "%s reCAPTCHA from hostname %s", "Successful" if resp_body['success'] else "Failed", - resp_body['hostname'] + resp_body.get('hostname') ) if resp_body['success']: defer.returnValue(True) -- cgit 1.4.1 From 436bffd15fb8382a0d2dddd3c6f7a077ba751da2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 22 Jul 2016 14:52:53 +0100 Subject: Implement deleting devices --- synapse/handlers/auth.py | 22 ++++++++++++++++-- synapse/handlers/device.py | 27 +++++++++++++++++++++- synapse/rest/client/v1/login.py | 13 ++++++++--- synapse/rest/client/v2_alpha/devices.py | 14 +++++++++++ synapse/rest/client/v2_alpha/register.py | 10 ++++---- synapse/storage/devices.py | 15 ++++++++++++ synapse/storage/registration.py | 26 +++++++++++++++++---- .../schema/delta/33/access_tokens_device_index.sql | 17 ++++++++++++++ .../schema/delta/33/refreshtoken_device_index.sql | 17 ++++++++++++++ tests/handlers/test_device.py | 22 ++++++++++++++++-- tests/rest/client/v2_alpha/test_register.py | 14 +++++++---- 11 files changed, 176 insertions(+), 21 deletions(-) create mode 100644 synapse/storage/schema/delta/33/access_tokens_device_index.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device_index.sql (limited to 'synapse/handlers/auth.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d2072436..2e138f328f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -77,6 +77,7 @@ class AuthHandler(BaseHandler): self.ldap_bind_password = hs.config.ldap_bind_password self.hs = hs # FIXME better possibility to access registrationHandler later? + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -374,7 +375,8 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id, device_id=None): + def get_login_tuple_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ Gets login tuple for the user with the given user ID. @@ -383,9 +385,15 @@ class AuthHandler(BaseHandler): The user is assumed to have been authenticated by some other machanism (e.g. CAS), and the user_id converted to the canonical case. + The device will be recorded in the table if it is not there already. + Args: user_id (str): canonical User ID - device_id (str): the device ID to associate with the access token + device_id (str|None): the device ID to associate with the tokens. + None to leave the tokens unassociated with a device (deprecated: + we should always have a device ID) + initial_display_name (str): display name to associate with the + device if it needs re-registering Returns: A tuple of: The access token for the user's session. @@ -397,6 +405,16 @@ class AuthHandler(BaseHandler): logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) refresh_token = yield self.issue_refresh_token(user_id, device_id) + + # the device *should* have been registered before we got here; however, + # it's possible we raced against a DELETE operation. The thing we + # really don't want is active access_tokens without a record of the + # device, so we double-check it here. + if device_id is not None: + yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1f9e15c33c..a7a192e1c9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler): Args: user_id (str): - device_id (str) + device_id (str): Returns: defer.Deferred: dict[str, X]: info on the device @@ -117,6 +117,31 @@ class DeviceHandler(BaseHandler): _update_device_from_client_ips(device, ips) defer.returnValue(device) + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """ Delete the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: + """ + + try: + yield self.store.delete_device(user_id, device_id) + except errors.StoreError, e: + if e.code == 404: + # no match + pass + else: + raise + + yield self.store.user_delete_access_tokens(user_id, + device_id=device_id) + + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e8b791519c..92fcae674a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id( - registered_user_id, device_id + registered_user_id, device_id, + login_submission.get("initial_device_display_name") ) ) result = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 8b9ab4f674..30ef8b3da9 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -70,6 +70,20 @@ class DeviceRestServlet(RestServlet): ) defer.returnValue((200, device)) + @defer.inlineCallbacks + def on_DELETE(self, request, device_id): + # XXX: it's not completely obvious we want to expose this endpoint. + # It allows the client to delete access tokens, which feels like a + # thing which merits extra auth. But if we want to do the interactive- + # auth dance, we should really make it possible to delete more than one + # device at a time. + requester = yield self.auth.get_user_by_req(request) + yield self.device_handler.delete_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, {})) + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c8c9395fc6..9f599ea8bb 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet): """ device_id = yield self._register_device(user_id, params) - access_token = yield self.auth_handler.issue_access_token( - user_id, device_id=device_id + access_token, refresh_token = ( + yield self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, + initial_display_name=params.get("initial_device_display_name") + ) ) - refresh_token = yield self.auth_handler.issue_refresh_token( - user_id, device_id=device_id - ) defer.returnValue({ "user_id": user_id, "access_token": access_token, diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 1cc6e07f2b..4689980f80 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -76,6 +76,21 @@ class DeviceStore(SQLBaseStore): desc="get_device", ) + def delete_device(self, user_id, device_id): + """Delete a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred + """ + return self._simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_device", + ) + @defer.inlineCallbacks def get_devices_by_user(self, user_id): """Retrieve all of a user's registered devices. diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 9a92b35361..935e82bf7a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,18 +18,31 @@ import re from twisted.internet import defer from synapse.api.errors import StoreError, Codes - -from ._base import SQLBaseStore +from synapse.storage import background_updates from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -class RegistrationStore(SQLBaseStore): +class RegistrationStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): super(RegistrationStore, self).__init__(hs) self.clock = hs.get_clock() + self.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.register_background_index_update( + "refresh_tokens_device_index", + index_name="refresh_tokens_device_id", + table="refresh_tokens", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -238,11 +251,16 @@ class RegistrationStore(SQLBaseStore): self.get_user_by_id.invalidate((user_id,)) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids=[]): + def user_delete_access_tokens(self, user_id, except_token_ids=[], + device_id=None): def f(txn): sql = "SELECT token FROM access_tokens WHERE user_id = ?" clauses = [user_id] + if device_id is not None: + sql += " AND device_id = ?" + clauses.append(device_id) + if except_token_ids: sql += " AND id NOT IN (%s)" % ( ",".join(["?" for _ in except_token_ids]), diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql new file mode 100644 index 0000000000..61ad3fe3e8 --- /dev/null +++ b/synapse/storage/schema/delta/33/access_tokens_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql new file mode 100644 index 0000000000..bb225dafbf --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('refresh_tokens_device_index', '{}'); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 331aa13fed..214e722eb3 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse import types + from twisted.internet import defer +import synapse.api.errors import synapse.handlers.device + import synapse.storage +from synapse import types from tests import unittest, utils user1 = "@boris:aaa" @@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(DeviceTestCase, self).__init__(*args, **kwargs) self.store = None # type: synapse.storage.DataStore - self.handler = None # type: device.DeviceHandler + self.handler = None # type: synapse.handlers.device.DeviceHandler self.clock = None # type: utils.MockClock @defer.inlineCallbacks @@ -123,6 +126,21 @@ class DeviceTestCase(unittest.TestCase): "last_seen_ts": 3000000, }, res) + @defer.inlineCallbacks + def test_delete_device(self): + yield self._record_users() + + # delete the device + yield self.handler.delete_device(user1, "abc") + + # check the device was deleted + with self.assertRaises(synapse.api.errors.NotFoundError): + yield self.handler.get_device(user1, "abc") + + # we'd like to check the access token was invalidated, but that's a + # bit of a PITA. + + @defer.inlineCallbacks def _record_users(self): # check this works for both devices which have a recorded client_ip, diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 3bd7065e32..8ac56a1fb2 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.appservice_register = Mock( return_value=user_id ) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname } self.assertDictContainsSubset(det_data, result) @@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase): "password": "monkey" }, None) self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) self.device_handler.check_device_registered = \ Mock(return_value=device_id) @@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase): det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname, "device_id": device_id, } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) - self.auth_handler.issue_access_token.assert_called_once_with( - user_id, device_id=device_id) + self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, initial_device_display_name=None) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.4.1