summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/register.py6
-rw-r--r--synapse/rest/client/v1/register.py9
-rw-r--r--tests/handlers/test_register.py8
-rw-r--r--tests/rest/client/v1/test_register.py30
4 files changed, 22 insertions, 31 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 19329057d5..7e119f13b1 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -19,7 +19,6 @@ import urllib
 
 from twisted.internet import defer
 
-import synapse.types
 from synapse.api.errors import (
     AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
 )
@@ -370,7 +369,7 @@ class RegistrationHandler(BaseHandler):
         defer.returnValue(data)
 
     @defer.inlineCallbacks
-    def get_or_create_user(self, localpart, displayname, duration_in_ms,
+    def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
                            password_hash=None):
         """Creates a new user if the user does not exist,
         else revokes all previous access tokens and generates a new one.
@@ -417,9 +416,8 @@ class RegistrationHandler(BaseHandler):
         if displayname is not None:
             logger.info("setting user display name: %s -> %s", user_id, displayname)
             profile_handler = self.hs.get_handlers().profile_handler
-            requester = synapse.types.create_requester(user)
             yield profile_handler.set_displayname(
-                user, requester, displayname
+                user, requester, displayname, by_admin=True,
             )
 
         defer.returnValue((user_id, token))
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index fe4480b363..b5a76fefac 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -22,6 +22,7 @@ from synapse.api.auth import get_access_token_from_request
 from .base import ClientV1RestServlet, client_path_patterns
 import synapse.util.stringutils as stringutils
 from synapse.http.servlet import parse_json_object_from_request
+from synapse.types import create_requester
 
 from synapse.util.async import run_on_reactor
 
@@ -397,9 +398,10 @@ class CreateUserRestServlet(ClientV1RestServlet):
         if not app_service:
             raise SynapseError(403, "Invalid application service token.")
 
-        logger.debug("creating user: %s", user_json)
+        requester = create_requester(app_service.sender)
 
-        response = yield self._do_create(user_json)
+        logger.debug("creating user: %s", user_json)
+        response = yield self._do_create(requester, user_json)
 
         defer.returnValue((200, response))
 
@@ -407,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
         return 403, {}
 
     @defer.inlineCallbacks
-    def _do_create(self, user_json):
+    def _do_create(self, requester, user_json):
         yield run_on_reactor()
 
         if "localpart" not in user_json:
@@ -433,6 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
 
         handler = self.handlers.registration_handler
         user_id, token = yield handler.get_or_create_user(
+            requester=requester,
             localpart=localpart,
             displayname=displayname,
             duration_in_ms=(duration_seconds * 1000),
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index a7de3c7c17..9c9d144690 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
 from .. import unittest
 
 from synapse.handlers.register import RegistrationHandler
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
 
 from tests.utils import setup_test_homeserver
 
@@ -57,8 +57,9 @@ class RegistrationTestCase(unittest.TestCase):
         local_part = "someone"
         display_name = "someone"
         user_id = "@someone:test"
+        requester = create_requester("@as:test")
         result_user_id, result_token = yield self.handler.get_or_create_user(
-            local_part, display_name, duration_ms)
+            requester, local_part, display_name, duration_ms)
         self.assertEquals(result_user_id, user_id)
         self.assertEquals(result_token, 'secret')
 
@@ -74,7 +75,8 @@ class RegistrationTestCase(unittest.TestCase):
         local_part = "frank"
         display_name = "Frank"
         user_id = "@frank:test"
+        requester = create_requester("@as:test")
         result_user_id, result_token = yield self.handler.get_or_create_user(
-            local_part, display_name, duration_ms)
+            requester, local_part, display_name, duration_ms)
         self.assertEquals(result_user_id, user_id)
         self.assertEquals(result_token, 'secret')
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index 4a898a034f..44ba9ff58f 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -31,33 +31,21 @@ class CreateUserServletTestCase(unittest.TestCase):
         )
         self.request.args = {}
 
-        self.appservice = None
-        self.auth = Mock(get_appservice_by_req=Mock(
-            side_effect=lambda x: defer.succeed(self.appservice))
-        )
+        self.registration_handler = Mock()
 
-        self.auth_result = (False, None, None, None)
-        self.auth_handler = Mock(
-            check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
-            get_session_data=Mock(return_value=None)
+        self.appservice = Mock(sender="@as:test")
+        self.datastore = Mock(
+            get_app_service_by_token=Mock(return_value=self.appservice)
         )
-        self.registration_handler = Mock()
-        self.identity_handler = Mock()
-        self.login_handler = Mock()
 
-        # do the dance to hook it up to the hs global
-        self.handlers = Mock(
-            auth_handler=self.auth_handler,
+        # do the dance to hook things up to the hs global
+        handlers = Mock(
             registration_handler=self.registration_handler,
-            identity_handler=self.identity_handler,
-            login_handler=self.login_handler
         )
         self.hs = Mock()
-        self.hs.hostname = "supergbig~testing~thing.com"
-        self.hs.get_auth = Mock(return_value=self.auth)
-        self.hs.get_handlers = Mock(return_value=self.handlers)
-        self.hs.config.enable_registration = True
-        # init the thing we're testing
+        self.hs.hostname = "superbig~testing~thing.com"
+        self.hs.get_datastore = Mock(return_value=self.datastore)
+        self.hs.get_handlers = Mock(return_value=handlers)
         self.servlet = CreateUserRestServlet(self.hs)
 
     @defer.inlineCallbacks