summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/api/auth.py6
-rw-r--r--synapse/handlers/auth.py6
-rw-r--r--synapse/handlers/register.py37
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/rest/client/v2_alpha/register.py12
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py2
-rw-r--r--synapse/storage/prepare_database.py4
-rw-r--r--synapse/storage/registration.py23
-rw-r--r--synapse/storage/schema/delta/28/upgrade_times.sql21
-rw-r--r--tests/api/test_auth.py18
11 files changed, 93 insertions, 40 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index adb7d64482..b86c6c8399 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -583,7 +583,7 @@ class Auth(object):
             AuthError if no user by that token exists or the token is invalid.
         """
         try:
-            ret = yield self._get_user_from_macaroon(token)
+            ret = yield self.get_user_from_macaroon(token)
         except AuthError:
             # TODO(daniel): Remove this fallback when all existing access tokens
             # have been re-issued as macaroons.
@@ -591,7 +591,7 @@ class Auth(object):
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
-    def _get_user_from_macaroon(self, macaroon_str):
+    def get_user_from_macaroon(self, macaroon_str):
         try:
             macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
             self.validate_macaroon(macaroon, "access", False)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index e64b67cdfd..62e82a2570 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -408,7 +408,7 @@ class AuthHandler(BaseHandler):
             macaroon = pymacaroons.Macaroon.deserialize(login_token)
             auth_api = self.hs.get_auth()
             auth_api.validate_macaroon(macaroon, "login", True)
-            return self._get_user_from_macaroon(macaroon)
+            return self.get_user_from_macaroon(macaroon)
         except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
             raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
 
@@ -421,7 +421,7 @@ class AuthHandler(BaseHandler):
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
         return macaroon
 
-    def _get_user_from_macaroon(self, macaroon):
+    def get_user_from_macaroon(self, macaroon):
         user_prefix = "user_id = "
         for caveat in macaroon.caveats:
             if caveat.caveat_id.startswith(user_prefix):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index baf7c14e40..6f111ff63e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -40,12 +40,13 @@ class RegistrationHandler(BaseHandler):
     def __init__(self, hs):
         super(RegistrationHandler, self).__init__(hs)
 
+        self.auth = hs.get_auth()
         self.distributor = hs.get_distributor()
         self.distributor.declare("registered_user")
         self.captcha_client = CaptchaServerHttpClient(hs)
 
     @defer.inlineCallbacks
-    def check_username(self, localpart):
+    def check_username(self, localpart, guest_access_token=None):
         yield run_on_reactor()
 
         if urllib.quote(localpart) != localpart:
@@ -62,14 +63,29 @@ class RegistrationHandler(BaseHandler):
 
         users = yield self.store.get_users_by_id_case_insensitive(user_id)
         if users:
-            raise SynapseError(
-                400,
-                "User ID already taken.",
-                errcode=Codes.USER_IN_USE,
-            )
+            if not guest_access_token:
+                raise SynapseError(
+                    400,
+                    "User ID already taken.",
+                    errcode=Codes.USER_IN_USE,
+                )
+            user_data = yield self.auth.get_user_from_macaroon(guest_access_token)
+            if not user_data["is_guest"] or user_data["user"].localpart != localpart:
+                raise AuthError(
+                    403,
+                    "Cannot register taken user ID without valid guest "
+                    "credentials for that user.",
+                    errcode=Codes.FORBIDDEN,
+                )
 
     @defer.inlineCallbacks
-    def register(self, localpart=None, password=None, generate_token=True):
+    def register(
+        self,
+        localpart=None,
+        password=None,
+        generate_token=True,
+        guest_access_token=None
+    ):
         """Registers a new client on the server.
 
         Args:
@@ -89,7 +105,7 @@ class RegistrationHandler(BaseHandler):
             password_hash = self.auth_handler().hash(password)
 
         if localpart:
-            yield self.check_username(localpart)
+            yield self.check_username(localpart, guest_access_token=guest_access_token)
 
             user = UserID(localpart, self.hs.hostname)
             user_id = user.to_string()
@@ -100,7 +116,8 @@ class RegistrationHandler(BaseHandler):
             yield self.store.register(
                 user_id=user_id,
                 token=token,
-                password_hash=password_hash
+                password_hash=password_hash,
+                was_guest=guest_access_token is not None,
             )
 
             yield registered_user(self.distributor, user)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 0cfeda10d8..6186c37c7c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9796f2a57f..41a42418a9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b2b89652c6..25389ceded 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -119,8 +119,13 @@ class RegisterRestServlet(RestServlet):
         if self.hs.config.disable_registration:
             raise SynapseError(403, "Registration has been disabled")
 
+        guest_access_token = body.get("guest_access_token", None)
+
         if desired_username is not None:
