diff options
Diffstat (limited to 'tests/rest/client/v1')
-rw-r--r-- | tests/rest/client/v1/test_admin.py | 116 | ||||
-rw-r--r-- | tests/rest/client/v1/test_register.py | 10 | ||||
-rw-r--r-- | tests/rest/client/v1/utils.py | 22 |
3 files changed, 65 insertions, 83 deletions
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py index 1a553fa3f9..e38eb628a9 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py @@ -19,24 +19,17 @@ import json from mock import Mock -from synapse.http.server import JsonResource from synapse.rest.client.v1.admin import register_servlets -from synapse.util import Clock from tests import unittest -from tests.server import ( - ThreadedMemoryReactorClock, - make_request, - render, - setup_test_homeserver, -) -class UserRegisterTestCase(unittest.TestCase): - def setUp(self): +class UserRegisterTestCase(unittest.HomeserverTestCase): + + servlets = [register_servlets] + + def make_homeserver(self, reactor, clock): - self.clock = ThreadedMemoryReactorClock() - self.hs_clock = Clock(self.clock) self.url = "/_matrix/client/r0/admin/register" self.registration_handler = Mock() @@ -50,17 +43,14 @@ class UserRegisterTestCase(unittest.TestCase): self.secrets = Mock() - self.hs = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock - ) + self.hs = self.setup_test_homeserver() self.hs.config.registration_shared_secret = u"shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() - self.resource = JsonResource(self.hs) - register_servlets(self.hs, self.resource) + return self.hs def test_disabled(self): """ @@ -69,8 +59,8 @@ class UserRegisterTestCase(unittest.TestCase): """ self.hs.config.registration_shared_secret = None - request, channel = make_request("POST", self.url, b'{}') - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, b'{}') + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -87,8 +77,8 @@ class UserRegisterTestCase(unittest.TestCase): self.hs.get_secrets = Mock(return_value=secrets) - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) self.assertEqual(channel.json_body, {"nonce": "abcd"}) @@ -97,25 +87,25 @@ class UserRegisterTestCase(unittest.TestCase): Calling GET on the endpoint will return a randomised nonce, which will only last for SALT_TIMEOUT (60s). """ - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) nonce = channel.json_body["nonce"] # 59 seconds - self.clock.advance(59) + self.reactor.advance(59) body = json.dumps({"nonce": nonce}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('username must be specified', channel.json_body["error"]) # 61 seconds - self.clock.advance(2) + self.reactor.advance(2) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('unrecognised nonce', channel.json_body["error"]) @@ -124,8 +114,8 @@ class UserRegisterTestCase(unittest.TestCase): """ Only the provided nonce can be used, as it's checked in the MAC. """ - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -141,8 +131,8 @@ class UserRegisterTestCase(unittest.TestCase): "mac": want_mac, } ) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("HMAC incorrect", channel.json_body["error"]) @@ -152,8 +142,8 @@ class UserRegisterTestCase(unittest.TestCase): When the correct nonce is provided, and the right key is provided, the user is registered. """ - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -169,8 +159,8 @@ class UserRegisterTestCase(unittest.TestCase): "mac": want_mac, } ) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -179,8 +169,8 @@ class UserRegisterTestCase(unittest.TestCase): """ A valid unrecognised nonce. """ - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -196,15 +186,15 @@ class UserRegisterTestCase(unittest.TestCase): "mac": want_mac, } ) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('unrecognised nonce', channel.json_body["error"]) @@ -217,8 +207,8 @@ class UserRegisterTestCase(unittest.TestCase): """ def nonce(): - request, channel = make_request("GET", self.url) - render(request, self.resource, self.clock) + request, channel = self.make_request("GET", self.url) + self.render(request) return channel.json_body["nonce"] # @@ -227,8 +217,8 @@ class UserRegisterTestCase(unittest.TestCase): # Must be present body = json.dumps({}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('nonce must be specified', channel.json_body["error"]) @@ -239,32 +229,32 @@ class UserRegisterTestCase(unittest.TestCase): # Must be present body = json.dumps({"nonce": nonce()}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('username must be specified', channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid username', channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid username', channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid username', channel.json_body["error"]) @@ -275,16 +265,16 @@ class UserRegisterTestCase(unittest.TestCase): # Must be present body = json.dumps({"nonce": nonce(), "username": "a"}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('password must be specified', channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid password', channel.json_body["error"]) @@ -293,16 +283,16 @@ class UserRegisterTestCase(unittest.TestCase): body = json.dumps( {"nonce": nonce(), "username": "a", "password": u"abcd\u0000"} ) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid password', channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - request, channel = make_request("POST", self.url, body.encode('utf8')) - render(request, self.resource, self.clock) + request, channel = self.make_request("POST", self.url, body.encode('utf8')) + self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid password', channel.json_body["error"]) diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 6b7ff813d5..f973eff8cf 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -45,11 +45,11 @@ class CreateUserServletTestCase(unittest.TestCase): ) handlers = Mock(registration_handler=self.registration_handler) - self.clock = MemoryReactorClock() - self.hs_clock = Clock(self.clock) + self.reactor = MemoryReactorClock() + self.hs_clock = Clock(self.reactor) self.hs = self.hs = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor ) self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_handlers = Mock(return_value=handlers) @@ -76,8 +76,8 @@ class CreateUserServletTestCase(unittest.TestCase): return_value=(user_id, token) ) - request, channel = make_request(b"POST", url, request_data) - render(request, res, self.clock) + request, channel = make_request(self.reactor, b"POST", url, request_data) + render(request, res, self.reactor) self.assertEquals(channel.result["code"], b"200") diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 530dc8ba6d..9c401bf300 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -169,7 +169,7 @@ class RestHelper(object): path = path + "?access_token=%s" % tok request, channel = make_request( - "POST", path, json.dumps(content).encode('utf8') + self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8') ) render(request, self.resource, self.hs.get_reactor()) @@ -217,7 +217,9 @@ class RestHelper(object): data = {"membership": membership} - request, channel = make_request("PUT", path, json.dumps(data).encode('utf8')) + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8') + ) render(request, self.resource, self.hs.get_reactor()) @@ -228,18 +230,6 @@ class RestHelper(object): self.auth_user_id = temp_id - @defer.inlineCallbacks - def register(self, user_id): - (code, response) = yield self.mock_resource.trigger( - "POST", - "/_matrix/client/r0/register", - json.dumps( - {"user": user_id, "password": "test", "type": "m.login.password"} - ), - ) - self.assertEquals(200, code) - defer.returnValue(response) - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -251,7 +241,9 @@ class RestHelper(object): if tok: path = path + "?access_token=%s" % tok - request, channel = make_request("PUT", path, json.dumps(content).encode('utf8')) + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8') + ) render(request, self.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( |