diff options
-rw-r--r-- | synapse/handlers/register.py | 5 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/register.py | 13 | ||||
-rw-r--r-- | synapse/storage/registration.py | 40 | ||||
-rw-r--r-- | synapse/storage/schema/delta/30/as_users.py | 23 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_register.py | 18 |
5 files changed, 72 insertions, 27 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index e2ace6a4e5..6ffb8c0da6 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -182,6 +182,8 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE ) + service_id = service.id if service.is_exclusive_user(user_id) else None + yield self.check_user_id_not_appservice_exclusive( user_id, allowed_appservice=service ) @@ -190,7 +192,8 @@ class RegistrationHandler(BaseHandler): yield self.store.register( user_id=user_id, token=token, - password_hash="" + password_hash="", + appservice_id=service_id, ) yield registered_user(self.distributor, user) defer.returnValue((user_id, token)) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b090e66bcf..533ff136eb 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -187,7 +187,7 @@ class RegisterRestServlet(RestServlet): else: logger.info("bind_email not specified: not binding email") - result = self._create_registration_details(user_id, token) + result = yield self._create_registration_details(user_id, token) defer.returnValue((200, result)) def on_OPTIONS(self, _): @@ -198,7 +198,7 @@ class RegisterRestServlet(RestServlet): (user_id, token) = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue(self._create_registration_details(user_id, token)) + defer.returnValue((yield self._create_registration_details(user_id, token))) @defer.inlineCallbacks def _do_shared_secret_registration(self, username, password, mac): @@ -225,14 +225,17 @@ class RegisterRestServlet(RestServlet): (user_id, token) = yield self.registration_handler.register( localpart=username, password=password ) - defer.returnValue(self._create_registration_details(user_id, token)) + defer.returnValue((yield self._create_registration_details(user_id, token))) + @defer.inlineCallbacks def _create_registration_details(self, user_id, token): - return { + refresh_token = yield self.auth_handler.issue_refresh_token(user_id) + defer.returnValue({ "user_id": user_id, "access_token": token, "home_server": self.hs.hostname, - } + "refresh_token": refresh_token, + }) @defer.inlineCallbacks def onEmailTokenRequest(self, request): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index ad1157f979..aa49f53458 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -76,7 +76,7 @@ class RegistrationStore(SQLBaseStore): @defer.inlineCallbacks def register(self, user_id, token, password_hash, - was_guest=False, make_guest=False): + was_guest=False, make_guest=False, appservice_id=None): """Attempts to register an account. Args: @@ -87,16 +87,32 @@ class RegistrationStore(SQLBaseStore): upgraded to a non-guest account. make_guest (boolean): True if the the new user should be guest, false to add a regular user account. + appservice_id (str): The ID of the appservice registering the user. Raises: StoreError if the user_id could not be registered. """ yield self.runInteraction( "register", - self._register, user_id, token, password_hash, was_guest, make_guest + self._register, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id ) self.is_guest.invalidate((user_id,)) - def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): + def _register( + self, + txn, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id + ): now = int(self.clock.time()) next_id = self._access_tokens_id_gen.get_next() @@ -111,9 +127,21 @@ class RegistrationStore(SQLBaseStore): [password_hash, now, 1 if make_guest else 0, user_id]) else: txn.execute("INSERT INTO users " - "(name, password_hash, creation_ts, is_guest) " - "VALUES (?,?,?,?)", - [user_id, password_hash, now, 1 if make_guest else 0]) + "(" + " name," + " password_hash," + " creation_ts," + " is_guest," + " appservice_id" + ") " + "VALUES (?,?,?,?,?)", + [ + user_id, + password_hash, + now, + 1 if make_guest else 0, + appservice_id, + ]) except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 4da3c59de2..4f6e9dd540 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -52,12 +52,17 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): " service (IDs %s and %s); assigning arbitrarily to %s" % (user_id, owned[user_id], appservice.id, owned[user_id]) ) - owned[user_id] = appservice.id - - for user_id, as_id in owned.items(): - cur.execute( - database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name = ?" - ), - (as_id, user_id) - ) + owned.setdefault(appservice.id, []).append(user_id) + + for as_id, user_ids in owned.items(): + n = 100 + user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n)) + for chunk in user_chunks: + cur.execute( + database_engine.convert_param_style( + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" % ( + ",".join("?" for _ in chunk), + ) + ), + [as_id] + chunk + ) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b867599079..9a202e9dd7 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -62,12 +62,15 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.appservice_register = Mock( return_value=(user_id, token) ) - result = yield self.servlet.on_POST(self.request) - self.assertEquals(result, (200, { + (code, result) = yield self.servlet.on_POST(self.request) + self.assertEquals(code, 200) + det_data = { "user_id": user_id, "access_token": token, "home_server": self.hs.hostname - })) + } + self.assertDictContainsSubset(det_data, result) + self.assertIn("refresh_token", result) @defer.inlineCallbacks def test_POST_appservice_registration_invalid(self): @@ -112,12 +115,15 @@ class RegisterRestServletTestCase(unittest.TestCase): }) self.registration_handler.register = Mock(return_value=(user_id, token)) - result = yield self.servlet.on_POST(self.request) - self.assertEquals(result, (200, { + (code, result) = yield self.servlet.on_POST(self.request) + self.assertEquals(code, 200) + det_data = { "user_id": user_id, "access_token": token, "home_server": self.hs.hostname - })) + } + self.assertDictContainsSubset(det_data, result) + self.assertIn("refresh_token", result) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False |