summary refs log tree commit diff
diff options
context:
space:
mode:
authorNegi Fazeli <negar.fazeli@ericsson.com>2016-04-20 16:21:40 +0200
committerNegar Fazeli <negar.fazeli@ericsson.com>2016-05-13 15:34:15 +0200
commit40aa6e8349b348802d6f87084c31c3895f728708 (patch)
tree2150d48917b62e0317a73e7892061d73e2fe62ec
parentPass through _get_state_group_for_events (diff)
downloadsynapse-40aa6e8349b348802d6f87084c31c3895f728708.tar.xz
Create user with expiry
 - Add unittests for client, api and handler

Signed-off-by: Negar Fazeli <negar.fazeli@ericsson.com>
-rw-r--r--synapse/api/auth.py3
-rw-r--r--synapse/config/key.py5
-rw-r--r--synapse/config/registration.py6
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/register.py53
-rw-r--r--synapse/rest/client/v1/register.py71
-rw-r--r--tests/api/test_auth.py12
-rw-r--r--tests/handlers/test_register.py67
-rw-r--r--tests/rest/client/v1/test_register.py88
-rw-r--r--tests/utils.py1
10 files changed, 301 insertions, 9 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d3e9837c81..44e38b777a 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -612,7 +612,8 @@ class Auth(object):
     def get_user_from_macaroon(self, macaroon_str):
         try:
             macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
-            self.validate_macaroon(macaroon, "access", False)
+
+            self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
 
             user_prefix = "user_id = "
             user = None
diff --git a/synapse/config/key.py b/synapse/config/key.py
index a072aec714..6ee643793e 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -57,6 +57,8 @@ class KeyConfig(Config):
             seed = self.signing_key[0].seed
             self.macaroon_secret_key = hashlib.sha256(seed)
 
+        self.expire_access_token = config.get("expire_access_token", False)
+
     def default_config(self, config_dir_path, server_name, is_generating_file=False,
                        **kwargs):
         base_key_name = os.path.join(config_dir_path, server_name)
