diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d3e9837c81..44e38b777a 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -612,7 +612,8 @@ class Auth(object):
def get_user_from_macaroon(self, macaroon_str):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
- self.validate_macaroon(macaroon, "access", False)
+
+ self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
user_prefix = "user_id = "
user = None
diff --git a/synapse/config/key.py b/synapse/config/key.py
index a072aec714..6ee643793e 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -57,6 +57,8 @@ class KeyConfig(Config):
seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed)
+ self.expire_access_token = config.get("expire_access_token", False)
+
def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
@@ -69,6 +71,9 @@ class KeyConfig(Config):
return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
+ # Used to enable access token expiration.
+ expire_access_token: False
+
## Signing Keys ##
# Path to the signing key to sign messages with
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 87e500c97a..cc3f879857 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -32,6 +32,7 @@ class RegistrationConfig(Config):
)
self.registration_shared_secret = config.get("registration_shared_secret")
+ self.user_creation_max_duration = int(config["user_creation_max_duration"])
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
@@ -54,6 +55,11 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
+ # Sets the expiry for the short term user creation in
+ # milliseconds. For instance the bellow duration is two weeks
+ # in milliseconds.
+ user_creation_max_duration: 1209600000
+
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6e7d080ecc..68d0d78fc6 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -521,11 +521,11 @@ class AuthHandler(BaseHandler):
))
return m.serialize()
- def generate_short_term_login_token(self, user_id):
+ def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
- expiry = now + (2 * 60 * 1000)
+ expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index b0862067e1..5883b9111e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -358,6 +358,59 @@ class RegistrationHandler(BaseHandler):
)
defer.returnValue(data)
+ @defer.inlineCallbacks
+ def get_or_create_user(self, localpart, displayname, duration_seconds):
+ """Creates a new user or returns an access token for an existing one
+
+ Args:
+ localpart : The local part of the user ID to register. If None,
+ one will be randomly generated.
+ Returns:
+ A tuple of (user_id, access_token).
+ Raises:
+ RegistrationError if there was a problem registering.
+ """
+ yield run_on_reactor()
+
+ if localpart is None:
+ raise SynapseError(400, "Request must include user id")
+
+ need_register = True
+
+ try:
+ yield self.check_username(localpart)
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ need_register = False
+ else:
+ raise
+
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+ auth_handler = self.hs.get_handlers().auth_handler
+ token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
+
+ if need_register:
+ yield self.store.register(
+ user_id=user_id,
+ token=token,
+ password_hash=None
+ )
+
+ yield registered_user(self.distributor, user)
+ else:
+ yield self.store.flush_user(user_id=user_id)
+ yield self.store.add_access_token_to_user(user_id=user_id, token=token)
+
+ if displayname is not None:
+ logger.info("setting user display name: %s -> %s", user_id, displayname)
+ profile_handler = self.hs.get_handlers().profile_handler
+ yield profile_handler.set_displayname(
+ user, user, displayname
+ )
+
+ defer.returnValue((user_id, token))
+
def auth_handler(self):
return self.hs.get_handlers().auth_handler
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index c6a2ef2ccc..e3f4fbb0bb 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet):
)
+class CreateUserRestServlet(ClientV1RestServlet):
+ """Handles user creation via a server-to-server interface
+ """
+
+ PATTERNS = client_path_patterns("/createUser$", releases=())
+
+ def __init__(self, hs):
+ super(CreateUserRestServlet, self).__init__(hs)
+ self.store = hs.get_datastore()
+ self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ user_json = parse_json_object_from_request(request)
+
+ if "access_token" not in request.args:
+ raise SynapseError(400, "Expected application service token.")
+
+ app_service = yield self.store.get_app_service_by_token(
+ request.args["access_token"][0]
+ )
+ if not app_service:
+ raise SynapseError(403, "Invalid application service token.")
+
+ logger.debug("creating user: %s", user_json)
+
+ response = yield self._do_create(user_json)
+
+ defer.returnValue((200, response))
+
+ def on_OPTIONS(self, request):
+ return 403, {}
+
+ @defer.inlineCallbacks
+ def _do_create(self, user_json):
+ yield run_on_reactor()
+
+ if "localpart" not in user_json:
+ raise SynapseError(400, "Expected 'localpart' key.")
+
+ if "displayname" not in user_json:
+ raise SynapseError(400, "Expected 'displayname' key.")
+
+ if "duration_seconds" not in user_json:
+ raise SynapseError(400, "Expected 'duration_seconds' key.")
+
+ localpart = user_json["localpart"].encode("utf-8")
+ displayname = user_json["displayname"].encode("utf-8")
+ duration_seconds = 0
+ try:
+ duration_seconds = int(user_json["duration_seconds"])
+ except ValueError:
+ raise SynapseError(400, "Failed to parse 'duration_seconds'")
+ if duration_seconds > self.direct_user_creation_max_duration:
+ duration_seconds = self.direct_user_creation_max_duration
+
+ handler = self.handlers.registration_handler
+ user_id, token = yield handler.get_or_create_user(
+ localpart=localpart,
+ displayname=displayname,
+ duration_seconds=duration_seconds
+ )
+
+ defer.returnValue({
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname,
+ })
+
+
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
+ CreateUserRestServlet(hs).register(http_server)
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 7e7b0b4b1d..ad269af0ec 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -284,12 +284,12 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("time < 1") # ms
self.hs.clock.now = 5000 # seconds
-
- yield self.auth.get_user_from_macaroon(macaroon.serialize())
+ self.hs.config.expire_access_token = True
+ # yield self.auth.get_user_from_macaroon(macaroon.serialize())
# TODO(daniel): Turn on the check that we validate expiration, when we
# validate expiration (and remove the above line, which will start
# throwing).
- # with self.assertRaises(AuthError) as cm:
- # yield self.auth.get_user_from_macaroon(macaroon.serialize())
- # self.assertEqual(401, cm.exception.code)
- # self.assertIn("Invalid macaroon", cm.exception.msg)
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth.get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("Invalid macaroon", cm.exception.msg)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
new file mode 100644
index 0000000000..8b7be96bd9
--- /dev/null
+++ b/tests/handlers/test_register.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from .. import unittest
+
+from synapse.handlers.register import RegistrationHandler
+
+from tests.utils import setup_test_homeserver
+
+from mock import Mock
+
+
+class RegistrationHandlers(object):
+ def __init__(self, hs):
+ self.registration_handler = RegistrationHandler(hs)
+
+
+class RegistrationTestCase(unittest.TestCase):
+ """ Tests the RegistrationHandler. """
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.mock_distributor = Mock()
+ self.mock_distributor.declare("registered_user")
+ self.mock_captcha_client = Mock()
+ hs = yield setup_test_homeserver(
+ handlers=None,
+ http_client=None,
+ expire_access_token=True)
+ hs.handlers = RegistrationHandlers(hs)
+ self.handler = hs.get_handlers().registration_handler
+ hs.get_handlers().profile_handler = Mock()
+ self.mock_handler = Mock(spec=[
+ "generate_short_term_login_token",
+ ])
+
+ hs.get_handlers().auth_handler = self.mock_handler
+
+ @defer.inlineCallbacks
+ def test_user_is_created_and_logged_in_if_doesnt_exist(self):
+ """
+ Returns:
+ The user doess not exist in this case so it will register and log it in
+ """
+ duration_ms = 200
+ local_part = "someone"
+ display_name = "someone"
+ user_id = "@someone:test"
+ mock_token = self.mock_handler.generate_short_term_login_token
+ mock_token.return_value = 'secret'
+ result_user_id, result_token = yield self.handler.get_or_create_user(
+ local_part, display_name, duration_ms)
+ self.assertEquals(result_user_id, user_id)
+ self.assertEquals(result_token, 'secret')
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
new file mode 100644
index 0000000000..4a898a034f
--- /dev/null
+++ b/tests/rest/client/v1/test_register.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest.client.v1.register import CreateUserRestServlet
+from twisted.internet import defer
+from mock import Mock
+from tests import unittest
+import json
+
+
+class CreateUserServletTestCase(unittest.TestCase):
+
+ def setUp(self):
+ # do the dance to hook up request data to self.request_data
+ self.request_data = ""
+ self.request = Mock(
+ content=Mock(read=Mock(side_effect=lambda: self.request_data)),
+ path='/_matrix/client/api/v1/createUser'
+ )
+ self.request.args = {}
+
+ self.appservice = None
+ self.auth = Mock(get_appservice_by_req=Mock(
+ side_effect=lambda x: defer.succeed(self.appservice))
+ )
+
+ self.auth_result = (False, None, None, None)
+ self.auth_handler = Mock(
+ check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
+ get_session_data=Mock(return_value=None)
+ )
+ self.registration_handler = Mock()
+ self.identity_handler = Mock()
+ self.login_handler = Mock()
+
+ # do the dance to hook it up to the hs global
+ self.handlers = Mock(
+ auth_handler=self.auth_handler,
+ registration_handler=self.registration_handler,
+ identity_handler=self.identity_handler,
+ login_handler=self.login_handler
+ )
+ self.hs = Mock()
+ self.hs.hostname = "supergbig~testing~thing.com"
+ self.hs.get_auth = Mock(return_value=self.auth)
+ self.hs.get_handlers = Mock(return_value=self.handlers)
+ self.hs.config.enable_registration = True
+ # init the thing we're testing
+ self.servlet = CreateUserRestServlet(self.hs)
+
+ @defer.inlineCallbacks
+ def test_POST_createuser_with_valid_user(self):
+ user_id = "@someone:interesting"
+ token = "my token"
+ self.request.args = {
+ "access_token": "i_am_an_app_service"
+ }
+ self.request_data = json.dumps({
+ "localpart": "someone",
+ "displayname": "someone interesting",
+ "duration_seconds": 200
+ })
+
+ self.registration_handler.get_or_create_user = Mock(
+ return_value=(user_id, token)
+ )
+
+ (code, result) = yield self.servlet.on_POST(self.request)
+ self.assertEquals(code, 200)
+
+ det_data = {
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname
+ }
+ self.assertDictContainsSubset(det_data, result)
diff --git a/tests/utils.py b/tests/utils.py
index c179df31ee..9d7978a642 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -49,6 +49,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.event_cache_size = 1
config.enable_registration = True
config.macaroon_secret_key = "not even a little secret"
+ config.expire_access_token = False
config.server_name = "server.under.test"
config.trusted_third_party_id_servers = []
config.room_invite_state_types = []
|