summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--MANIFEST.in1
-rw-r--r--synapse/handlers/_base.py4
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/register.py13
-rw-r--r--synapse/rest/client/v1/base.py4
-rw-r--r--synapse/rest/client/v1/register.py8
-rw-r--r--synapse/rest/client/v2_alpha/register.py59
-rw-r--r--synapse/server.pyi21
-rw-r--r--synapse/storage/registration.py6
-rw-r--r--tests/rest/client/v2_alpha/test_register.py6
10 files changed, 104 insertions, 22 deletions
diff --git a/MANIFEST.in b/MANIFEST.in
index dfb7c9d28d..216df265b5 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -14,6 +14,7 @@ recursive-include docs *
 recursive-include res *
 recursive-include scripts *
 recursive-include scripts-dev *
+recursive-include synapse *.pyi
 recursive-include tests *.py
 
 recursive-include synapse/static *.css
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index d00685c389..6264aa0d9a 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -36,6 +36,10 @@ class BaseHandler(object):
     """
 
     def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer):
+        """
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index ce9bc18849..8f83923ddb 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -45,6 +45,10 @@ class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
     def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer):
+        """
         super(AuthHandler, self).__init__(hs)
         self.checkers = {
             LoginType.PASSWORD: self._check_password_auth,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 6b33b27149..94b19d0cb0 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -99,8 +99,13 @@ class RegistrationHandler(BaseHandler):
             localpart : The local part of the user ID to register. If None,
               one will be generated.
             password (str) : The password to assign to this user so they can
-            login again. This can be None which means they cannot login again
-            via a password (e.g. the user is an application service user).
+              login again. This can be None which means they cannot login again
+              via a password (e.g. the user is an application service user).
+            generate_token (bool): Whether a new access token should be
+              generated. Having this be True should be considered deprecated,
+              since it offers no means of associating a device_id with the
+              access_token. Instead you should call auth_handler.issue_access_token
+              after registration.
         Returns:
             A tuple of (user_id, access_token).
         Raises:
@@ -196,15 +201,13 @@ class RegistrationHandler(BaseHandler):
             user_id, allowed_appservice=service
         )
 
-        token = self.auth_handler().generate_access_token(user_id)
         yield self.store.register(
             user_id=user_id,
-            token=token,
             password_hash="",
             appservice_id=service_id,
             create_profile_with_localpart=user.localpart,
         )
-        defer.returnValue((user_id, token))
+        defer.returnValue(user_id)
 
     @defer.inlineCallbacks
     def check_recaptcha(self, ip, private_key, challenge, response):
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index 1c020b7e2c..96b49b01f2 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
     """
 
     def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer):