@@ -69,6 +71,9 @@ class KeyConfig(Config):
         return """\
         macaroon_secret_key: "%(macaroon_secret_key)s"
 
+        # Used to enable access token expiration.
+        expire_access_token: False
+
         ## Signing Keys ##
 
         # Path to the signing key to sign messages with
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 87e500c97a..cc3f879857 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -32,6 +32,7 @@ class RegistrationConfig(Config):
             )
 
         self.registration_shared_secret = config.get("registration_shared_secret")
+        self.user_creation_max_duration = int(config["user_creation_max_duration"])
 
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
         self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
@@ -54,6 +55,11 @@ class RegistrationConfig(Config):
         # secret, even if registration is otherwise disabled.
         registration_shared_secret: "%(registration_shared_secret)s"
 
+        # Sets the expiry for the short term user creation in
+        # milliseconds. For instance the bellow duration is two weeks
+        # in milliseconds.
+        user_creation_max_duration: 1209600000
+
         # Set the number of bcrypt rounds used to generate password hash.
         # Larger numbers increase the work factor needed to generate the hash.
         # The default number of rounds is 12.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 61fe56032a..3d36d3460e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -521,11 +521,11 @@ class AuthHandler(BaseHandler):
         ))
         return m.serialize()
 
-    def generate_short_term_login_token(self, user_id):
+    def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = login")
         now = self.hs.get_clock().time_msec()
-        expiry = now + (2 * 60 * 1000)
+        expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
         return macaroon.serialize()
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index b0862067e1..5883b9111e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -358,6 +358,59 @@ class RegistrationHandler(BaseHandler):
         )
         defer.returnValue(data)
 
+    @defer.inlineCallbacks
+    def get_or_create_user(self, localpart, displayname, duration_seconds):
+        """Creates a new user or returns an access token for an existing one
+
+        Args:
+            localpart : The local part of the user ID to register. If None,
+              one will be randomly generated.
+        Returns:
+            A tuple of (user_id, access_token).
+        Raises:
+            RegistrationError if there was a problem registering.
+        """
+        yield run_on_reactor()
+
+        if localpart is None:
+            raise SynapseError(400, "Request must include user id")
+
+        need_register = True
+
+        try:
+            yield self.check_username(localpart)
+        except SynapseError as e:
+            if e.errcode == Codes.USER_IN_USE:
+                need_register = False
+            else:
+                raise
+
+        user = UserID(localpart, self.hs.hostname)
+        user_id = user.to_string()
+        auth_handler = self.hs.get_handlers().auth_handler
+        token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
+
+        if need_register:
+            yield self.store.register(
+                user_id=user_id,
+                token=token,
+                password_hash=None
+            )
+
+            yield registered_user(self.distributor, user)
+        else:
+            yield self.store.flush_user(user_id=user_id)
+            yield self.store.add_access_token_to_user(user_id=user_id, token=token)
+
+        if displayname is not None:
+            logger.info("setting user display name: %s -> %s", user_id, displayname)
+            profile_handler = self.hs.get_handlers().profile_handler
+            yield profile_handler.set_displayname(
+                user, user, displayname
+            )
+
+        defer.returnValue((user_id, token))
+
     def auth_handler(self):
         return self.hs.get_handlers().auth_handler
 
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index c6a2ef2ccc..e3f4fbb0bb 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet):
             )
 
 
+class CreateUserRestServlet(ClientV1RestServlet):
+    """Handles user creation via a server-to-server interface
+    """
+
+    PATTERNS = client_path_patterns("/createUser$", releases=())
+
+    def __init__(self, hs):
+        super(CreateUserRestServlet, self).__init__(hs)
+        self.store = hs.get_datastore()
+        self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        user_json = parse_json_object_from_request(request)
+
+        if "access_token" not in request.args:
+            raise SynapseError(400, "Expected application service token.")
+
+        app_service = yield self.store.get_app_service_by_token(
+            request.args["access_token"][0]
+        )
+        if not app_service:
+            raise SynapseError(403, "Invalid application service token.")
+
+        logger.debug("creating user: %s", user_json)
+
+        response = yield self._do_create(user_json)
+
+        defer.returnValue((200, response))
+
+    def on_OPTIONS(self, request):
+        return 403, {}
+
+    @defer.inlineCallbacks
+    def _do_create(self, user_json):
+        yield run_on_reactor()
+
+        if "localpart" not in user_json:
+            raise SynapseError(400, "Expected 'localpart' key.")
+
+        if "displayname" not in user_json:
+            raise SynapseError(400, "Expected 'displayname' key.")
+
+        if "duration_seconds" not in user_json:
+            raise SynapseError(400, "Expected 'duration_seconds' key.")
+
+        localpart = user_json["localpart"].encode("utf-8")
+        displayname = user_json["displayname"].encode("utf-8")
+        duration_seconds = 0
+        try:
+            duration_seconds = int(user_json["duration_seconds"])
+        except ValueError:
+            raise SynapseError(400, "Failed to parse 'duration_seconds'")
+        if duration_seconds > self.direct_user_creation_max_duration:
+            duration_seconds = self.direct_user_creation_max_duration
+
+        handler = self.handlers.registration_handler
+        user_id, token = yield handler.get_or_create_user(
+            localpart=localpart,
+            displayname=displayname,
+            duration_seconds=duration_seconds
+        )
+
+        defer.returnValue({
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname,
+        })
+
+
 def register_servlets(hs, http_server):
     RegisterRestServlet(hs).register(http_server)
+    CreateUserRestServlet(hs).register(http_server)
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 7e7b0b4b1d..ad269af0ec 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -284,12 +284,12 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("time < 1")  # ms
 
         self.hs.clock.now = 5000  # seconds
-
-        yield self.auth.get_user_from_macaroon(macaroon.serialize())
+        self.hs.config.expire_access_token = True
+        # yield self.auth.get_user_from_macaroon(macaroon.serialize())
         # TODO(daniel): Turn on the check that we validate expiration, when we
         # validate expiration (and remove the above line, which will start
         # throwing).
-        # with self.assertRaises(AuthError) as cm:
-        #     yield self.auth.get_user_from_macaroon(macaroon.serialize())
-        # self.assertEqual(401, cm.exception.code)
-        # self.assertIn("Invalid macaroon", cm.exception.msg)
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth.get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("Invalid macaroon", cm.exception.msg)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
new file mode 100644
index 0000000000..8b7be96bd9
--- /dev/null
+++ b/tests/handlers/test_register.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from .. import unittest
+
+from synapse.handlers.register import RegistrationHandler
+
+from tests.utils import setup_test_homeserver
+
+from mock import Mock
+
+
+class RegistrationHandlers(object):
+    def __init__(self, hs):
+        self.registration_handler = RegistrationHandler(hs)
+
+
+class RegistrationTestCase(unittest.TestCase):
+    """ Tests the RegistrationHandler. """
+
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.mock_distributor = Mock()
+        self.mock_distributor.declare("registered_user")
+        self.mock_captcha_client = Mock()
+        hs = yield setup_test_homeserver(
+            handlers=None,
+            http_client=None,
+            expire_access_token=True)
+        hs.handlers = RegistrationHandlers(hs)
+        self.handler = hs.get_handlers().registration_handler
+        hs.get_handlers().profile_handler = Mock()
+        self.mock_handler = Mock(spec=[
+            "generate_short_term_login_token",
+        ])
+
+        hs.get_handlers().auth_handler = self.mock_handler
+
+    @defer.inlineCallbacks
+    def test_user_is_created_and_logged_in_if_doesnt_exist(self):
+        """
+        Returns:
+            The user doess not exist in this case so it will register and log it in
+        """
+        duration_ms = 200
+        local_part = "someone"
+        display_name = "someone"
+        user_id = "@someone:test"
+        mock_token = self.mock_handler.generate_short_term_login_token
+        mock_token.return_value = 'secret'
+        result_user_id, result_token = yield self.handler.get_or_create_user(
+            local_part, display_name, duration_ms)
+        self.assertEquals(result_user_id, user_id)
+        self.assertEquals(result_token, 'secret')
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
new file mode 100644
index 0000000000..4a898a034f
--- /dev/null
+++ b/tests/rest/client/v1/test_register.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest.client.v1.register import CreateUserRestServlet
+from twisted.internet import defer
+from mock import Mock
+from tests import unittest
+import json
+
+
+class CreateUserServletTestCase(unittest.TestCase):
+
+    def setUp(self):
+        # do the dance to hook up request data to self.request_data
+        self.request_data = ""
+        self.request = Mock(
+            content=Mock(read=Mock(side_effect=lambda: self.request_data)),
+            path='/_matrix/client/api/v1/createUser'
+        )
+        self.request.args = {}
+
+        self.appservice = None
+        self.auth = Mock(get_appservice_by_req=Mock(
+            side_effect=lambda x: defer.succeed(self.appservice))
+        )
+
+        self.auth_result = (False, None, None, None)
+        self.auth_handler = Mock(
+            check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
+            get_session_data=Mock(return_value=None)
+        )
+        self.registration_handler = Mock()
+        self.identity_handler = Mock()
+        self.login_handler = Mock()
+
+        # do the dance to hook it up to the hs global
+        self.handlers = Mock(
+            auth_handler=self.auth_handler,
+            registration_handler=self.registration_handler,
+            identity_handler=self.identity_handler,
+            login_handler=self.login_handler
+        )
+        self.hs = Mock()
+        self.hs.hostname = "supergbig~testing~thing.com"
+        self.hs.get_auth = Mock(return_value=self.auth)
+        self.hs.get_handlers = Mock(return_value=self.handlers)
+        self.hs.config.enable_registration = True
+        # init the thing we're testing
+        self.servlet = CreateUserRestServlet(self.hs)
+
+    @defer.inlineCallbacks
+    def test_POST_createuser_with_valid_user(self):
+        user_id = "@someone:interesting"
+        token = "my token"
+        self.request.args = {
+            "access_token": "i_am_an_app_service"
+        }
+        self.request_data = json.dumps({
+            "localpart": "someone",
+            "displayname": "someone interesting",
+            "duration_seconds": 200
+        })
+
+        self.registration_handler.get_or_create_user = Mock(
+            return_value=(user_id, token)
+        )
+
+        (code, result) = yield self.servlet.on_POST(self.request)
+        self.assertEquals(code, 200)
+
+        det_data = {
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname
+        }
+        self.assertDictContainsSubset(det_data, result)
diff --git a/tests/utils.py b/tests/utils.py
index c179df31ee..9d7978a642 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -49,6 +49,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.event_cache_size = 1
         config.enable_registration = True
         config.macaroon_secret_key = "not even a little secret"
+        config.expire_access_token = False
         config.server_name = "server.under.test"
         config.trusted_third_party_id_servers = []
         config.room_invite_state_types = []