summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/v2_alpha/test_register.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 753d5c3e80..18080ebfd6 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -32,7 +32,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.identity_handler = Mock()
         self.login_handler = Mock()
         self.device_handler = Mock()
-        self.device_handler.check_device_registered = Mock(return_value="FAKE")
+
+        def check_device_registered(user_id, device_id, initial_display_name):
+            # Just echo back the given device ID, or return a new "FAKE" device
+            # ID
+            if device_id:
+                return device_id
+            else:
+                return "FAKE"
+
+        self.device_handler.check_device_registered = Mock(
+            side_effect=check_device_registered,
+        )
 
         self.datastore = Mock(return_value=Mock())
         self.datastore.get_current_state_deltas = Mock(return_value=[])
@@ -106,14 +117,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         user_id = "@kermit:muppet"
         token = "kermits_access_token"
         device_id = "frogfone"
-        request_data = json.dumps(
-            {"username": "kermit", "password": "monkey", "device_id": device_id}
-        )
+        params = {"username": "kermit", "password": "monkey", "device_id": device_id}
+        request_data = json.dumps(params)
         self.registration_handler.check_username = Mock(return_value=True)
-        self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
+        self.auth_result = (None, params, None)
         self.registration_handler.register = Mock(return_value=(user_id, None))
         self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
-        self.device_handler.check_device_registered = Mock(return_value=device_id)
 
         request, channel = self.make_request(b"POST", self.url, request_data)
         self.render(request)