+        """
         self.hs = hs
         self.handlers = hs.get_handlers()
         self.builder_factory = hs.get_event_builder_factory()
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 8e1f1b7845..2383b9df86 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
 
     def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
         super(RegisterRestServlet, self).__init__(hs)
         # sessions are stored as:
         # self.sessions = {
@@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
         # TODO: persistent storage
         self.sessions = {}
         self.enable_registration = hs.config.enable_registration
+        self.auth_handler = hs.get_auth_handler()
 
     def on_GET(self, request):
         if self.hs.config.enable_registration_captcha:
@@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet):
         user_localpart = register_json["user"].encode("utf-8")
 
         handler = self.handlers.registration_handler
-        (user_id, token) = yield handler.appservice_register(
+        user_id = yield handler.appservice_register(
             user_localpart, as_token
         )
+        token = yield self.auth_handler.issue_access_token(user_id)
         self._remove_session(session)
         defer.returnValue({
             "user_id": user_id,
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5db953a1e3..b7e03ea9d1 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -45,6 +45,10 @@ class RegisterRequestTokenRestServlet(RestServlet):
     PATTERNS = client_v2_patterns("/register/email/requestToken$")
 
     def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
         super(RegisterRequestTokenRestServlet, self).__init__()
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
@@ -77,7 +81,12 @@ class RegisterRestServlet(RestServlet):
     PATTERNS = client_v2_patterns("/register$")
 
     def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
         super(RegisterRestServlet, self).__init__()
+
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
@@ -226,19 +235,17 @@ class RegisterRestServlet(RestServlet):
 
             add_email = True
 
-        access_token = yield self.auth_handler.issue_access_token(
+        result = yield self._create_registration_details(
             registered_user_id
         )
 
         if add_email and result and LoginType.EMAIL_IDENTITY in result:
             threepid = result[LoginType.EMAIL_IDENTITY]
             yield self._register_email_threepid(
-                registered_user_id, threepid, access_token,
+                registered_user_id, threepid, result["access_token"],
                 params.get("bind_email")
             )
 
-        result = yield self._create_registration_details(registered_user_id,
-                                                         access_token)
         defer.returnValue((200, result))
 
     def on_OPTIONS(self, _):
@@ -246,10 +253,10 @@ class RegisterRestServlet(RestServlet):
 
     @defer.inlineCallbacks
     def _do_appservice_registration(self, username, as_token):
-        (user_id, token) = yield self.registration_handler.appservice_register(
+        user_id = yield self.registration_handler.appservice_register(
             username, as_token
         )
-        defer.returnValue((yield self._create_registration_details(user_id, token)))
+        defer.returnValue((yield self._create_registration_details(user_id)))
 
     @defer.inlineCallbacks
     def _do_shared_secret_registration(self, username, password, mac):
@@ -273,10 +280,12 @@ class RegisterRestServlet(RestServlet):
                 403, "HMAC incorrect",
             )
 
-        (user_id, token) = yield self.registration_handler.register(
-            localpart=username, password=password
+        (user_id, _) = yield self.registration_handler.register(
+            localpart=username, password=password, generate_token=False,
         )
-        defer.returnValue((yield self._create_registration_details(user_id, token)))
+
+        result = yield self._create_registration_details(user_id)
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def _register_email_threepid(self, user_id, threepid, token, bind_email):
@@ -349,11 +358,31 @@ class RegisterRestServlet(RestServlet):
         defer.returnValue()
 
     @defer.inlineCallbacks
-    def _create_registration_details(self, user_id, token):
-        refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
+    def _create_registration_details(self, user_id):
+        """Complete registration of newly-registered user
+
+        Issues access_token and refresh_token, and builds the success response
+        body.
+
+        Args:
+            (str) user_id: full canonical @user:id
+
+
+        Returns:
+            defer.Deferred: (object) dictionary for response from /register
+        """
+
+        access_token = yield self.auth_handler.issue_access_token(
+            user_id
+        )
+
+        refresh_token = yield self.auth_handler.issue_refresh_token(
+            user_id
+        )
+
         defer.returnValue({
             "user_id": user_id,
-            "access_token": token,
+            "access_token": access_token,
             "home_server": self.hs.hostname,
             "refresh_token": refresh_token,
         })
@@ -366,7 +395,11 @@ class RegisterRestServlet(RestServlet):
             generate_token=False,
             make_guest=True
         )
-        access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
+        access_token = self.auth_handler.generate_access_token(
+            user_id, ["guest = true"]
+        )
+        # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
+        # so long as we don't return a refresh_token here.
         defer.returnValue((200, {
             "user_id": user_id,
             "access_token": access_token,
diff --git a/synapse/server.pyi b/synapse/server.pyi
new file mode 100644
index 0000000000..902f725c06
--- /dev/null
+++ b/synapse/server.pyi
@@ -0,0 +1,21 @@
+import synapse.handlers
+import synapse.handlers.auth
+import synapse.handlers.device
+import synapse.storage
+import synapse.state
+
+class HomeServer(object):
+    def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
+        pass
+
+    def get_datastore(self) -> synapse.storage.DataStore:
+        pass
+
+    def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
+        pass
+
+    def get_handlers(self) -> synapse.handlers.Handlers:
+        pass
+
+    def get_state_handler(self) -> synapse.state.StateHandler:
+        pass
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 26ef1cfd8a..9a92b35361 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -81,14 +81,16 @@ class RegistrationStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def register(self, user_id, token, password_hash,
+    def register(self, user_id, token=None, password_hash=None,
                  was_guest=False, make_guest=False, appservice_id=None,
                  create_profile_with_localpart=None, admin=False):
         """Attempts to register an account.
 
         Args:
             user_id (str): The desired user ID to register.
-            token (str): The desired access token to use for this user.
+            token (str): The desired access token to use for this user. If this
+                is not None, the given access token is associated with the user
+                id.
             password_hash (str): Optional. The password hash for this user.
             was_guest (bool): Optional. Whether this is a guest account being
                 upgraded to a non-guest account.
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 9a4215fef7..ccbb8776d3 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -61,8 +61,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "id": "1234"
         }
         self.registration_handler.appservice_register = Mock(
-            return_value=(user_id, token)
+            return_value=user_id
         )
+        self.auth_handler.issue_access_token = Mock(return_value=token)
+
         (code, result) = yield self.servlet.on_POST(self.request)
         self.assertEquals(code, 200)
         det_data = {
@@ -126,6 +128,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
         }
         self.assertDictContainsSubset(det_data, result)
         self.assertIn("refresh_token", result)
+        self.auth_handler.issue_access_token.assert_called_once_with(
+            user_id)
 
     def test_POST_disabled_registration(self):
         self.hs.config.enable_registration = False