summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2019-07-08 14:52:26 +0100
committerAmber Brown <hawkowl@atleastfornow.net>2019-07-08 23:52:26 +1000
commit1af2fcd492c81075a9d75a5578d8d599af3cea0c (patch)
treedeb79e3d7fdb0ef304bf76e8a61cd7a9d8e5aa9a /tests/handlers
parentAdd a few more common environment directory names to black exclusion (#5630) (diff)
downloadsynapse-1af2fcd492c81075a9d75a5578d8d599af3cea0c.tar.xz
Move get_or_create_user to test code (#5628)
This is only used in tests, so...
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_register.py68
1 files changed, 59 insertions, 9 deletions
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 4edce7af43..1c7ded7397 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,7 +18,7 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.api.constants import UserTypes
-from synapse.api.errors import ResourceLimitError, SynapseError
+from synapse.api.errors import Codes, ResourceLimitError, SynapseError
 from synapse.handlers.register import RegistrationHandler
 from synapse.types import RoomAlias, UserID, create_requester
 
@@ -67,7 +67,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         user_id = frank.to_string()
         requester = create_requester(user_id)
         result_user_id, result_token = self.get_success(
-            self.handler.get_or_create_user(requester, frank.localpart, "Frankie")
+            self.get_or_create_user(requester, frank.localpart, "Frankie")
         )
         self.assertEquals(result_user_id, user_id)
         self.assertTrue(result_token is not None)
@@ -87,7 +87,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         user_id = frank.to_string()
         requester = create_requester(user_id)
         result_user_id, result_token = self.get_success(
-            self.handler.get_or_create_user(requester, local_part, None)
+            self.get_or_create_user(requester, local_part, None)
         )
         self.assertEquals(result_user_id, user_id)
         self.assertTrue(result_token is not None)
@@ -95,9 +95,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_mau_limits_when_disabled(self):
         self.hs.config.limit_usage_by_mau = False
         # Ensure does not throw exception
-        self.get_success(
-            self.handler.get_or_create_user(self.requester, "a", "display_name")
-        )
+        self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
 
     def test_get_or_create_user_mau_not_blocked(self):
         self.hs.config.limit_usage_by_mau = True
@@ -105,7 +103,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             return_value=defer.succeed(self.hs.config.max_mau_value - 1)
         )
         # Ensure does not throw exception
-        self.get_success(self.handler.get_or_create_user(self.requester, "c", "User"))
+        self.get_success(self.get_or_create_user(self.requester, "c", "User"))
 
     def test_get_or_create_user_mau_blocked(self):
         self.hs.config.limit_usage_by_mau = True
@@ -113,7 +111,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             return_value=defer.succeed(self.lots_of_users)
         )
         self.get_failure(
-            self.handler.get_or_create_user(self.requester, "b", "display_name"),
+            self.get_or_create_user(self.requester, "b", "display_name"),
             ResourceLimitError,
         )
 
@@ -121,7 +119,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             return_value=defer.succeed(self.hs.config.max_mau_value)
         )
         self.get_failure(
-            self.handler.get_or_create_user(self.requester, "b", "display_name"),
+            self.get_or_create_user(self.requester, "b", "display_name"),
             ResourceLimitError,
         )
 
@@ -232,3 +230,55 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_invalid_user_id_length(self):
         invalid_user_id = "x" * 256
         self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError)
+
+    @defer.inlineCallbacks
+    def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+        """Creates a new user if the user does not exist,
+        else revokes all previous access tokens and generates a new one.
+
+        XXX: this used to be in the main codebase, but was only used by this file,
+        so got moved here. TODO: get rid of it, probably
+
+        Args:
+            localpart : The local part of the user ID to register. If None,
+              one will be randomly generated.
+        Returns:
+            A tuple of (user_id, access_token).
+        Raises:
+            RegistrationError if there was a problem registering.
+        """
+        if localpart is None:
+            raise SynapseError(400, "Request must include user id")
+        yield self.hs.get_auth().check_auth_blocking()
+        need_register = True
+
+        try:
+            yield self.handler.check_username(localpart)
+        except SynapseError as e:
+            if e.errcode == Codes.USER_IN_USE:
+                need_register = False
+            else:
+                raise
+
+        user = UserID(localpart, self.hs.hostname)
+        user_id = user.to_string()
+        token = self.macaroon_generator.generate_access_token(user_id)
+
+        if need_register:
+            yield self.handler.register_with_store(
+                user_id=user_id,
+                token=token,
+                password_hash=password_hash,
+                create_profile_with_displayname=user.localpart,
+            )
+        else:
+            yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+            yield self.store.add_access_token_to_user(user_id=user_id, token=token)
+
+        if displayname is not None:
+            # logger.info("setting user display name: %s -> %s", user_id, displayname)
+            yield self.hs.get_profile_handler().set_displayname(
+                user, requester, displayname, by_admin=True
+            )
+
+        defer.returnValue((user_id, token))