summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_admin.py10
-rw-r--r--tests/handlers/test_auth.py20
-rw-r--r--tests/handlers/test_register.py126
-rw-r--r--tests/handlers/test_user_directory.py16
4 files changed, 114 insertions, 58 deletions
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 5e7d2d3361..fc37c4328c 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -55,7 +55,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
 
         writer = Mock()
 
-        self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer))
+        self.get_success(self.admin_handler.export_user_data(self.user2, writer))
 
         writer.write_events.assert_called()
 
@@ -94,7 +94,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
 
         writer = Mock()
 
-        self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer))
+        self.get_success(self.admin_handler.export_user_data(self.user2, writer))
 
         writer.write_events.assert_called()
 
@@ -127,7 +127,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
 
         writer = Mock()
 
-        self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer))
+        self.get_success(self.admin_handler.export_user_data(self.user2, writer))
 
         writer.write_events.assert_called()
 
@@ -169,7 +169,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
 
         writer = Mock()
 
-        self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer))
+        self.get_success(self.admin_handler.export_user_data(self.user2, writer))
 
         writer.write_events.assert_called_once()
 
@@ -198,7 +198,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
 
         writer = Mock()
 
-        self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer))
+        self.get_success(self.admin_handler.export_user_data(self.user2, writer))
 
         writer.write_events.assert_not_called()
         writer.write_state.assert_not_called()
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index b204a0700d..b03103d96f 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -117,7 +117,9 @@ class AuthTestCase(unittest.TestCase):
     def test_mau_limits_disabled(self):
         self.hs.config.limit_usage_by_mau = False
         # Ensure does not throw exception
-        yield self.auth_handler.get_access_token_for_user_id("user_a")
+        yield self.auth_handler.get_access_token_for_user_id(
+            "user_a", device_id=None, valid_until_ms=None
+        )
 
         yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
             self._get_macaroon().serialize()
@@ -131,7 +133,9 @@ class AuthTestCase(unittest.TestCase):
         )
 
         with self.assertRaises(ResourceLimitError):
-            yield self.auth_handler.get_access_token_for_user_id("user_a")
+            yield self.auth_handler.get_access_token_for_user_id(
+                "user_a", device_id=None, valid_until_ms=None
+            )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.large_number_of_users)
@@ -150,7 +154,9 @@ class AuthTestCase(unittest.TestCase):
             return_value=defer.succeed(self.hs.config.max_mau_value)
         )
         with self.assertRaises(ResourceLimitError):
-            yield self.auth_handler.get_access_token_for_user_id("user_a")
+            yield self.auth_handler.get_access_token_for_user_id(
+                "user_a", device_id=None, valid_until_ms=None
+            )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.hs.config.max_mau_value)
@@ -166,7 +172,9 @@ class AuthTestCase(unittest.TestCase):
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.hs.config.max_mau_value)
         )
-        yield self.auth_handler.get_access_token_for_user_id("user_a")
+        yield self.auth_handler.get_access_token_for_user_id(
+            "user_a", device_id=None, valid_until_ms=None
+        )
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(self.hs.get_clock().time_msec())
         )
@@ -185,7 +193,9 @@ class AuthTestCase(unittest.TestCase):
             return_value=defer.succeed(self.small_number_of_users)
         )
         # Ensure does not raise exception
-        yield self.auth_handler.get_access_token_for_user_id("user_a")
+        yield self.auth_handler.get_access_token_for_user_id(
+            "user_a", device_id=None, valid_until_ms=None
+        )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.small_number_of_users)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 4edce7af43..90d0129374 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,7 +18,7 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.api.constants import UserTypes
-from synapse.api.errors import ResourceLimitError, SynapseError
+from synapse.api.errors import Codes, ResourceLimitError, SynapseError
 from synapse.handlers.register import RegistrationHandler
 from synapse.types import RoomAlias, UserID, create_requester
 
@@ -67,7 +67,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         user_id = frank.to_string()
         requester = create_requester(user_id)
         result_user_id, result_token = self.get_success(
-            self.handler.get_or_create_user(requester, frank.localpart, "Frankie")
+            self.get_or_create_user(requester, frank.localpart, "Frankie")
         )
         self.assertEquals(result_user_id, user_id)
         self.assertTrue(result_token is not None)
