diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 7e7b0b4b1d..ad269af0ec 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -284,12 +284,12 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("time < 1") # ms
self.hs.clock.now = 5000 # seconds
-
- yield self.auth.get_user_from_macaroon(macaroon.serialize())
+ self.hs.config.expire_access_token = True
+ # yield self.auth.get_user_from_macaroon(macaroon.serialize())
# TODO(daniel): Turn on the check that we validate expiration, when we
# validate expiration (and remove the above line, which will start
# throwing).
- # with self.assertRaises(AuthError) as cm:
- # yield self.auth.get_user_from_macaroon(macaroon.serialize())
- # self.assertEqual(401, cm.exception.code)
- # self.assertIn("Invalid macaroon", cm.exception.msg)
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth.get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("Invalid macaroon", cm.exception.msg)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
new file mode 100644
index 0000000000..8b7be96bd9
--- /dev/null
+++ b/tests/handlers/test_register.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from .. import unittest
+
+from synapse.handlers.register import RegistrationHandler
+
+from tests.utils import setup_test_homeserver
+
+from mock import Mock
+
+
+class RegistrationHandlers(object):
+ def __init__(self, hs):
+ self.registration_handler = RegistrationHandler(hs)
+
+
+class RegistrationTestCase(unittest.TestCase):
+ """ Tests the RegistrationHandler. """
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.mock_distributor = Mock()
+ self.mock_distributor.declare("registered_user")
+ self.mock_captcha_client = Mock()
+ hs = yield setup_test_homeserver(
+ handlers=None,
+ http_client=None,
+ expire_access_token=True)
+ hs.handlers = RegistrationHandlers(hs)
+ self.handler = hs.get_handlers().registration_handler
+ hs.get_handlers().profile_handler = Mock()
+ self.mock_handler = Mock(spec=[
+ "generate_short_term_login_token",
+ ])
+
+ hs.get_handlers().auth_handler = self.mock_handler
+
+ @defer.inlineCallbacks
+ def test_user_is_created_and_logged_in_if_doesnt_exist(self):
+ """
+ Returns:
+ The user doess not exist in this case so it will register and log it in
+ """
+ duration_ms = 200
+ local_part = "someone"
+ display_name = "someone"
+ user_id = "@someone:test"
+ mock_token = self.mock_handler.generate_short_term_login_token
+ mock_token.return_value = 'secret'
+ result_user_id, result_token = yield self.handler.get_or_create_user(
+ local_part, display_name, duration_ms)
+ self.assertEquals(result_user_id, user_id)
+ self.assertEquals(result_token, 'secret')
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
new file mode 100644
index 0000000000..4a898a034f
--- /dev/null
+++ b/tests/rest/client/v1/test_register.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest.client.v1.register import CreateUserRestServlet
+from twisted.internet import defer
+from mock import Mock
+from tests import unittest
+import json
+
+
+class CreateUserServletTestCase(unittest.TestCase):
+
+ def setUp(self):
+ # do the dance to hook up request data to self.request_data
+ self.request_data = ""
+ self.request = Mock(
+ content=Mock(read=Mock(side_effect=lambda: self.request_data)),
+ path='/_matrix/client/api/v1/createUser'
+ )
+ self.request.args = {}
+
+ self.appservice = None
+ self.auth = Mock(get_appservice_by_req=Mock(
+ side_effect=lambda x: defer.succeed(self.appservice))
+ )
+
+ self.auth_result = (False, None, None, None)
+ self.auth_handler = Mock(
+ check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
+ get_session_data=Mock(return_value=None)
+ )
+ self.registration_handler = Mock()
+ self.identity_handler = Mock()
+ self.login_handler = Mock()
+
+ # do the dance to hook it up to the hs global
+ self.handlers = Mock(
+ auth_handler=self.auth_handler,
+ registration_handler=self.registration_handler,
+ identity_handler=self.identity_handler,
+ login_handler=self.login_handler
+ )
+ self.hs = Mock()
+ self.hs.hostname = "supergbig~testing~thing.com"
+ self.hs.get_auth = Mock(return_value=self.auth)
+ self.hs.get_handlers = Mock(return_value=self.handlers)
+ self.hs.config.enable_registration = True
+ # init the thing we're testing
+ self.servlet = CreateUserRestServlet(self.hs)
+
+ @defer.inlineCallbacks
+ def test_POST_createuser_with_valid_user(self):
+ user_id = "@someone:interesting"
+ token = "my token"
+ self.request.args = {
+ "access_token": "i_am_an_app_service"
+ }
+ self.request_data = json.dumps({
+ "localpart": "someone",
+ "displayname": "someone interesting",
+ "duration_seconds": 200
+ })
+
+ self.registration_handler.get_or_create_user = Mock(
+ return_value=(user_id, token)
+ )
+
+ (code, result) = yield self.servlet.on_POST(self.request)
+ self.assertEquals(code, 200)
+
+ det_data = {
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname
+ }
+ self.assertDictContainsSubset(det_data, result)
diff --git a/tests/utils.py b/tests/utils.py
index c179df31ee..9d7978a642 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -49,6 +49,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.event_cache_size = 1
config.enable_registration = True
config.macaroon_secret_key = "not even a little secret"
+ config.expire_access_token = False
config.server_name = "server.under.test"
config.trusted_third_party_id_servers = []
config.room_invite_state_types = []
|