summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/v2_alpha/test_register.py134
-rw-r--r--tests/test_state.py2
-rw-r--r--tests/util/test_dict_cache.py101
3 files changed, 236 insertions, 1 deletions
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
new file mode 100644

index 0000000000..66fd25964d --- /dev/null +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -0,0 +1,134 @@ +from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.api.errors import SynapseError +from twisted.internet import defer +from mock import Mock, MagicMock +from tests import unittest +import json + + +class RegisterRestServletTestCase(unittest.TestCase): + + def setUp(self): + # do the dance to hook up request data to self.request_data + self.request_data = "" + self.request = Mock( + content=Mock(read=Mock(side_effect=lambda: self.request_data)), + ) + self.request.args = {} + + self.appservice = None + self.auth = Mock(get_appservice_by_req=Mock( + side_effect=lambda x: defer.succeed(self.appservice)) + ) + + self.auth_result = (False, None, None) + self.auth_handler = Mock( + check_auth=Mock(side_effect=lambda x,y,z: self.auth_result) + ) + self.registration_handler = Mock() + self.identity_handler = Mock() + self.login_handler = Mock() + + # do the dance to hook it up to the hs global + self.handlers = Mock( + auth_handler=self.auth_handler, + registration_handler=self.registration_handler, + identity_handler=self.identity_handler, + login_handler=self.login_handler + ) + self.hs = Mock() + self.hs.hostname = "superbig~testing~thing.com" + self.hs.get_auth = Mock(return_value=self.auth) + self.hs.get_handlers = Mock(return_value=self.handlers) + self.hs.config.disable_registration = False + + # init the thing we're testing + self.servlet = RegisterRestServlet(self.hs) + + @defer.inlineCallbacks + def test_POST_appservice_registration_valid(self): + user_id = "@kermit:muppet" + token = "kermits_access_token" + self.request.args = { + "access_token": "i_am_an_app_service" + } + self.request_data = json.dumps({ + "username": "kermit" + }) + self.appservice = { + "id": "1234" + } + self.registration_handler.appservice_register = Mock( + return_value=(user_id, token) + ) + result = yield self.servlet.on_POST(self.request) + self.assertEquals(result, (200, { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname + })) + + @defer.inlineCallbacks + def test_POST_appservice_registration_invalid(self): + self.request.args = { + "access_token": "i_am_an_app_service" + } + self.request_data = json.dumps({ + "username": "kermit" + }) + self.appservice = None # no application service exists + result = yield self.servlet.on_POST(self.request) + self.assertEquals(result, (401, None)) + + def test_POST_bad_password(self): + self.request_data = json.dumps({ + "username": "kermit", + "password": 666 + }) + d = self.servlet.on_POST(self.request) + return self.assertFailure(d, SynapseError) + + def test_POST_bad_username(self): + self.request_data = json.dumps({ + "username": 777, + "password": "monkey" + }) + d = self.servlet.on_POST(self.request) + return self.assertFailure(d, SynapseError) + + @defer.inlineCallbacks + def test_POST_user_valid(self): + user_id = "@kermit:muppet" + token = "kermits_access_token" + self.request_data = json.dumps({ + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.check_username = Mock(return_value=True) + self.auth_result = (True, None, { + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.register = Mock(return_value=(user_id, token)) + + result = yield self.servlet.on_POST(self.request) + self.assertEquals(result, (200, { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname + })) + + def test_POST_disabled_registration(self): + self.hs.config.disable_registration = True + self.request_data = json.dumps({ + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.check_username = Mock(return_value=True) + self.auth_result = (True, None, { + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.register = Mock(return_value=("@user:id", "t")) + d = self.servlet.on_POST(self.request) + return self.assertFailure(d, SynapseError) \ No newline at end of file diff --git a/tests/test_state.py b/tests/test_state.py
index fea25f7021..5845358754 100644 --- a/tests/test_state.py +++ b/tests/test_state.py
@@ -69,7 +69,7 @@ class StateGroupStore(object): self._next_group = 1 - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py new file mode 100644
index 0000000000..79bc1225d6 --- /dev/null +++ b/tests/util/test_dict_cache.py
@@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer +from tests import unittest + +from synapse.util.dictionary_cache import DictionaryCache + + +class DictCacheTestCase(unittest.TestCase): + + def setUp(self): + self.cache = DictionaryCache("foobar") + + def test_simple_cache_hit_full(self): + key = "test_simple_cache_hit_full" + + v = self.cache.get(key) + self.assertEqual((False, {}), v) + + seq = self.cache.sequence + test_value = {"test": "test_simple_cache_hit_full"} + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key) + self.assertEqual(test_value, c.value) + + def test_simple_cache_hit_partial(self): + key = "test_simple_cache_hit_partial" + + seq = self.cache.sequence + test_value = { + "test": "test_simple_cache_hit_partial" + } + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key, ["test"]) + self.assertEqual(test_value, c.value) + + def test_simple_cache_miss_partial(self): + key = "test_simple_cache_miss_partial" + + seq = self.cache.sequence + test_value = { + "test": "test_simple_cache_miss_partial" + } + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key, ["test2"]) + self.assertEqual({}, c.value) + + def test_simple_cache_hit_miss_partial(self): + key = "test_simple_cache_hit_miss_partial" + + seq = self.cache.sequence + test_value = { + "test": "test_simple_cache_hit_miss_partial", + "test2": "test_simple_cache_hit_miss_partial2", + "test3": "test_simple_cache_hit_miss_partial3", + } + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key, ["test2"]) + self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) + + def test_multi_insert(self): + key = "test_simple_cache_hit_miss_partial" + + seq = self.cache.sequence + test_value_1 = { + "test": "test_simple_cache_hit_miss_partial", + } + self.cache.update(seq, key, test_value_1, full=False) + + seq = self.cache.sequence + test_value_2 = { + "test2": "test_simple_cache_hit_miss_partial2", + } + self.cache.update(seq, key, test_value_2, full=False) + + c = self.cache.get(key) + self.assertEqual( + { + "test": "test_simple_cache_hit_miss_partial", + "test2": "test_simple_cache_hit_miss_partial2", + }, + c.value + )