@@ -77,17 +77,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         store = self.hs.get_datastore()
         frank = UserID.from_string("@frank:test")
         self.get_success(
-            store.register(
-                user_id=frank.to_string(),
-                token="jkv;g498752-43gj['eamb!-5",
-                password_hash=None,
-            )
+            store.register_user(user_id=frank.to_string(), password_hash=None)
         )
         local_part = frank.localpart
         user_id = frank.to_string()
         requester = create_requester(user_id)
         result_user_id, result_token = self.get_success(
-            self.handler.get_or_create_user(requester, local_part, None)
+            self.get_or_create_user(requester, local_part, None)
         )
         self.assertEquals(result_user_id, user_id)
         self.assertTrue(result_token is not None)
@@ -95,9 +91,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_mau_limits_when_disabled(self):
         self.hs.config.limit_usage_by_mau = False
         # Ensure does not throw exception
-        self.get_success(
-            self.handler.get_or_create_user(self.requester, "a", "display_name")
-        )
+        self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
 
     def test_get_or_create_user_mau_not_blocked(self):
         self.hs.config.limit_usage_by_mau = True
@@ -105,7 +99,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             return_value=defer.succeed(self.hs.config.max_mau_value - 1)
         )
         # Ensure does not throw exception
-        self.get_success(self.handler.get_or_create_user(self.requester, "c", "User"))
+        self.get_success(self.get_or_create_user(self.requester, "c", "User"))
 
     def test_get_or_create_user_mau_blocked(self):
         self.hs.config.limit_usage_by_mau = True
@@ -113,7 +107,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             return_value=defer.succeed(self.lots_of_users)
         )
         self.get_failure(
-            self.handler.get_or_create_user(self.requester, "b", "display_name"),
+            self.get_or_create_user(self.requester, "b", "display_name"),
             ResourceLimitError,
         )
 
@@ -121,7 +115,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             return_value=defer.succeed(self.hs.config.max_mau_value)
         )
         self.get_failure(
-            self.handler.get_or_create_user(self.requester, "b", "display_name"),
+            self.get_or_create_user(self.requester, "b", "display_name"),
             ResourceLimitError,
         )
 
@@ -131,21 +125,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             return_value=defer.succeed(self.lots_of_users)
         )
         self.get_failure(
-            self.handler.register(localpart="local_part"), ResourceLimitError
+            self.handler.register_user(localpart="local_part"), ResourceLimitError
         )
 
         self.store.get_monthly_active_count = Mock(
             return_value=defer.succeed(self.hs.config.max_mau_value)
         )
         self.get_failure(
-            self.handler.register(localpart="local_part"), ResourceLimitError
+            self.handler.register_user(localpart="local_part"), ResourceLimitError
         )
 
     def test_auto_create_auto_join_rooms(self):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
-        res = self.get_success(self.handler.register(localpart="jeff"))
-        rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+        user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         directory_handler = self.hs.get_handlers().directory_handler
         room_alias = RoomAlias.from_string(room_alias_str)
         room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -156,25 +150,25 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_with_no_rooms(self):
         self.hs.config.auto_join_rooms = []
         frank = UserID.from_string("@frank:test")
-        res = self.get_success(self.handler.register(frank.localpart))
-        self.assertEqual(res[0], frank.to_string())
-        rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+        user_id = self.get_success(self.handler.register_user(frank.localpart))
+        self.assertEqual(user_id, frank.to_string())
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
 
     def test_auto_create_auto_join_where_room_is_another_domain(self):
         self.hs.config.auto_join_rooms = ["#room:another"]
         frank = UserID.from_string("@frank:test")
-        res = self.get_success(self.handler.register(frank.localpart))
-        self.assertEqual(res[0], frank.to_string())
-        rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+        user_id = self.get_success(self.handler.register_user(frank.localpart))
+        self.assertEqual(user_id, frank.to_string())
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
 
     def test_auto_create_auto_join_where_auto_create_is_false(self):
         self.hs.config.autocreate_auto_join_rooms = False
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
-        res = self.get_success(self.handler.register(localpart="jeff"))
-        rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+        user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
 
     def test_auto_create_auto_join_rooms_when_support_user_exists(self):
@@ -182,8 +176,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         self.hs.config.auto_join_rooms = [room_alias_str]
 
         self.store.is_support_user = Mock(return_value=True)