-            yield self.registration_handler.check_username(desired_username)
+            yield self.registration_handler.check_username(
+                desired_username,
+                guest_access_token=guest_access_token
+            )
 
         if self.hs.config.enable_registration_captcha:
             flows = [
@@ -150,7 +155,8 @@ class RegisterRestServlet(RestServlet):
 
         (user_id, token) = yield self.registration_handler.register(
             localpart=desired_username,
-            password=new_password
+            password=new_password,
+            guest_access_token=guest_access_token,
         )
 
         if result and LoginType.EMAIL_IDENTITY in result:
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 8b8fba3dc7..c18160534e 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 16eff62544..c1f5f99789 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 27
+SCHEMA_VERSION = 28
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 09a05b08ef..f0fa0bd33c 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -73,30 +73,39 @@ class RegistrationStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def register(self, user_id, token, password_hash):
+    def register(self, user_id, token, password_hash, was_guest=False):
         """Attempts to register an account.
 
         Args:
             user_id (str): The desired user ID to register.
             token (str): The desired access token to use for this user.
             password_hash (str): Optional. The password hash for this user.
+            was_guest (bool): Optional. Whether this is a guest account being
+                upgraded to a non-guest account.
         Raises:
             StoreError if the user_id could not be registered.
         """
         yield self.runInteraction(
             "register",
-            self._register, user_id, token, password_hash
+            self._register, user_id, token, password_hash, was_guest
         )
 
-    def _register(self, txn, user_id, token, password_hash):
+    def _register(self, txn, user_id, token, password_hash, was_guest):
         now = int(self.clock.time())
 
         next_id = self._access_tokens_id_gen.get_next_txn(txn)
 
         try:
-            txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
-                        "VALUES (?,?,?)",
-                        [user_id, password_hash, now])
+            if was_guest:
+                txn.execute("UPDATE users SET"
+                            " password_hash = ?,"
+                            " upgrade_ts = ?"
+                            " WHERE name = ?",
+                            [password_hash, now, user_id])
+            else:
+                txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
+                            "VALUES (?,?,?)",
+                            [user_id, password_hash, now])
         except self.database_engine.module.IntegrityError:
             raise StoreError(
                 400, "User ID already taken.", errcode=Codes.USER_IN_USE
diff --git a/synapse/storage/schema/delta/28/upgrade_times.sql b/synapse/storage/schema/delta/28/upgrade_times.sql
new file mode 100644
index 0000000000..3e4a9ab455
--- /dev/null
+++ b/synapse/storage/schema/delta/28/upgrade_times.sql
@@ -0,0 +1,21 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Stores the timestamp when a user upgraded from a guest to a full user, if
+ * that happened.
+ */
+
+ALTER TABLE users ADD COLUMN upgrade_ts BIGINT;
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 70d928defe..5ff4c8a873 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = access")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
-        user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
         user = user_info["user"]
         self.assertEqual(UserID.from_string(user_id), user)
 
@@ -171,7 +171,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("guest = true")
         serialized = macaroon.serialize()
 
-        user_info = yield self.auth._get_user_from_macaroon(serialized)
+        user_info = yield self.auth.get_user_from_macaroon(serialized)
         user = user_info["user"]
         is_guest = user_info["is_guest"]
         self.assertEqual(UserID.from_string(user_id), user)
@@ -192,7 +192,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("type = access")
         macaroon.add_first_party_caveat("user_id = %s" % (user,))
         with self.assertRaises(AuthError) as cm:
-            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+            yield self.auth.get_user_from_macaroon(macaroon.serialize())
         self.assertEqual(401, cm.exception.code)
         self.assertIn("User mismatch", cm.exception.msg)
 
@@ -212,7 +212,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("type = access")
 
         with self.assertRaises(AuthError) as cm:
-            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+            yield self.auth.get_user_from_macaroon(macaroon.serialize())
         self.assertEqual(401, cm.exception.code)
         self.assertIn("No user caveat", cm.exception.msg)
 
@@ -234,7 +234,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("user_id = %s" % (user,))
 
         with self.assertRaises(AuthError) as cm:
-            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+            yield self.auth.get_user_from_macaroon(macaroon.serialize())
         self.assertEqual(401, cm.exception.code)
         self.assertIn("Invalid macaroon", cm.exception.msg)
 
@@ -257,7 +257,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("cunning > fox")
 
         with self.assertRaises(AuthError) as cm:
-            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+            yield self.auth.get_user_from_macaroon(macaroon.serialize())
         self.assertEqual(401, cm.exception.code)
         self.assertIn("Invalid macaroon", cm.exception.msg)
 
@@ -285,11 +285,11 @@ class AuthTestCase(unittest.TestCase):
 
         self.hs.clock.now = 5000 # seconds
 
-        yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        yield self.auth.get_user_from_macaroon(macaroon.serialize())
         # TODO(daniel): Turn on the check that we validate expiration, when we
         # validate expiration (and remove the above line, which will start
         # throwing).
         # with self.assertRaises(AuthError) as cm:
-        #     yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        #     yield self.auth.get_user_from_macaroon(macaroon.serialize())
         # self.assertEqual(401, cm.exception.code)
         # self.assertIn("Invalid macaroon", cm.exception.msg)