diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/api/test_auth.py | 25 | ||||
-rw-r--r-- | tests/handlers/test_auth.py | 8 | ||||
-rw-r--r-- | tests/handlers/test_register.py | 72 | ||||
-rw-r--r-- | tests/handlers/test_typing.py | 4 | ||||
-rw-r--r-- | tests/rest/client/test_transactions.py | 4 | ||||
-rw-r--r-- | tests/rest/client/v1/test_admin.py | 10 | ||||
-rw-r--r-- | tests/rest/client/v1/test_profile.py | 8 | ||||
-rw-r--r-- | tests/rest/client/v1/utils.py | 10 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_filter.py | 22 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_register.py | 14 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_sync.py | 4 | ||||
-rw-r--r-- | tests/server.py | 22 | ||||
-rw-r--r-- | tests/storage/test__init__.py | 65 | ||||
-rw-r--r-- | tests/storage/test_client_ips.py | 66 | ||||
-rw-r--r-- | tests/storage/test_event_federation.py | 2 | ||||
-rw-r--r-- | tests/storage/test_monthly_active_users.py | 123 | ||||
-rw-r--r-- | tests/storage/test_state.py | 2 | ||||
-rw-r--r-- | tests/test_server.py | 11 | ||||
-rw-r--r-- | tests/utils.py | 16 |
19 files changed, 326 insertions, 162 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index a82d737e71..5dc3398300 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -444,3 +444,28 @@ class AuthTestCase(unittest.TestCase): self.assertEqual("Guest access token used for regular user", cm.exception.msg) self.store.get_user_by_id.assert_called_with(USER_ID) + + @defer.inlineCallbacks + def test_blocking_mau(self): + self.hs.config.limit_usage_by_mau = False + self.hs.config.max_mau_value = 50 + lots_of_users = 100 + small_number_of_users = 1 + + # Ensure no error thrown + yield self.auth.check_auth_blocking() + + self.hs.config.limit_usage_by_mau = True + + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(lots_of_users) + ) + + with self.assertRaises(AuthError): + yield self.auth.check_auth_blocking() + + # Ensure does not throw an error + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(small_number_of_users) + ) + yield self.auth.check_auth_blocking() diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 55eab9e9cf..8a9bf2d5fd 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -132,14 +132,14 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_exceeded(self): self.hs.config.limit_usage_by_mau = True - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(AuthError): yield self.auth_handler.get_access_token_for_user_id('user_a') - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(AuthError): @@ -151,13 +151,13 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_not_exceeded(self): self.hs.config.limit_usage_by_mau = True - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception yield self.auth_handler.get_access_token_for_user_id('user_a') - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) yield self.auth_handler.validate_short_term_login_token_and_get_user_id( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 0937d71cf6..4ea59a58de 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -50,6 +50,10 @@ class RegistrationTestCase(unittest.TestCase): self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler + self.store = self.hs.get_datastore() + self.hs.config.max_mau_value = 50 + self.lots_of_users = 100 + self.small_number_of_users = 1 @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): @@ -80,51 +84,43 @@ class RegistrationTestCase(unittest.TestCase): self.assertEquals(result_token, 'secret') @defer.inlineCallbacks - def test_cannot_register_when_mau_limits_exceeded(self): - local_part = "someone" - display_name = "someone" - requester = create_requester("@as:test") - store = self.hs.get_datastore() + def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 - lots_of_users = 100 - small_number_users = 1 - - store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) - # Ensure does not throw exception - yield self.handler.get_or_create_user(requester, 'a', display_name) + yield self.handler.get_or_create_user("requester", 'a', "display_name") + @defer.inlineCallbacks + def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True - - with self.assertRaises(RegistrationError): - yield self.handler.get_or_create_user(requester, 'b', display_name) - - store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users)) - - self._macaroon_mock_generator("another_secret") - + self.store.count_monthly_users = Mock( + return_value=defer.succeed(self.small_number_of_users) + ) # Ensure does not throw exception - yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil") - - self._macaroon_mock_generator("another another secret") - store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + yield self.handler.get_or_create_user("@user:server", 'c', "User") + @defer.inlineCallbacks + def test_get_or_create_user_mau_blocked(self): + self.hs.config.limit_usage_by_mau = True + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.lots_of_users) + ) with self.assertRaises(RegistrationError): - yield self.handler.register(localpart=local_part) + yield self.handler.get_or_create_user("requester", 'b', "display_name") - self._macaroon_mock_generator("another another secret") - store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + @defer.inlineCallbacks + def test_register_mau_blocked(self): + self.hs.config.limit_usage_by_mau = True + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.lots_of_users) + ) + with self.assertRaises(RegistrationError): + yield self.handler.register(localpart="local_part") + @defer.inlineCallbacks + def test_register_saml2_mau_blocked(self): + self.hs.config.limit_usage_by_mau = True + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.lots_of_users) + ) with self.assertRaises(RegistrationError): - yield self.handler.register_saml2(local_part) - - def _macaroon_mock_generator(self, secret): - """ - Reset macaroon generator in the case where the test creates multiple users - """ - macaroon_generator = Mock( - generate_access_token=Mock(return_value=secret)) - self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator) - self.hs.handlers = RegistrationHandlers(self.hs) - self.handler = self.hs.get_handlers().registration_handler + yield self.handler.register_saml2(localpart="local_part") diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 2c263af1a3..f422cf3c5a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -48,7 +48,9 @@ def _expect_edu(destination, edu_type, content, origin="test"): def _make_edu_json(origin, edu_type, content): - return json.dumps(_expect_edu("test", edu_type, content, origin=origin)) + return json.dumps( + _expect_edu("test", edu_type, content, origin=origin) + ).encode('utf8') class TypingNotificationsTestCase(unittest.TestCase): diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 34e68ae82f..d46c27e7e9 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -85,7 +85,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): try: yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: - self.assertEqual(e.message, "boo") + self.assertEqual(e.args[0], "boo") self.assertIs(LoggingContext.current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) @@ -111,7 +111,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): try: yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: - self.assertEqual(e.message, "boo") + self.assertEqual(e.args[0], "boo") self.assertIs(LoggingContext.current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py index 8c90145601..fb28883d30 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py @@ -140,7 +140,7 @@ class UserRegisterTestCase(unittest.TestCase): "admin": True, "mac": want_mac, } - ).encode('utf8') + ) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) @@ -168,7 +168,7 @@ class UserRegisterTestCase(unittest.TestCase): "admin": True, "mac": want_mac, } - ).encode('utf8') + ) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) @@ -195,7 +195,7 @@ class UserRegisterTestCase(unittest.TestCase): "admin": True, "mac": want_mac, } - ).encode('utf8') + ) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) @@ -253,7 +253,7 @@ class UserRegisterTestCase(unittest.TestCase): self.assertEqual('Invalid username', channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"}) + body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) @@ -289,7 +289,7 @@ class UserRegisterTestCase(unittest.TestCase): self.assertEqual('Invalid password', channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"}) + body = json.dumps({"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index d71cc8e0db..0516ce3cfb 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -80,7 +80,7 @@ class ProfileTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", "/profile/%s/displayname" % (myid), - '{"displayname": "Frank Jr."}' + b'{"displayname": "Frank Jr."}' ) self.assertEquals(200, code) @@ -95,7 +95,7 @@ class ProfileTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", "/profile/%s/displayname" % ("@4567:test"), - '{"displayname": "Frank Jr."}' + b'{"displayname": "Frank Jr."}' ) self.assertTrue( @@ -122,7 +122,7 @@ class ProfileTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"), - '{"displayname":"bob"}' + b'{"displayname":"bob"}' ) self.assertTrue( @@ -151,7 +151,7 @@ class ProfileTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", "/profile/%s/avatar_url" % (myid), - '{"avatar_url": "http://my.server/pic.gif"}' + b'{"avatar_url": "http://my.server/pic.gif"}' ) self.assertEquals(200, code) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 41de8e0762..e3bc5f378d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -105,7 +105,7 @@ class RestTestCase(unittest.TestCase): "password": "test", "type": "m.login.password" })) - self.assertEquals(200, code) + self.assertEquals(200, code, msg=response) defer.returnValue(response) @defer.inlineCallbacks @@ -149,14 +149,14 @@ class RestHelper(object): def create_room_as(self, room_creator, is_public=True, tok=None): temp_id = self.auth_user_id self.auth_user_id = room_creator - path = b"/_matrix/client/r0/createRoom" + path = "/_matrix/client/r0/createRoom" content = {} if not is_public: content["visibility"] = "private" if tok: - path = path + b"?access_token=%s" % tok.encode('ascii') + path = path + "?access_token=%s" % tok - request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8')) + request, channel = make_request("POST", path, json.dumps(content).encode('utf8')) request.render(self.resource) wait_until_result(self.hs.get_reactor(), channel) @@ -205,7 +205,7 @@ class RestHelper(object): data = {"membership": membership} request, channel = make_request( - b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8') + "PUT", path, json.dumps(data).encode('utf8') ) request.render(self.resource) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index e890f0feac..de33b10a5f 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -33,7 +33,7 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.TestCase): - USER_ID = b"@apple:test" + USER_ID = "@apple:test" EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' TO_REGISTER = [filter] @@ -72,8 +72,8 @@ class FilterTestCase(unittest.TestCase): def test_add_filter(self): request, channel = make_request( - b"POST", - b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + "POST", + "/_matrix/client/r0/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON, ) request.render(self.resource) @@ -87,8 +87,8 @@ class FilterTestCase(unittest.TestCase): def test_add_filter_for_other_user(self): request, channel = make_request( - b"POST", - b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"), + "POST", + "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), self.EXAMPLE_FILTER_JSON, ) request.render(self.resource) @@ -101,8 +101,8 @@ class FilterTestCase(unittest.TestCase): _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False request, channel = make_request( - b"POST", - b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + "POST", + "/_matrix/client/r0/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON, ) request.render(self.resource) @@ -119,7 +119,7 @@ class FilterTestCase(unittest.TestCase): self.clock.advance(1) filter_id = filter_id.result request, channel = make_request( - b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) + "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) ) request.render(self.resource) wait_until_result(self.clock, channel) @@ -129,7 +129,7 @@ class FilterTestCase(unittest.TestCase): def test_get_filter_non_existant(self): request, channel = make_request( - b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) + "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) ) request.render(self.resource) wait_until_result(self.clock, channel) @@ -141,7 +141,7 @@ class FilterTestCase(unittest.TestCase): # in errors.py def test_get_filter_invalid_id(self): request, channel = make_request( - b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) + "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) ) request.render(self.resource) wait_until_result(self.clock, channel) @@ -151,7 +151,7 @@ class FilterTestCase(unittest.TestCase): # No ID also returns an invalid_id error def test_get_filter_no_id(self): request, channel = make_request( - b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) + "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) ) request.render(self.resource) wait_until_result(self.clock, channel) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index e004d8fc73..f6293f11a8 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -81,7 +81,7 @@ class RegisterRestServletTestCase(unittest.TestCase): "access_token": token, "home_server": self.hs.hostname, } - self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) + self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_appservice_registration_invalid(self): self.appservice = None # no application service exists @@ -102,7 +102,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals( - json.loads(channel.result["body"])["error"], "Invalid password" + channel.json_body["error"], "Invalid password" ) def test_POST_bad_username(self): @@ -113,7 +113,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals( - json.loads(channel.result["body"])["error"], "Invalid username" + channel.json_body["error"], "Invalid username" ) def test_POST_user_valid(self): @@ -140,7 +140,7 @@ class RegisterRestServletTestCase(unittest.TestCase): "device_id": device_id, } self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) + self.assertDictContainsSubset(det_data, channel.json_body) self.auth_handler.get_login_tuple_for_user_id( user_id, device_id=device_id, initial_device_display_name=None ) @@ -158,7 +158,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - json.loads(channel.result["body"])["error"], + channel.json_body["error"], "Registration has been disabled", ) @@ -178,7 +178,7 @@ class RegisterRestServletTestCase(unittest.TestCase): "device_id": "guest_device", } self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) + self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): self.hs.config.allow_guest_access = False @@ -189,5 +189,5 @@ class RegisterRestServletTestCase(unittest.TestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - json.loads(channel.result["body"])["error"], "Guest access is disabled" + channel.json_body["error"], "Guest access is disabled" ) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 03ec3993b2..bafc0d1df0 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -32,7 +32,7 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.TestCase): - USER_ID = b"@apple:test" + USER_ID = "@apple:test" TO_REGISTER = [sync] def setUp(self): @@ -68,7 +68,7 @@ class FilterTestCase(unittest.TestCase): r.register_servlets(self.hs, self.resource) def test_sync_argless(self): - request, channel = make_request(b"GET", b"/_matrix/client/r0/sync") + request, channel = make_request("GET", "/_matrix/client/r0/sync") request.render(self.resource) wait_until_result(self.clock, channel) diff --git a/tests/server.py b/tests/server.py index c611dd6059..e249668d21 100644 --- a/tests/server.py +++ b/tests/server.py @@ -11,6 +11,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactorClock from synapse.http.site import SynapseRequest +from synapse.util import Clock from tests.utils import setup_test_homeserver as _sth @@ -28,7 +29,13 @@ class FakeChannel(object): def json_body(self): if not self.result: raise Exception("No result yet.") - return json.loads(self.result["body"]) + return json.loads(self.result["body"].decode('utf8')) + + @property + def code(self): + if not self.result: + raise Exception("No result yet.") + return int(self.result["code"]) def writeHeaders(self, version, code, reason, headers): self.result["version"] = version @@ -79,11 +86,16 @@ def make_request(method, path, content=b""): Make a web request using the given method and path, feed it the content, and return the Request and the Channel underneath. """ + if not isinstance(method, bytes): + method = method.encode('ascii') + + if not isinstance(path, bytes): + path = path.encode('ascii') # Decorate it to be the full path if not path.startswith(b"/_matrix"): path = b"/_matrix/client/r0/" + path - path = path.replace("//", "/") + path = path.replace(b"//", b"/") if isinstance(content, text_type): content = content.encode('utf8') @@ -191,3 +203,9 @@ def setup_test_homeserver(*args, **kwargs): clock.threadpool = ThreadPool() pool.threadpool = ThreadPool() return d + + +def get_clock(): + clock = ThreadedMemoryReactorClock() + hs_clock = Clock(clock) + return (clock, hs_clock) diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py deleted file mode 100644 index f19cb1265c..0000000000 --- a/tests/storage/test__init__.py +++ /dev/null @@ -1,65 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -import tests.utils - - -class InitTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super(InitTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore - - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver() - - hs.config.max_mau_value = 50 - hs.config.limit_usage_by_mau = True - self.store = hs.get_datastore() - self.clock = hs.get_clock() - - @defer.inlineCallbacks - def test_count_monthly_users(self): - count = yield self.store.count_monthly_users() - self.assertEqual(0, count) - - yield self._insert_user_ips("@user:server1") - yield self._insert_user_ips("@user:server2") - - count = yield self.store.count_monthly_users() - self.assertEqual(2, count) - - @defer.inlineCallbacks - def _insert_user_ips(self, user): - """ - Helper function to populate user_ips without using batch insertion infra - args: - user (str): specify username i.e. @user:server.com - """ - yield self.store._simple_upsert( - table="user_ips", - keyvalues={ - "user_id": user, - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": "device_id", - }, - values={ - "last_seen": self.clock.time_msec(), - } - ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index bd6fda6cb1..7a58c6eb24 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock from twisted.internet import defer @@ -27,9 +28,9 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() - self.store = hs.get_datastore() - self.clock = hs.get_clock() + self.hs = yield tests.utils.setup_test_homeserver() + self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() @defer.inlineCallbacks def test_insert_new_client_ip(self): @@ -54,3 +55,62 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): }, r ) + + @defer.inlineCallbacks + def test_disabled_monthly_active_user(self): + self.hs.config.limit_usage_by_mau = False + self.hs.config.max_mau_value = 50 + user_id = "@user:server" + yield self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id", + ) + active = yield self.store._user_last_seen_monthly_active(user_id) + self.assertFalse(active) + + @defer.inlineCallbacks + def test_adding_monthly_active_user_when_full(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 50 + lots_of_users = 100 + user_id = "@user:server" + + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(lots_of_users) + ) + yield self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id", + ) + active = yield self.store._user_last_seen_monthly_active(user_id) + self.assertFalse(active) + + @defer.inlineCallbacks + def test_adding_monthly_active_user_when_space(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 50 + user_id = "@user:server" + active = yield self.store._user_last_seen_monthly_active(user_id) + self.assertFalse(active) + + yield self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id", + ) + active = yield self.store._user_last_seen_monthly_active(user_id) + self.assertTrue(active) + + @defer.inlineCallbacks + def test_updating_monthly_active_user_when_space(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 50 + user_id = "@user:server" + + active = yield self.store._user_last_seen_monthly_active(user_id) + self.assertFalse(active) + + yield self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id", + ) + yield self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id", + ) + active = yield self.store._user_last_seen_monthly_active(user_id) + self.assertTrue(active) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 30683e7888..69412c5aad 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -49,7 +49,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): 'INSERT INTO event_reference_hashes ' '(event_id, algorithm, hash) ' "VALUES (?, 'sha256', ?)" - ), (event_id, 'ffff')) + ), (event_id, b'ffff')) for i in range(0, 11): yield self.store.runInteraction("insert", insert_event, i) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py new file mode 100644 index 0000000000..cbd480cd42 --- /dev/null +++ b/tests/storage/test_monthly_active_users.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import tests.unittest +import tests.utils +from tests.utils import setup_test_homeserver + +FORTY_DAYS = 40 * 24 * 60 * 60 + + +class MonthlyActiveUsersTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs) + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver() + self.store = self.hs.get_datastore() + + @defer.inlineCallbacks + def test_initialise_reserved_users(self): + + user1 = "@user1:server" + user1_email = "user1@matrix.org" + user2 = "@user2:server" + user2_email = "user2@matrix.org" + threepids = [ + {'medium': 'email', 'address': user1_email}, + {'medium': 'email', 'address': user2_email} + ] + user_num = len(threepids) + + yield self.store.register( + user_id=user1, + token="123", + password_hash=None) + + yield self.store.register( + user_id=user2, + token="456", + password_hash=None) + + now = int(self.hs.get_clock().time_msec()) + yield self.store.user_add_threepid(user1, "email", user1_email, now, now) + yield self.store.user_add_threepid(user2, "email", user2_email, now, now) + yield self.store.initialise_reserved_users(threepids) + + active_count = yield self.store.get_monthly_active_count() + + # Test total counts + self.assertEquals(active_count, user_num) + + # Test user is marked as active + + timestamp = yield self.store._user_last_seen_monthly_active(user1) + self.assertTrue(timestamp) + timestamp = yield self.store._user_last_seen_monthly_active(user2) + self.assertTrue(timestamp) + + # Test that users are never removed from the db. + self.hs.config.max_mau_value = 0 + + self.hs.get_clock().advance_time(FORTY_DAYS) + + yield self.store.reap_monthly_active_users() + + active_count = yield self.store.get_monthly_active_count() + self.assertEquals(active_count, user_num) + + @defer.inlineCallbacks + def test_can_insert_and_count_mau(self): + count = yield self.store.get_monthly_active_count() + self.assertEqual(0, count) + + yield self.store.upsert_monthly_active_user("@user:server") + count = yield self.store.get_monthly_active_count() + + self.assertEqual(1, count) + + @defer.inlineCallbacks + def test__user_last_seen_monthly_active(self): + user_id1 = "@user1:server" + user_id2 = "@user2:server" + user_id3 = "@user3:server" + result = yield self.store._user_last_seen_monthly_active(user_id1) + self.assertFalse(result == 0) + yield self.store.upsert_monthly_active_user(user_id1) + yield self.store.upsert_monthly_active_user(user_id2) + result = yield self.store._user_last_seen_monthly_active(user_id1) + self.assertTrue(result > 0) + result = yield self.store._user_last_seen_monthly_active(user_id3) + self.assertFalse(result == 0) + + @defer.inlineCallbacks + def test_reap_monthly_active_users(self): + self.hs.config.max_mau_value = 5 + initial_users = 10 + for i in range(initial_users): + yield self.store.upsert_monthly_active_user("@user%d:server" % i) + count = yield self.store.get_monthly_active_count() + self.assertTrue(count, initial_users) + yield self.store.reap_monthly_active_users() + count = yield self.store.get_monthly_active_count() + self.assertEquals(count, initial_users - self.hs.config.max_mau_value) + + self.hs.get_clock().advance_time(FORTY_DAYS) + yield self.store.reap_monthly_active_users() + count = yield self.store.get_monthly_active_count() + self.assertEquals(count, 0) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 7a76d67b8c..f7871cd426 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -176,7 +176,7 @@ class StateStoreTestCase(tests.unittest.TestCase): room_id = self.room.to_string() group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) - group = group_ids.keys()[0] + group = list(group_ids.keys())[0] # test _get_some_state_from_cache correctly filters out members with types=[] (state_dict, is_all) = yield self.store._get_some_state_from_cache( diff --git a/tests/test_server.py b/tests/test_server.py index 7e063c0290..fc396226ea 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,4 +1,3 @@ -import json import re from twisted.internet.defer import Deferred @@ -104,9 +103,8 @@ class JsonResourceTests(unittest.TestCase): request.render(res) self.assertEqual(channel.result["code"], b'403') - reply_body = json.loads(channel.result["body"]) - self.assertEqual(reply_body["error"], "Forbidden!!one!") - self.assertEqual(reply_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.json_body["error"], "Forbidden!!one!") + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") def test_no_handler(self): """ @@ -126,6 +124,5 @@ class JsonResourceTests(unittest.TestCase): request.render(res) self.assertEqual(channel.result["code"], b'400') - reply_body = json.loads(channel.result["body"]) - self.assertEqual(reply_body["error"], "Unrecognized request") - self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED") + self.assertEqual(channel.json_body["error"], "Unrecognized request") + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") diff --git a/tests/utils.py b/tests/utils.py index 9bff3ff3b9..5d49692c58 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -73,6 +73,13 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None config.block_events_without_consent_error = None config.media_storage_providers = [] config.auto_join_rooms = [] + config.limit_usage_by_mau = False + config.max_mau_value = 50 + config.mau_limits_reserved_threepids = [] + + # we need a sane default_room_version, otherwise attempts to create rooms will + # fail. + config.default_room_version = "1" # disable user directory updates, because they get done in the # background, which upsets the test runner. @@ -146,8 +153,9 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest() - hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h + hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest() + hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5( + p.encode('utf8')).hexdigest() == h fed = kargs.get("resource_for_federation", None) if fed: @@ -220,8 +228,8 @@ class MockHttpResource(HttpServer): mock_content.configure_mock(**config) mock_request.content = mock_content - mock_request.method = http_method - mock_request.uri = path + mock_request.method = http_method.encode('ascii') + mock_request.uri = path.encode('ascii') mock_request.getClientIP.return_value = "-" |