summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/register.py9
-rw-r--r--tests/handlers/test_register.py34
2 files changed, 30 insertions, 13 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 5883b9111e..16f33f8371 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -16,7 +16,7 @@
 """Contains functions for registering clients."""
 from twisted.internet import defer
 
-from synapse.types import UserID
+from synapse.types import UserID, Requester
 from synapse.api.errors import (
     AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
 )
@@ -360,7 +360,8 @@ class RegistrationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def get_or_create_user(self, localpart, displayname, duration_seconds):
-        """Creates a new user or returns an access token for an existing one
+        """Creates a new user if the user does not exist,
+        else revokes all previous access tokens and generates a new one.
 
         Args:
             localpart : The local part of the user ID to register. If None,
@@ -399,14 +400,14 @@ class RegistrationHandler(BaseHandler):
 
             yield registered_user(self.distributor, user)
         else:
-            yield self.store.flush_user(user_id=user_id)
+            yield self.store.user_delete_access_tokens(user_id=user_id)
             yield self.store.add_access_token_to_user(user_id=user_id, token=token)
 
         if displayname is not None:
             logger.info("setting user display name: %s -> %s", user_id, displayname)
             profile_handler = self.hs.get_handlers().profile_handler
             yield profile_handler.set_displayname(
-                user, user, displayname
+                user, Requester(user, token, False), displayname
             )
 
         defer.returnValue((user_id, token))
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 8b7be96bd9..9d5c653b45 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
 from .. import unittest
 
 from synapse.handlers.register import RegistrationHandler
+from synapse.types import UserID
 
 from tests.utils import setup_test_homeserver
 
@@ -36,25 +37,21 @@ class RegistrationTestCase(unittest.TestCase):
         self.mock_distributor = Mock()
         self.mock_distributor.declare("registered_user")
         self.mock_captcha_client = Mock()
-        hs = yield setup_test_homeserver(
+        self.hs = yield setup_test_homeserver(
             handlers=None,
             http_client=None,
             expire_access_token=True)
-        hs.handlers = RegistrationHandlers(hs)
-        self.handler = hs.get_handlers().registration_handler
-        hs.get_handlers().profile_handler = Mock()
+        self.hs.handlers = RegistrationHandlers(self.hs)
+        self.handler = self.hs.get_handlers().registration_handler
+        self.hs.get_handlers().profile_handler = Mock()
         self.mock_handler = Mock(spec=[
             "generate_short_term_login_token",
         ])
 
-        hs.get_handlers().auth_handler = self.mock_handler
+        self.hs.get_handlers().auth_handler = self.mock_handler
 
     @defer.inlineCallbacks
     def test_user_is_created_and_logged_in_if_doesnt_exist(self):
-        """
-        Returns:
-            The user doess not exist in this case so it will register and log it in
-        """
         duration_ms = 200
         local_part = "someone"
         display_name = "someone"
@@ -65,3 +62,22 @@ class RegistrationTestCase(unittest.TestCase):
             local_part, display_name, duration_ms)
         self.assertEquals(result_user_id, user_id)
         self.assertEquals(result_token, 'secret')
+
+    @defer.inlineCallbacks
+    def test_if_user_exists(self):
+        store = self.hs.get_datastore()
+        frank = UserID.from_string("@frank:test")
+        yield store.register(
+            user_id=frank.to_string(),
+            token="jkv;g498752-43gj['eamb!-5",
+            password_hash=None)
+        duration_ms = 200
+        local_part = "frank"
+        display_name = "Frank"
+        user_id = "@frank:test"
+        mock_token = self.mock_handler.generate_short_term_login_token
+        mock_token.return_value = 'secret'
+        result_user_id, result_token = yield self.handler.get_or_create_user(
+            local_part, display_name, duration_ms)
+        self.assertEquals(result_user_id, user_id)
+        self.assertEquals(result_token, 'secret')