-        res = self.get_success(self.handler.register(localpart="support"))
-        rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+        user_id = self.get_success(self.handler.register_user(localpart="support"))
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
         directory_handler = self.hs.get_handlers().directory_handler
         room_alias = RoomAlias.from_string(room_alias_str)
@@ -211,24 +205,82 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         # When:-
         #   * the user is registered and post consent actions are called
-        res = self.get_success(self.handler.register(localpart="jeff"))
-        self.get_success(self.handler.post_consent_actions(res[0]))
+        user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+        self.get_success(self.handler.post_consent_actions(user_id))
 
         # Then:-
         #   * Ensure that they have not been joined to the room
-        rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
 
     def test_register_support_user(self):
-        res = self.get_success(
-            self.handler.register(localpart="user", user_type=UserTypes.SUPPORT)
+        user_id = self.get_success(
+            self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
         )
-        self.assertTrue(self.store.is_support_user(res[0]))
+        d = self.store.is_support_user(user_id)
+        self.assertTrue(self.get_success(d))
 
     def test_register_not_support_user(self):
-        res = self.get_success(self.handler.register(localpart="user"))
-        self.assertFalse(self.store.is_support_user(res[0]))
+        user_id = self.get_success(self.handler.register_user(localpart="user"))
+        d = self.store.is_support_user(user_id)
+        self.assertFalse(self.get_success(d))
 
     def test_invalid_user_id_length(self):
         invalid_user_id = "x" * 256
-        self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError)
+        self.get_failure(
+            self.handler.register_user(localpart=invalid_user_id), SynapseError
+        )
+
+    @defer.inlineCallbacks
+    def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+        """Creates a new user if the user does not exist,
+        else revokes all previous access tokens and generates a new one.
+
+        XXX: this used to be in the main codebase, but was only used by this file,
+        so got moved here. TODO: get rid of it, probably
+
+        Args:
+            localpart : The local part of the user ID to register. If None,
+              one will be randomly generated.
+        Returns:
+            A tuple of (user_id, access_token).
+        Raises:
+            RegistrationError if there was a problem registering.
+        """
+        if localpart is None:
+            raise SynapseError(400, "Request must include user id")
+        yield self.hs.get_auth().check_auth_blocking()
+        need_register = True
+
+        try:
+            yield self.handler.check_username(localpart)
+        except SynapseError as e:
+            if e.errcode == Codes.USER_IN_USE:
+                need_register = False
+            else:
+                raise
+
+        user = UserID(localpart, self.hs.hostname)
+        user_id = user.to_string()
+        token = self.macaroon_generator.generate_access_token(user_id)
+
+        if need_register:
+            yield self.handler.register_with_store(
+                user_id=user_id,
+                password_hash=password_hash,
+                create_profile_with_displayname=user.localpart,
+            )
+        else:
+            yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+
+        yield self.store.add_access_token_to_user(
+            user_id=user_id, token=token, device_id=None, valid_until_ms=None
+        )
+
+        if displayname is not None:
+            # logger.info("setting user display name: %s -> %s", user_id, displayname)
+            yield self.hs.get_profile_handler().set_displayname(
+                user, requester, displayname, by_admin=True
+            )
+
+        defer.returnValue((user_id, token))
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index b135486c48..c5e91a8c41 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -47,11 +47,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
     def test_handle_local_profile_change_with_support_user(self):
         support_user_id = "@support:test"
         self.get_success(
-            self.store.register(
-                user_id=support_user_id,
-                token="123",
-                password_hash=None,
-                user_type=UserTypes.SUPPORT,
+            self.store.register_user(
+                user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
             )
         )
 
@@ -73,11 +70,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
     def test_handle_user_deactivated_support_user(self):
         s_user_id = "@support:test"
         self.get_success(
-            self.store.register(
-                user_id=s_user_id,
-                token="123",
-                password_hash=None,
-                user_type=UserTypes.SUPPORT,
+            self.store.register_user(
+                user_id=s_user_id, password_hash=None, user_type=UserTypes.SUPPORT
             )
         )
 
@@ -90,7 +84,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
     def test_handle_user_deactivated_regular_user(self):
         r_user_id = "@regular:test"
         self.get_success(
-            self.store.register(user_id=r_user_id, token="123", password_hash=None)
+            self.store.register_user(user_id=r_user_id, password_hash=None)
         )
         self.store.remove_from_user_dir = Mock()
         self.get_success(self.handler.handle_user_deactivated(r_user_id))