diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/rest/client/v2_alpha/test_register.py | 134 | ||||
-rw-r--r-- | tests/storage/test__base.py | 21 | ||||
-rw-r--r-- | tests/test_distributor.py | 4 |
3 files changed, 148 insertions, 11 deletions
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py new file mode 100644 index 0000000000..66fd25964d --- /dev/null +++ b/tests/rest/client/v2_alpha/test_register.py @@ -0,0 +1,134 @@ +from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.api.errors import SynapseError +from twisted.internet import defer +from mock import Mock, MagicMock +from tests import unittest +import json + + +class RegisterRestServletTestCase(unittest.TestCase): + + def setUp(self): + # do the dance to hook up request data to self.request_data + self.request_data = "" + self.request = Mock( + content=Mock(read=Mock(side_effect=lambda: self.request_data)), + ) + self.request.args = {} + + self.appservice = None + self.auth = Mock(get_appservice_by_req=Mock( + side_effect=lambda x: defer.succeed(self.appservice)) + ) + + self.auth_result = (False, None, None) + self.auth_handler = Mock( + check_auth=Mock(side_effect=lambda x,y,z: self.auth_result) + ) + self.registration_handler = Mock() + self.identity_handler = Mock() + self.login_handler = Mock() + + # do the dance to hook it up to the hs global + self.handlers = Mock( + auth_handler=self.auth_handler, + registration_handler=self.registration_handler, + identity_handler=self.identity_handler, + login_handler=self.login_handler + ) + self.hs = Mock() + self.hs.hostname = "superbig~testing~thing.com" + self.hs.get_auth = Mock(return_value=self.auth) + self.hs.get_handlers = Mock(return_value=self.handlers) + self.hs.config.disable_registration = False + + # init the thing we're testing + self.servlet = RegisterRestServlet(self.hs) + + @defer.inlineCallbacks + def test_POST_appservice_registration_valid(self): + user_id = "@kermit:muppet" + token = "kermits_access_token" + self.request.args = { + "access_token": "i_am_an_app_service" + } + self.request_data = json.dumps({ + "username": "kermit" + }) + self.appservice = { + "id": "1234" + } + self.registration_handler.appservice_register = Mock( + return_value=(user_id, token) + ) + result = yield self.servlet.on_POST(self.request) + self.assertEquals(result, (200, { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname + })) + + @defer.inlineCallbacks + def test_POST_appservice_registration_invalid(self): + self.request.args = { + "access_token": "i_am_an_app_service" + } + self.request_data = json.dumps({ + "username": "kermit" + }) + self.appservice = None # no application service exists + result = yield self.servlet.on_POST(self.request) + self.assertEquals(result, (401, None)) + + def test_POST_bad_password(self): + self.request_data = json.dumps({ + "username": "kermit", + "password": 666 + }) + d = self.servlet.on_POST(self.request) + return self.assertFailure(d, SynapseError) + + def test_POST_bad_username(self): + self.request_data = json.dumps({ + "username": 777, + "password": "monkey" + }) + d = self.servlet.on_POST(self.request) + return self.assertFailure(d, SynapseError) + + @defer.inlineCallbacks + def test_POST_user_valid(self): + user_id = "@kermit:muppet" + token = "kermits_access_token" + self.request_data = json.dumps({ + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.check_username = Mock(return_value=True) + self.auth_result = (True, None, { + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.register = Mock(return_value=(user_id, token)) + + result = yield self.servlet.on_POST(self.request) + self.assertEquals(result, (200, { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname + })) + + def test_POST_disabled_registration(self): + self.hs.config.disable_registration = True + self.request_data = json.dumps({ + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.check_username = Mock(return_value=True) + self.auth_result = (True, None, { + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.register = Mock(return_value=("@user:id", "t")) + d = self.servlet.on_POST(self.request) + return self.assertFailure(d, SynapseError) \ No newline at end of file diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8c3d2952bd..abee2f631d 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -17,6 +17,8 @@ from tests import unittest from twisted.internet import defer +from synapse.util.async import ObservableDeferred + from synapse.storage._base import Cache, cached @@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase): self.assertEquals(self.cache.get("foo"), 123) def test_invalidate(self): - self.cache.prefill("foo", 123) - self.cache.invalidate("foo") + self.cache.prefill(("foo",), 123) + self.cache.invalidate(("foo",)) failed = False try: - self.cache.get("foo") + self.cache.get(("foo",)) except KeyError: failed = True @@ -139,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 1) - a.func.invalidate("foo") + a.func.invalidate(("foo",)) yield a.func("foo") @@ -151,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase): def func(self, key): return key - A().func.invalidate("what") + A().func.invalidate(("what",)) @defer.inlineCallbacks def test_max_entries(self): @@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertTrue(callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])) - @defer.inlineCallbacks def test_prefill(self): callcount = [0] + d = defer.succeed(123) + class A(object): @cached() def func(self, key): callcount[0] += 1 - return key + return d a = A() - a.func.prefill("foo", 123) + a.func.prefill(("foo",), ObservableDeferred(d)) - self.assertEquals((yield a.func("foo")), 123) + self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 6a0095d850..8ed48cfb70 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase): yield d self.assertTrue(d.called) - observers[0].assert_called_once("Go") - observers[1].assert_called_once("Go") + observers[0].assert_called_once_with("Go") + observers[1].assert_called_once_with("Go") self.assertEquals(mock_logger.warning.call_count, 1) self.assertIsInstance(mock_logger.warning.call_args[0][0], |