summary refs log tree commit diff
path: root/synapse/rest/client/v1/register.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/register.py')
-rw-r--r--synapse/rest/client/v1/register.py33
1 files changed, 27 insertions, 6 deletions
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index e3f4fbb0bb..2383b9df86 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
 
     def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
         super(RegisterRestServlet, self).__init__(hs)
         # sessions are stored as:
         # self.sessions = {
@@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
         # TODO: persistent storage
         self.sessions = {}
         self.enable_registration = hs.config.enable_registration
+        self.auth_handler = hs.get_auth_handler()
 
     def on_GET(self, request):
         if self.hs.config.enable_registration_captcha:
@@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet):
         user_localpart = register_json["user"].encode("utf-8")
 
         handler = self.handlers.registration_handler
-        (user_id, token) = yield handler.appservice_register(
+        user_id = yield handler.appservice_register(
             user_localpart, as_token
         )
+        token = yield self.auth_handler.issue_access_token(user_id)
         self._remove_session(session)
         defer.returnValue({
             "user_id": user_id,
@@ -324,6 +330,14 @@ class RegisterRestServlet(ClientV1RestServlet):
             raise SynapseError(400, "Shared secret registration is not enabled")
 
         user = register_json["user"].encode("utf-8")
+        password = register_json["password"].encode("utf-8")
+        admin = register_json.get("admin", None)
+
+        # Its important to check as we use null bytes as HMAC field separators
+        if "\x00" in user:
+            raise SynapseError(400, "Invalid user")
+        if "\x00" in password:
+            raise SynapseError(400, "Invalid password")
 
         # str() because otherwise hmac complains that 'unicode' does not
         # have the buffer interface
@@ -331,17 +345,21 @@ class RegisterRestServlet(ClientV1RestServlet):
 
         want_mac = hmac.new(
             key=self.hs.config.registration_shared_secret,
-            msg=user,
             digestmod=sha1,
-        ).hexdigest()
-
-        password = register_json["password"].encode("utf-8")
+        )
+        want_mac.update(user)
+        want_mac.update("\x00")
+        want_mac.update(password)
+        want_mac.update("\x00")
+        want_mac.update("admin" if admin else "notadmin")
+        want_mac = want_mac.hexdigest()
 
         if compare_digest(want_mac, got_mac):
             handler = self.handlers.registration_handler
             user_id, token = yield handler.register(
                 localpart=user,
                 password=password,
+                admin=bool(admin),
             )
             self._remove_session(session)
             defer.returnValue({
@@ -410,12 +428,15 @@ class CreateUserRestServlet(ClientV1RestServlet):
             raise SynapseError(400, "Failed to parse 'duration_seconds'")
         if duration_seconds > self.direct_user_creation_max_duration:
             duration_seconds = self.direct_user_creation_max_duration
+        password_hash = user_json["password_hash"].encode("utf-8") \
+            if user_json.get("password_hash") else None
 
         handler = self.handlers.registration_handler
         user_id, token = yield handler.get_or_create_user(
             localpart=localpart,
             displayname=displayname,
-            duration_seconds=duration_seconds
+            duration_in_ms=(duration_seconds * 1000),
+            password_hash=password_hash
         )
 
         defer.returnValue({