diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index a82d737e71..5dc3398300 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -444,3 +444,28 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
self.store.get_user_by_id.assert_called_with(USER_ID)
+
+ @defer.inlineCallbacks
+ def test_blocking_mau(self):
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.max_mau_value = 50
+ lots_of_users = 100
+ small_number_of_users = 1
+
+ # Ensure no error thrown
+ yield self.auth.check_auth_blocking()
+
+ self.hs.config.limit_usage_by_mau = True
+
+ self.store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(lots_of_users)
+ )
+
+ with self.assertRaises(AuthError):
+ yield self.auth.check_auth_blocking()
+
+ # Ensure does not throw an error
+ self.store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(small_number_of_users)
+ )
+ yield self.auth.check_auth_blocking()
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 55eab9e9cf..8a9bf2d5fd 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -132,14 +132,14 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_mau_limits_exceeded(self):
self.hs.config.limit_usage_by_mau = True
- self.hs.get_datastore().count_monthly_users = Mock(
+ self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(AuthError):
yield self.auth_handler.get_access_token_for_user_id('user_a')
- self.hs.get_datastore().count_monthly_users = Mock(
+ self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(AuthError):
@@ -151,13 +151,13 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True
- self.hs.get_datastore().count_monthly_users = Mock(
+ self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
yield self.auth_handler.get_access_token_for_user_id('user_a')
- self.hs.get_datastore().count_monthly_users = Mock(
+ self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 0937d71cf6..4ea59a58de 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -50,6 +50,10 @@ class RegistrationTestCase(unittest.TestCase):
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler
+ self.store = self.hs.get_datastore()
+ self.hs.config.max_mau_value = 50
+ self.lots_of_users = 100
+ self.small_number_of_users = 1
@defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
@@ -80,51 +84,43 @@ class RegistrationTestCase(unittest.TestCase):
self.assertEquals(result_token, 'secret')
@defer.inlineCallbacks
- def test_cannot_register_when_mau_limits_exceeded(self):
- local_part = "someone"
- display_name = "someone"
- requester = create_requester("@as:test")
- store = self.hs.get_datastore()
+ def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False
- self.hs.config.max_mau_value = 50
- lots_of_users = 100
- small_number_users = 1
-
- store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
-
# Ensure does not throw exception
- yield self.handler.get_or_create_user(requester, 'a', display_name)
+ yield self.handler.get_or_create_user("requester", 'a', "display_name")
+ @defer.inlineCallbacks
+ def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
-
- with self.assertRaises(RegistrationError):
- yield self.handler.get_or_create_user(requester, 'b', display_name)
-
- store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
-
- self._macaroon_mock_generator("another_secret")
-
+ self.store.count_monthly_users = Mock(
+ return_value=defer.succeed(self.small_number_of_users)
+ )
# Ensure does not throw exception
- yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
-
- self._macaroon_mock_generator("another another secret")
- store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+ yield self.handler.get_or_create_user("@user:server", 'c', "User")
+ @defer.inlineCallbacks
+ def test_get_or_create_user_mau_blocked(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(self.lots_of_users)
+ )
with self.assertRaises(RegistrationError):
- yield self.handler.register(localpart=local_part)
+ yield self.handler.get_or_create_user("requester", 'b', "display_name")
- self._macaroon_mock_generator("another another secret")
- store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+ @defer.inlineCallbacks
+ def test_register_mau_blocked(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(self.lots_of_users)
+ )
+ with self.assertRaises(RegistrationError):
+ yield self.handler.register(localpart="local_part")
+ @defer.inlineCallbacks
+ def test_register_saml2_mau_blocked(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(self.lots_of_users)
+ )
with self.assertRaises(RegistrationError):
- yield self.handler.register_saml2(local_part)
-
- def _macaroon_mock_generator(self, secret):
- """
- Reset macaroon generator in the case where the test creates multiple users
- """
- macaroon_generator = Mock(
- generate_access_token=Mock(return_value=secret))
- self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
- self.hs.handlers = RegistrationHandlers(self.hs)
- self.handler = self.hs.get_handlers().registration_handler
+ yield self.handler.register_saml2(localpart="local_part")
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 2c263af1a3..f422cf3c5a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -48,7 +48,9 @@ def _expect_edu(destination, edu_type, content, origin="test"):
def _make_edu_json(origin, edu_type, content):
- return json.dumps(_expect_edu("test", edu_type, content, origin=origin))
+ return json.dumps(
+ _expect_edu("test", edu_type, content, origin=origin)
+ ).encode('utf8')
class TypingNotificationsTestCase(unittest.TestCase):
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 34e68ae82f..d46c27e7e9 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -85,7 +85,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
try:
yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e:
- self.assertEqual(e.message, "boo")
+ self.assertEqual(e.args[0], "boo")
self.assertIs(LoggingContext.current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
@@ -111,7 +111,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
try:
yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e:
- self.assertEqual(e.message, "boo")
+ self.assertEqual(e.args[0], "boo")
self.assertIs(LoggingContext.current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index 8c90145601..fb28883d30 100644
--- a/tests/rest/client/v1/test_admin.py
+++ b/tests/rest/client/v1/test_admin.py
@@ -140,7 +140,7 @@ class UserRegisterTestCase(unittest.TestCase):
"admin": True,
"mac": want_mac,
}
- ).encode('utf8')
+ )
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
@@ -168,7 +168,7 @@ class UserRegisterTestCase(unittest.TestCase):
"admin": True,
"mac": want_mac,
}
- ).encode('utf8')
+ )
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
@@ -195,7 +195,7 @@ class UserRegisterTestCase(unittest.TestCase):
"admin": True,
"mac": want_mac,
}
- ).encode('utf8')
+ )
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
@@ -253,7 +253,7 @@ class UserRegisterTestCase(unittest.TestCase):
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"})
+ body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
@@ -289,7 +289,7 @@ class UserRegisterTestCase(unittest.TestCase):
self.assertEqual('Invalid password', channel.json_body["error"])
# Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"})
+ body = json.dumps({"nonce": nonce(), "username": "a", "password": u"abcd\u0000"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index d71cc8e0db..0516ce3cfb 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -80,7 +80,7 @@ class ProfileTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger(
"PUT",
"/profile/%s/displayname" % (myid),
- '{"displayname": "Frank Jr."}'
+ b'{"displayname": "Frank Jr."}'
)
self.assertEquals(200, code)
@@ -95,7 +95,7 @@ class ProfileTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger(
"PUT", "/profile/%s/displayname" % ("@4567:test"),
- '{"displayname": "Frank Jr."}'
+ b'{"displayname": "Frank Jr."}'
)
self.assertTrue(
@@ -122,7 +122,7 @@ class ProfileTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger(
"PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"),
- '{"displayname":"bob"}'
+ b'{"displayname":"bob"}'
)
self.assertTrue(
@@ -151,7 +151,7 @@ class ProfileTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger(
"PUT",
"/profile/%s/avatar_url" % (myid),
- '{"avatar_url": "http://my.server/pic.gif"}'
+ b'{"avatar_url": "http://my.server/pic.gif"}'
)
self.assertEquals(200, code)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 41de8e0762..e3bc5f378d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -105,7 +105,7 @@ class RestTestCase(unittest.TestCase):
"password": "test",
"type": "m.login.password"
}))
- self.assertEquals(200, code)
+ self.assertEquals(200, code, msg=response)
defer.returnValue(response)
@defer.inlineCallbacks
@@ -149,14 +149,14 @@ class RestHelper(object):
def create_room_as(self, room_creator, is_public=True, tok=None):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
- path = b"/_matrix/client/r0/createRoom"
+ path = "/_matrix/client/r0/createRoom"
content = {}
if not is_public:
content["visibility"] = "private"
if tok:
- path = path + b"?access_token=%s" % tok.encode('ascii')
+ path = path + "?access_token=%s" % tok
- request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8'))
+ request, channel = make_request("POST", path, json.dumps(content).encode('utf8'))
request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel)
@@ -205,7 +205,7 @@ class RestHelper(object):
data = {"membership": membership}
request, channel = make_request(
- b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8')
+ "PUT", path, json.dumps(data).encode('utf8')
)
request.render(self.resource)
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index e890f0feac..de33b10a5f 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -33,7 +33,7 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase):
- USER_ID = b"@apple:test"
+ USER_ID = "@apple:test"
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter]
@@ -72,8 +72,8 @@ class FilterTestCase(unittest.TestCase):
def test_add_filter(self):
request, channel = make_request(
- b"POST",
- b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+ "POST",
+ "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
)
request.render(self.resource)
@@ -87,8 +87,8 @@ class FilterTestCase(unittest.TestCase):
def test_add_filter_for_other_user(self):
request, channel = make_request(
- b"POST",
- b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"),
+ "POST",
+ "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
)
request.render(self.resource)
@@ -101,8 +101,8 @@ class FilterTestCase(unittest.TestCase):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
request, channel = make_request(
- b"POST",
- b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+ "POST",
+ "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
)
request.render(self.resource)
@@ -119,7 +119,7 @@ class FilterTestCase(unittest.TestCase):
self.clock.advance(1)
filter_id = filter_id.result
request, channel = make_request(
- b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
+ "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
@@ -129,7 +129,7 @@ class FilterTestCase(unittest.TestCase):
def test_get_filter_non_existant(self):
request, channel = make_request(
- b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
+ "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
@@ -141,7 +141,7 @@ class FilterTestCase(unittest.TestCase):
# in errors.py
def test_get_filter_invalid_id(self):
request, channel = make_request(
- b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
+ "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
@@ -151,7 +151,7 @@ class FilterTestCase(unittest.TestCase):
# No ID also returns an invalid_id error
def test_get_filter_no_id(self):
request, channel = make_request(
- b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
+ "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index e004d8fc73..f6293f11a8 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -81,7 +81,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"access_token": token,
"home_server": self.hs.hostname,
}
- self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
+ self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_appservice_registration_invalid(self):
self.appservice = None # no application service exists
@@ -102,7 +102,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(
- json.loads(channel.result["body"])["error"], "Invalid password"
+ channel.json_body["error"], "Invalid password"
)
def test_POST_bad_username(self):
@@ -113,7 +113,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(
- json.loads(channel.result["body"])["error"], "Invalid username"
+ channel.json_body["error"], "Invalid username"
)
def test_POST_user_valid(self):
@@ -140,7 +140,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"device_id": device_id,
}
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
+ self.assertDictContainsSubset(det_data, channel.json_body)
self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None
)
@@ -158,7 +158,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
- json.loads(channel.result["body"])["error"],
+ channel.json_body["error"],
"Registration has been disabled",
)
@@ -178,7 +178,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"device_id": "guest_device",
}
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
+ self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False
@@ -189,5 +189,5 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
- json.loads(channel.result["body"])["error"], "Guest access is disabled"
+ channel.json_body["error"], "Guest access is disabled"
)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 03ec3993b2..bafc0d1df0 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -32,7 +32,7 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase):
- USER_ID = b"@apple:test"
+ USER_ID = "@apple:test"
TO_REGISTER = [sync]
def setUp(self):
@@ -68,7 +68,7 @@ class FilterTestCase(unittest.TestCase):
r.register_servlets(self.hs, self.resource)
def test_sync_argless(self):
- request, channel = make_request(b"GET", b"/_matrix/client/r0/sync")
+ request, channel = make_request("GET", "/_matrix/client/r0/sync")
request.render(self.resource)
wait_until_result(self.clock, channel)
diff --git a/tests/server.py b/tests/server.py
index c611dd6059..e249668d21 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -11,6 +11,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock
from synapse.http.site import SynapseRequest
+from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth
@@ -28,7 +29,13 @@ class FakeChannel(object):
def json_body(self):
if not self.result:
raise Exception("No result yet.")
- return json.loads(self.result["body"])
+ return json.loads(self.result["body"].decode('utf8'))
+
+ @property
+ def code(self):
+ if not self.result:
+ raise Exception("No result yet.")
+ return int(self.result["code"])
def writeHeaders(self, version, code, reason, headers):
self.result["version"] = version
@@ -79,11 +86,16 @@ def make_request(method, path, content=b""):
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
"""
+ if not isinstance(method, bytes):
+ method = method.encode('ascii')
+
+ if not isinstance(path, bytes):
+ path = path.encode('ascii')
# Decorate it to be the full path
if not path.startswith(b"/_matrix"):
path = b"/_matrix/client/r0/" + path
- path = path.replace("//", "/")
+ path = path.replace(b"//", b"/")
if isinstance(content, text_type):
content = content.encode('utf8')
@@ -191,3 +203,9 @@ def setup_test_homeserver(*args, **kwargs):
clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool()
return d
+
+
+def get_clock():
+ clock = ThreadedMemoryReactorClock()
+ hs_clock = Clock(clock)
+ return (clock, hs_clock)
diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py
deleted file mode 100644
index f19cb1265c..0000000000
--- a/tests/storage/test__init__.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-import tests.utils
-
-
-class InitTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(InitTestCase, self).__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
-
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver()
-
- hs.config.max_mau_value = 50
- hs.config.limit_usage_by_mau = True
- self.store = hs.get_datastore()
- self.clock = hs.get_clock()
-
- @defer.inlineCallbacks
- def test_count_monthly_users(self):
- count = yield self.store.count_monthly_users()
- self.assertEqual(0, count)
-
- yield self._insert_user_ips("@user:server1")
- yield self._insert_user_ips("@user:server2")
-
- count = yield self.store.count_monthly_users()
- self.assertEqual(2, count)
-
- @defer.inlineCallbacks
- def _insert_user_ips(self, user):
- """
- Helper function to populate user_ips without using batch insertion infra
- args:
- user (str): specify username i.e. @user:server.com
- """
- yield self.store._simple_upsert(
- table="user_ips",
- keyvalues={
- "user_id": user,
- "access_token": "access_token",
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": "device_id",
- },
- values={
- "last_seen": self.clock.time_msec(),
- }
- )
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index bd6fda6cb1..7a58c6eb24 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
from twisted.internet import defer
@@ -27,9 +28,9 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield tests.utils.setup_test_homeserver()
- self.store = hs.get_datastore()
- self.clock = hs.get_clock()
+ self.hs = yield tests.utils.setup_test_homeserver()
+ self.store = self.hs.get_datastore()
+ self.clock = self.hs.get_clock()
@defer.inlineCallbacks
def test_insert_new_client_ip(self):
@@ -54,3 +55,62 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
},
r
)
+
+ @defer.inlineCallbacks
+ def test_disabled_monthly_active_user(self):
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.max_mau_value = 50
+ user_id = "@user:server"
+ yield self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id",
+ )
+ active = yield self.store._user_last_seen_monthly_active(user_id)
+ self.assertFalse(active)
+
+ @defer.inlineCallbacks
+ def test_adding_monthly_active_user_when_full(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.max_mau_value = 50
+ lots_of_users = 100
+ user_id = "@user:server"
+
+ self.store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(lots_of_users)
+ )
+ yield self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id",
+ )
+ active = yield self.store._user_last_seen_monthly_active(user_id)
+ self.assertFalse(active)
+
+ @defer.inlineCallbacks
+ def test_adding_monthly_active_user_when_space(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.max_mau_value = 50
+ user_id = "@user:server"
+ active = yield self.store._user_last_seen_monthly_active(user_id)
+ self.assertFalse(active)
+
+ yield self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id",
+ )
+ active = yield self.store._user_last_seen_monthly_active(user_id)
+ self.assertTrue(active)
+
+ @defer.inlineCallbacks
+ def test_updating_monthly_active_user_when_space(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.max_mau_value = 50
+ user_id = "@user:server"
+
+ active = yield self.store._user_last_seen_monthly_active(user_id)
+ self.assertFalse(active)
+
+ yield self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id",
+ )
+ yield self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id",
+ )
+ active = yield self.store._user_last_seen_monthly_active(user_id)
+ self.assertTrue(active)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 30683e7888..69412c5aad 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -49,7 +49,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
'INSERT INTO event_reference_hashes '
'(event_id, algorithm, hash) '
"VALUES (?, 'sha256', ?)"
- ), (event_id, 'ffff'))
+ ), (event_id, b'ffff'))
for i in range(0, 11):
yield self.store.runInteraction("insert", insert_event, i)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
new file mode 100644
index 0000000000..cbd480cd42
--- /dev/null
+++ b/tests/storage/test_monthly_active_users.py
@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+from tests.utils import setup_test_homeserver
+
+FORTY_DAYS = 40 * 24 * 60 * 60
+
+
+class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs)
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver()
+ self.store = self.hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_initialise_reserved_users(self):
+
+ user1 = "@user1:server"
+ user1_email = "user1@matrix.org"
+ user2 = "@user2:server"
+ user2_email = "user2@matrix.org"
+ threepids = [
+ {'medium': 'email', 'address': user1_email},
+ {'medium': 'email', 'address': user2_email}
+ ]
+ user_num = len(threepids)
+
+ yield self.store.register(
+ user_id=user1,
+ token="123",
+ password_hash=None)
+
+ yield self.store.register(
+ user_id=user2,
+ token="456",
+ password_hash=None)
+
+ now = int(self.hs.get_clock().time_msec())
+ yield self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ yield self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ yield self.store.initialise_reserved_users(threepids)
+
+ active_count = yield self.store.get_monthly_active_count()
+
+ # Test total counts
+ self.assertEquals(active_count, user_num)
+
+ # Test user is marked as active
+
+ timestamp = yield self.store._user_last_seen_monthly_active(user1)
+ self.assertTrue(timestamp)
+ timestamp = yield self.store._user_last_seen_monthly_active(user2)
+ self.assertTrue(timestamp)
+
+ # Test that users are never removed from the db.
+ self.hs.config.max_mau_value = 0
+
+ self.hs.get_clock().advance_time(FORTY_DAYS)
+
+ yield self.store.reap_monthly_active_users()
+
+ active_count = yield self.store.get_monthly_active_count()
+ self.assertEquals(active_count, user_num)
+
+ @defer.inlineCallbacks
+ def test_can_insert_and_count_mau(self):
+ count = yield self.store.get_monthly_active_count()
+ self.assertEqual(0, count)
+
+ yield self.store.upsert_monthly_active_user("@user:server")
+ count = yield self.store.get_monthly_active_count()
+
+ self.assertEqual(1, count)
+
+ @defer.inlineCallbacks
+ def test__user_last_seen_monthly_active(self):
+ user_id1 = "@user1:server"
+ user_id2 = "@user2:server"
+ user_id3 = "@user3:server"
+ result = yield self.store._user_last_seen_monthly_active(user_id1)
+ self.assertFalse(result == 0)
+ yield self.store.upsert_monthly_active_user(user_id1)
+ yield self.store.upsert_monthly_active_user(user_id2)
+ result = yield self.store._user_last_seen_monthly_active(user_id1)
+ self.assertTrue(result > 0)
+ result = yield self.store._user_last_seen_monthly_active(user_id3)
+ self.assertFalse(result == 0)
+
+ @defer.inlineCallbacks
+ def test_reap_monthly_active_users(self):
+ self.hs.config.max_mau_value = 5
+ initial_users = 10
+ for i in range(initial_users):
+ yield self.store.upsert_monthly_active_user("@user%d:server" % i)
+ count = yield self.store.get_monthly_active_count()
+ self.assertTrue(count, initial_users)
+ yield self.store.reap_monthly_active_users()
+ count = yield self.store.get_monthly_active_count()
+ self.assertEquals(count, initial_users - self.hs.config.max_mau_value)
+
+ self.hs.get_clock().advance_time(FORTY_DAYS)
+ yield self.store.reap_monthly_active_users()
+ count = yield self.store.get_monthly_active_count()
+ self.assertEquals(count, 0)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 7a76d67b8c..f7871cd426 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -176,7 +176,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
- group = group_ids.keys()[0]
+ group = list(group_ids.keys())[0]
# test _get_some_state_from_cache correctly filters out members with types=[]
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
diff --git a/tests/test_server.py b/tests/test_server.py
index 7e063c0290..fc396226ea 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -1,4 +1,3 @@
-import json
import re
from twisted.internet.defer import Deferred
@@ -104,9 +103,8 @@ class JsonResourceTests(unittest.TestCase):
request.render(res)
self.assertEqual(channel.result["code"], b'403')
- reply_body = json.loads(channel.result["body"])
- self.assertEqual(reply_body["error"], "Forbidden!!one!")
- self.assertEqual(reply_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
def test_no_handler(self):
"""
@@ -126,6 +124,5 @@ class JsonResourceTests(unittest.TestCase):
request.render(res)
self.assertEqual(channel.result["code"], b'400')
- reply_body = json.loads(channel.result["body"])
- self.assertEqual(reply_body["error"], "Unrecognized request")
- self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED")
+ self.assertEqual(channel.json_body["error"], "Unrecognized request")
+ self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
diff --git a/tests/utils.py b/tests/utils.py
index 9bff3ff3b9..5d49692c58 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -73,6 +73,13 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
config.block_events_without_consent_error = None
config.media_storage_providers = []
config.auto_join_rooms = []
+ config.limit_usage_by_mau = False
+ config.max_mau_value = 50
+ config.mau_limits_reserved_threepids = []
+
+ # we need a sane default_room_version, otherwise attempts to create rooms will
+ # fail.
+ config.default_room_version = "1"
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
@@ -146,8 +153,9 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
- hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
+ hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
+ hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(
+ p.encode('utf8')).hexdigest() == h
fed = kargs.get("resource_for_federation", None)
if fed:
@@ -220,8 +228,8 @@ class MockHttpResource(HttpServer):
mock_content.configure_mock(**config)
mock_request.content = mock_content
- mock_request.method = http_method
- mock_request.uri = path
+ mock_request.method = http_method.encode('ascii')
+ mock_request.uri = path.encode('ascii')
mock_request.getClientIP.return_value = "-"
|