diff options
Diffstat (limited to 'tests')
31 files changed, 1112 insertions, 371 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 022d81ce3e..f65a27e5f1 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -457,8 +457,8 @@ class AuthTestCase(unittest.TestCase): with self.assertRaises(ResourceLimitError) as e: yield self.auth.check_auth_blocking() - self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) # Ensure does not throw an error @@ -468,11 +468,36 @@ class AuthTestCase(unittest.TestCase): yield self.auth.check_auth_blocking() @defer.inlineCallbacks + def test_reserved_threepid(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 1 + threepid = {'medium': 'email', 'address': 'reserved@server.com'} + unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'} + self.hs.config.mau_limits_reserved_threepids = [threepid] + + yield self.store.register(user_id='user1', token="123", password_hash=None) + with self.assertRaises(ResourceLimitError): + yield self.auth.check_auth_blocking() + + with self.assertRaises(ResourceLimitError): + yield self.auth.check_auth_blocking(threepid=unknown_threepid) + + yield self.auth.check_auth_blocking(threepid=threepid) + + @defer.inlineCallbacks def test_hs_disabled(self): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: yield self.auth.check_auth_blocking() - self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) + + @defer.inlineCallbacks + def test_server_notices_mxid_special_cased(self): + self.hs.config.hs_disabled = True + user = "@user:server" + self.hs.config.server_notices_mxid = user + self.hs.config.hs_disabled_message = "Reason for being disabled" + yield self.auth.check_auth_blocking(user) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 56e7acd37c..a3aa0a1cf2 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,79 +14,79 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import synapse.api.errors import synapse.handlers.device import synapse.storage -from tests import unittest, utils +from tests import unittest user1 = "@boris:aaa" user2 = "@theresa:bbb" -class DeviceTestCase(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(DeviceTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore - self.handler = None # type: synapse.handlers.device.DeviceHandler - self.clock = None # type: utils.MockClock - - @defer.inlineCallbacks - def setUp(self): - hs = yield utils.setup_test_homeserver(self.addCleanup) +class DeviceTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) self.handler = hs.get_device_handler() self.store = hs.get_datastore() - self.clock = hs.get_clock() + return hs + + def prepare(self, reactor, clock, hs): + # These tests assume that it starts 1000 seconds in. + self.reactor.advance(1000) - @defer.inlineCallbacks def test_device_is_created_if_doesnt_exist(self): - res = yield self.handler.check_device_registered( - user_id="@boris:foo", - device_id="fco", - initial_device_display_name="display name", + res = self.get_success( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="fco", + initial_device_display_name="display name", + ) ) self.assertEqual(res, "fco") - dev = yield self.handler.store.get_device("@boris:foo", "fco") + dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) self.assertEqual(dev["display_name"], "display name") - @defer.inlineCallbacks def test_device_is_preserved_if_exists(self): - res1 = yield self.handler.check_device_registered( - user_id="@boris:foo", - device_id="fco", - initial_device_display_name="display name", + res1 = self.get_success( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="fco", + initial_device_display_name="display name", + ) ) self.assertEqual(res1, "fco") - res2 = yield self.handler.check_device_registered( - user_id="@boris:foo", - device_id="fco", - initial_device_display_name="new display name", + res2 = self.get_success( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="fco", + initial_device_display_name="new display name", + ) ) self.assertEqual(res2, "fco") - dev = yield self.handler.store.get_device("@boris:foo", "fco") + dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) self.assertEqual(dev["display_name"], "display name") - @defer.inlineCallbacks def test_device_id_is_made_up_if_unspecified(self): - device_id = yield self.handler.check_device_registered( - user_id="@theresa:foo", - device_id=None, - initial_device_display_name="display", + device_id = self.get_success( + self.handler.check_device_registered( + user_id="@theresa:foo", + device_id=None, + initial_device_display_name="display", + ) ) - dev = yield self.handler.store.get_device("@theresa:foo", device_id) + dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) self.assertEqual(dev["display_name"], "display") - @defer.inlineCallbacks def test_get_devices_by_user(self): - yield self._record_users() + self._record_users() + + res = self.get_success(self.handler.get_devices_by_user(user1)) - res = yield self.handler.get_devices_by_user(user1) self.assertEqual(3, len(res)) device_map = {d["device_id"]: d for d in res} self.assertDictContainsSubset( @@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase): device_map["abc"], ) - @defer.inlineCallbacks def test_get_device(self): - yield self._record_users() + self._record_users() - res = yield self.handler.get_device(user1, "abc") + res = self.get_success(self.handler.get_device(user1, "abc")) self.assertDictContainsSubset( { "user_id": user1, @@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase): res, ) - @defer.inlineCallbacks def test_delete_device(self): - yield self._record_users() + self._record_users() # delete the device - yield self.handler.delete_device(user1, "abc") + self.get_success(self.handler.delete_device(user1, "abc")) # check the device was deleted - with self.assertRaises(synapse.api.errors.NotFoundError): - yield self.handler.get_device(user1, "abc") + res = self.handler.get_device(user1, "abc") + self.pump() + self.assertIsInstance( + self.failureResultOf(res).value, synapse.api.errors.NotFoundError + ) # we'd like to check the access token was invalidated, but that's a # bit of a PITA. - @defer.inlineCallbacks def test_update_device(self): - yield self._record_users() + self._record_users() update = {"display_name": "new display"} - yield self.handler.update_device(user1, "abc", update) + self.get_success(self.handler.update_device(user1, "abc", update)) - res = yield self.handler.get_device(user1, "abc") + res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "new display") - @defer.inlineCallbacks def test_update_unknown_device(self): update = {"display_name": "new_display"} - with self.assertRaises(synapse.api.errors.NotFoundError): - yield self.handler.update_device("user_id", "unknown_device_id", update) + res = self.handler.update_device("user_id", "unknown_device_id", update) + self.pump() + self.assertIsInstance( + self.failureResultOf(res).value, synapse.api.errors.NotFoundError + ) - @defer.inlineCallbacks def _record_users(self): # check this works for both devices which have a recorded client_ip, # and those which don't. - yield self._record_user(user1, "xyz", "display 0") - yield self._record_user(user1, "fco", "display 1", "token1", "ip1") - yield self._record_user(user1, "abc", "display 2", "token2", "ip2") - yield self._record_user(user1, "abc", "display 2", "token3", "ip3") + self._record_user(user1, "xyz", "display 0") + self._record_user(user1, "fco", "display 1", "token1", "ip1") + self._record_user(user1, "abc", "display 2", "token2", "ip2") + self._record_user(user1, "abc", "display 2", "token3", "ip3") + + self._record_user(user2, "def", "dispkay", "token4", "ip4") - yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + self.reactor.advance(10000) - @defer.inlineCallbacks def _record_user( self, user_id, device_id, display_name, access_token=None, ip=None ): - device_id = yield self.handler.check_device_registered( - user_id=user_id, - device_id=device_id, - initial_device_display_name=display_name, + device_id = self.get_success( + self.handler.check_device_registered( + user_id=user_id, + device_id=device_id, + initial_device_display_name=display_name, + ) ) if ip is not None: - yield self.store.insert_client_ip( - user_id, access_token, ip, "user_agent", device_id + self.get_success( + self.store.insert_client_ip( + user_id, access_token, ip, "user_agent", device_id + ) ) - self.clock.advance_time(1000) + self.reactor.advance(1000) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 62dc69003c..80da1c8954 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -20,7 +20,7 @@ from twisted.internet import defer import synapse.types from synapse.api.errors import AuthError -from synapse.handlers.profile import ProfileHandler +from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest @@ -29,7 +29,7 @@ from tests.utils import setup_test_homeserver class ProfileHandlers(object): def __init__(self, hs): - self.profile_handler = ProfileHandler(hs) + self.profile_handler = MasterProfileHandler(hs) class ProfileTestCase(unittest.TestCase): diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index a01ab471f5..31f54bbd7d 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -51,7 +51,7 @@ class SyncTestCase(tests.unittest.TestCase): self.hs.config.hs_disabled = True with self.assertRaises(ResourceLimitError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.hs.config.hs_disabled = False @@ -59,7 +59,7 @@ class SyncTestCase(tests.unittest.TestCase): with self.assertRaises(ResourceLimitError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def _generate_sync_config(self, user_id): return SyncConfig( diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 65df116efc..089cecfbee 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -1,4 +1,5 @@ # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,89 +12,91 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import tempfile from mock import Mock, NonCallableMock -from twisted.internet import defer, reactor -from twisted.internet.defer import Deferred +import attr from synapse.replication.tcp.client import ( ReplicationClientFactory, ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory -from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable from tests import unittest -from tests.utils import setup_test_homeserver -class TestReplicationClientHandler(ReplicationClientHandler): - """Overrides on_rdata so that we can wait for it to happen""" +class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): - def __init__(self, store): - super(TestReplicationClientHandler, self).__init__(store) - self._rdata_awaiters = [] - - def await_replication(self): - d = Deferred() - self._rdata_awaiters.append(d) - return make_deferred_yieldable(d) - - def on_rdata(self, stream_name, token, rows): - awaiters = self._rdata_awaiters - self._rdata_awaiters = [] - super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows) - with PreserveLoggingContext(): - for a in awaiters: - a.callback(None) - - -class BaseSlavedStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver( - self.addCleanup, + hs = self.setup_test_homeserver( "blue", - http_client=None, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) - self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + hs.get_ratelimiter().send_message.return_value = (True, 0) + + return hs + + def prepare(self, reactor, clock, hs): self.master_store = self.hs.get_datastore() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - # XXX: mktemp is unsafe and should never be used. but we're just a test. - path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket") - listener = reactor.listenUNIX(path, server_factory) - self.addCleanup(listener.stopListening) self.streamer = server_factory.streamer - self.replication_handler = TestReplicationClientHandler(self.slaved_store) + self.replication_handler = ReplicationClientHandler(self.slaved_store) client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) - client_connector = reactor.connectUNIX(path, client_factory) - self.addCleanup(client_factory.stopTrying) - self.addCleanup(client_connector.disconnect) + + server = server_factory.buildProtocol(None) + client = client_factory.buildProtocol(None) + + @attr.s + class FakeTransport(object): + + other = attr.ib() + disconnecting = False + buffer = attr.ib(default=b'') + + def registerProducer(self, producer, streaming): + + self.producer = producer + + def _produce(): + self.producer.resumeProducing() + reactor.callLater(0.1, _produce) + + reactor.callLater(0.0, _produce) + + def write(self, byt): + self.buffer = self.buffer + byt + + if getattr(self.other, "transport") is not None: + self.other.dataReceived(self.buffer) + self.buffer = b"" + + def writeSequence(self, seq): + for x in seq: + self.write(x) + + client.makeConnection(FakeTransport(server)) + server.makeConnection(FakeTransport(client)) def replicate(self): """Tell the master side of replication that something has happened, and then wait for the replication to occur. """ - # xxx: should we be more specific in what we wait for? - d = self.replication_handler.await_replication() self.streamer.on_notifier_poke() - return d + self.pump(0.1) - @defer.inlineCallbacks def check(self, method, args, expected_result=None): - master_result = yield getattr(self.master_store, method)(*args) - slaved_result = yield getattr(self.slaved_store, method)(*args) + master_result = self.get_success(getattr(self.master_store, method)(*args)) + slaved_result = self.get_success(getattr(self.slaved_store, method)(*args)) if expected_result is not None: self.assertEqual(master_result, expected_result) self.assertEqual(slaved_result, expected_result) diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py index 87cc2b2fba..43e3248703 100644 --- a/tests/replication/slave/storage/test_account_data.py +++ b/tests/replication/slave/storage/test_account_data.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from ._base import BaseSlavedStoreTestCase @@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): STORE_TYPE = SlavedAccountDataStore - @defer.inlineCallbacks def test_user_account_data(self): - yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) - yield self.replicate() - yield self.check( + self.get_success( + self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) + ) + self.replicate() + self.check( "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1} ) - yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) - yield self.replicate() - yield self.check( + self.get_success( + self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) + ) + self.replicate() + self.check( "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2} ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 622be2eef8..db44d33c68 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext from synapse.replication.slave.storage.events import SlavedEventStore @@ -55,69 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def tearDown(self): [unpatch() for unpatch in self.unpatches] - @defer.inlineCallbacks def test_get_latest_event_ids_in_room(self): - create = yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.replicate() - yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) + create = self.persist(type="m.room.create", key="", creator=USER_ID) + self.replicate() + self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) - join = yield self.persist( + join = self.persist( type="m.room.member", key=USER_ID, membership="join", prev_events=[(create.event_id, {})], ) - yield self.replicate() - yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) + self.replicate() + self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) - @defer.inlineCallbacks def test_redactions(self): - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.persist(type="m.room.member", key=USER_ID, membership="join") + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") - msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") - yield self.replicate() - yield self.check("get_event", [msg.event_id], msg) + msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") + self.replicate() + self.check("get_event", [msg.event_id], msg) - redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id) - yield self.replicate() + redaction = self.persist(type="m.room.redaction", redacts=msg.event_id) + self.replicate() msg_dict = msg.get_dict() msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) - yield self.check("get_event", [msg.event_id], redacted) + self.check("get_event", [msg.event_id], redacted) - @defer.inlineCallbacks def test_backfilled_redactions(self): - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.persist(type="m.room.member", key=USER_ID, membership="join") + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") - msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") - yield self.replicate() - yield self.check("get_event", [msg.event_id], msg) + msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") + self.replicate() + self.check("get_event", [msg.event_id], msg) - redaction = yield self.persist( + redaction = self.persist( type="m.room.redaction", redacts=msg.event_id, backfill=True ) - yield self.replicate() + self.replicate() msg_dict = msg.get_dict() msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) - yield self.check("get_event", [msg.event_id], redacted) + self.check("get_event", [msg.event_id], redacted) - @defer.inlineCallbacks def test_invites(self): - yield self.check("get_invited_rooms_for_user", [USER_ID_2], []) - event = yield self.persist( - type="m.room.member", key=USER_ID_2, membership="invite" - ) - yield self.replicate() - yield self.check( + self.persist(type="m.room.create", key="", creator=USER_ID) + self.check("get_invited_rooms_for_user", [USER_ID_2], []) + event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") + + self.replicate() + + self.check( "get_invited_rooms_for_user", [USER_ID_2], [ @@ -131,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ], ) - @defer.inlineCallbacks def test_push_actions_for_user(self): - yield self.persist(type="m.room.create", creator=USER_ID) - yield self.persist(type="m.room.join", key=USER_ID, membership="join") - yield self.persist( + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.join", key=USER_ID, membership="join") + self.persist( type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join" ) - event1 = yield self.persist( - type="m.room.message", msgtype="m.text", body="hello" - ) - yield self.replicate() - yield self.check( + event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello") + self.replicate() + self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], {"highlight_count": 0, "notify_count": 0}, ) - yield self.persist( + self.persist( type="m.room.message", msgtype="m.text", body="world", push_actions=[(USER_ID_2, ["notify"])], ) - yield self.replicate() - yield self.check( + self.replicate() + self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], {"highlight_count": 0, "notify_count": 1}, ) - yield self.persist( + self.persist( type="m.room.message", msgtype="m.text", body="world", @@ -169,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}]) ], ) - yield self.replicate() - yield self.check( + self.replicate() + self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], {"highlight_count": 1, "notify_count": 2}, @@ -178,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event_id = 0 - @defer.inlineCallbacks def persist( self, sender=USER_ID, @@ -205,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): depth = self.event_id if not prev_events: - latest_event_ids = yield self.master_store.get_latest_event_ids_in_room( - room_id + latest_event_ids = self.get_success( + self.master_store.get_latest_event_ids_in_room(room_id) ) prev_events = [(ev_id, {}) for ev_id in latest_event_ids] @@ -239,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) else: state_handler = self.hs.get_state_handler() - context = yield state_handler.compute_event_context(event) + context = self.get_success(state_handler.compute_event_context(event)) - yield self.master_store.add_push_actions_to_staging( + self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} ) ordering = None if backfill: - yield self.master_store.persist_events([(event, context)], backfilled=True) + self.get_success( + self.master_store.persist_events([(event, context)], backfilled=True) + ) else: - ordering, _ = yield self.master_store.persist_event(event, context) + ordering, _ = self.get_success( + self.master_store.persist_event(event, context) + ) if ordering: event.internal_metadata.stream_ordering = ordering - defer.returnValue(event) + return event diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index ae1adeded1..f47d94f690 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from ._base import BaseSlavedStoreTestCase @@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): STORE_TYPE = SlavedReceiptsStore - @defer.inlineCallbacks def test_receipt(self): - yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) - yield self.master_store.insert_receipt( - ROOM_ID, "m.read", USER_ID, [EVENT_ID], {} - ) - yield self.replicate() - yield self.check( - "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID} + self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) + self.get_success( + self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}) ) + self.replicate() + self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 40dc4ea256..530dc8ba6d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -240,7 +240,6 @@ class RestHelper(object): self.assertEquals(200, code) defer.returnValue(response) - @defer.inlineCallbacks def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -248,9 +247,16 @@ class RestHelper(object): body = "body_text_here" path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) - content = '{"msgtype":"m.text","body":"%s"}' % body + content = {"msgtype": "m.text", "body": body} if tok: path = path + "?access_token=%s" % tok - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(expect_code, code, msg=str(response)) + request, channel = make_request("PUT", path, json.dumps(content).encode('utf8')) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body diff --git a/tests/server.py b/tests/server.py index c63b2c3100..615bba1b59 100644 --- a/tests/server.py +++ b/tests/server.py @@ -5,7 +5,7 @@ from six import text_type import attr -from twisted.internet import threads +from twisted.internet import address, threads from twisted.internet.defer import Deferred from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactorClock @@ -63,7 +63,9 @@ class FakeChannel(object): self.result["done"] = True def getPeer(self): - return None + # We give an address so that getClientIP returns a non null entry, + # causing us to record the MAU + return address.IPv4Address(b"TCP", "127.0.0.1", 3423) def getHost(self): return None @@ -91,7 +93,7 @@ class FakeSite: return FakeLogger() -def make_request(method, path, content=b""): +def make_request(method, path, content=b"", access_token=None): """ Make a web request using the given method and path, feed it the content, and return the Request and the Channel underneath. @@ -116,6 +118,11 @@ def make_request(method, path, content=b""): req = SynapseRequest(site, channel) req.process = lambda: b"" req.content = BytesIO(content) + + if access_token: + req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) + + req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1") req.requestReceived(method, path, b"1.1") return req, channel @@ -225,6 +232,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): clock.threadpool = ThreadPool() pool.threadpool = ThreadPool() + pool.running = True return d diff --git a/tests/server_notices/__init__.py b/tests/server_notices/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/server_notices/__init__.py diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py new file mode 100644 index 0000000000..5cc7fff39b --- /dev/null +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -0,0 +1,213 @@ +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, ServerNoticeMsgType +from synapse.api.errors import ResourceLimitError +from synapse.handlers.auth import AuthHandler +from synapse.server_notices.resource_limits_server_notices import ( + ResourceLimitsServerNotices, +) + +from tests import unittest +from tests.utils import setup_test_homeserver + + +class AuthHandlers(object): + def __init__(self, hs): + self.auth_handler = AuthHandler(hs) + + +class TestResourceLimitsServerNotices(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) + self.hs.handlers = AuthHandlers(self.hs) + self.auth_handler = self.hs.handlers.auth_handler + self.server_notices_sender = self.hs.get_server_notices_sender() + + # relying on [1] is far from ideal, but the only case where + # ResourceLimitsServerNotices class needs to be isolated is this test, + # general code should never have a reason to do so ... + self._rlsn = self.server_notices_sender._server_notices[1] + if not isinstance(self._rlsn, ResourceLimitsServerNotices): + raise Exception("Failed to find reference to ResourceLimitsServerNotices") + + self._rlsn._store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(1000) + ) + self._send_notice = self._rlsn._server_notices_manager.send_notice + self._rlsn._server_notices_manager.send_notice = Mock() + self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None)) + self._rlsn._store.get_events = Mock(return_value=defer.succeed({})) + + self._send_notice = self._rlsn._server_notices_manager.send_notice + + self.hs.config.limit_usage_by_mau = True + self.user_id = "@user_id:test" + + # self.server_notices_mxid = "@server:test" + # self.server_notices_mxid_display_name = None + # self.server_notices_mxid_avatar_url = None + # self.server_notices_room_name = "Server Notices" + + self._rlsn._server_notices_manager.get_notice_room_for_user = Mock( + returnValue="" + ) + self._rlsn._store.add_tag_to_room = Mock() + self._rlsn._store.get_tags_for_room = Mock(return_value={}) + self.hs.config.admin_contact = "mailto:user@test.com" + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_flag_off(self): + """Tests cases where the flags indicate nothing to do""" + # test hs disabled case + self.hs.config.hs_disabled = True + + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + self._send_notice.assert_not_called() + # Test when mau limiting disabled + self.hs.config.hs_disabled = False + self.hs.limit_usage_by_mau = False + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + self._send_notice.assert_not_called() + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): + """Test when user has blocked notice, but should have it removed""" + + self._rlsn._auth.check_auth_blocking = Mock() + mock_event = Mock( + type=EventTypes.Message, + content={"msgtype": ServerNoticeMsgType}, + ) + self._rlsn._store.get_events = Mock(return_value=defer.succeed( + {"123": mock_event} + )) + + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + # Would be better to check the content, but once == remove blocking event + self._send_notice.assert_called_once() + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): + """Test when user has blocked notice, but notice ought to be there (NOOP)""" + self._rlsn._auth.check_auth_blocking = Mock( + side_effect=ResourceLimitError(403, 'foo') + ) + + mock_event = Mock( + type=EventTypes.Message, + content={"msgtype": ServerNoticeMsgType}, + ) + self._rlsn._store.get_events = Mock(return_value=defer.succeed( + {"123": mock_event} + )) + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + self._send_notice.assert_not_called() + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_add_blocked_notice(self): + """Test when user does not have blocked notice, but should have one""" + + self._rlsn._auth.check_auth_blocking = Mock( + side_effect=ResourceLimitError(403, 'foo') + ) + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + # Would be better to check contents, but 2 calls == set blocking event + self.assertTrue(self._send_notice.call_count == 2) + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): + """Test when user does not have blocked notice, nor should they (NOOP)""" + + self._rlsn._auth.check_auth_blocking = Mock() + + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + self._send_notice.assert_not_called() + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): + + """Test when user is not part of the MAU cohort - this should not ever + happen - but ... + """ + + self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(None) + ) + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + self._send_notice.assert_not_called() + + +class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(self.addCleanup) + self.store = self.hs.get_datastore() + self.server_notices_sender = self.hs.get_server_notices_sender() + self.server_notices_manager = self.hs.get_server_notices_manager() + self.event_source = self.hs.get_event_sources() + + # relying on [1] is far from ideal, but the only case where + # ResourceLimitsServerNotices class needs to be isolated is this test, + # general code should never have a reason to do so ... + self._rlsn = self.server_notices_sender._server_notices[1] + if not isinstance(self._rlsn, ResourceLimitsServerNotices): + raise Exception("Failed to find reference to ResourceLimitsServerNotices") + + self.hs.config.limit_usage_by_mau = True + self.hs.config.hs_disabled = False + self.hs.config.max_mau_value = 5 + self.hs.config.server_notices_mxid = "@server:test" + self.hs.config.server_notices_mxid_display_name = None + self.hs.config.server_notices_mxid_avatar_url = None + self.hs.config.server_notices_room_name = "Test Server Notice Room" + + self.user_id = "@user_id:test" + + self.hs.config.admin_contact = "mailto:user@test.com" + + @defer.inlineCallbacks + def test_server_notice_only_sent_once(self): + self.store.get_monthly_active_count = Mock( + return_value=1000, + ) + + self.store.user_last_seen_monthly_active = Mock( + return_value=1000, + ) + + # Call the function multiple times to ensure we only send the notice once + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + # Now lets get the last load of messages in the service notice room and + # check that there is only one server notice + room_id = yield self.server_notices_manager.get_notice_room_for_user( + self.user_id, + ) + + token = yield self.event_source.get_current_token() + events, _ = yield self.store.get_recent_events_for_room( + room_id, limit=100, end_token=token.room_key, + ) + + count = 0 + for event in events: + if event.type != EventTypes.Message: + continue + if event.content.get("msgtype") != ServerNoticeMsgType: + continue + + count += 1 + + self.assertEqual(count, 1) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index c893990454..3f0083831b 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.as_yaml_files = [] - config = Mock( - app_service_config_files=self.as_yaml_files, - event_cache_size=1, - password_providers=[], - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + hs.config.app_service_config_files = self.as_yaml_files + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + self.as_token = "token1" self.as_url = "some_url" self.as_id = "as1" @@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(None, hs) + self.store = ApplicationServiceStore(hs.get_db_conn(), hs) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): def setUp(self): self.as_yaml_files = [] - config = Mock( - app_service_config_files=self.as_yaml_files, - event_cache_size=1, - password_providers=[], - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + + hs.config.app_service_config_files = self.as_yaml_files + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + self.db_pool = hs.get_db_pool() + self.engine = hs.database_engine self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, @@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(None, hs) + self.store = TestTransactionStore(hs.get_db_conn(), hs) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files.append(as_token) def _set_state(self, id, state, txn=None): - return self.db_pool.runQuery( - "INSERT INTO application_services_state(as_id, state, last_txn) " - "VALUES(?,?,?)", + return self.db_pool.runOperation( + self.engine.convert_param_style( + "INSERT INTO application_services_state(as_id, state, last_txn) " + "VALUES(?,?,?)" + ), (id, state, txn), ) def _insert_txn(self, as_id, txn_id, events): - return self.db_pool.runQuery( - "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " - "VALUES(?,?,?)", + return self.db_pool.runOperation( + self.engine.convert_param_style( + "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " + "VALUES(?,?,?)" + ), (as_id, txn_id, json.dumps([e.event_id for e in events])), ) def _set_last_txn(self, as_id, txn_id): - return self.db_pool.runQuery( - "INSERT INTO application_services_state(as_id, last_txn, state) " - "VALUES(?,?,?)", + return self.db_pool.runOperation( + self.engine.convert_param_style( + "INSERT INTO application_services_state(as_id, last_txn, state) " + "VALUES(?,?,?)" + ), (as_id, txn_id, ApplicationServiceState.UP), ) @defer.inlineCallbacks def test_get_appservice_state_none(self): - service = Mock(id=999) + service = Mock(id="999") state = yield self.store.get_appservice_state(service) self.assertEquals(None, state) @@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): service = Mock(id=self.as_list[1]["id"]) yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) rows = yield self.db_pool.runQuery( - "SELECT as_id FROM application_services_state WHERE state=?", + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), (ApplicationServiceState.DOWN,), ) self.assertEquals(service.id, rows[0][0]) @@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) yield self.store.set_appservice_state(service, ApplicationServiceState.UP) rows = yield self.db_pool.runQuery( - "SELECT as_id FROM application_services_state WHERE state=?", + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), (ApplicationServiceState.UP,), ) self.assertEquals(service.id, rows[0][0]) @@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) res = yield self.db_pool.runQuery( - "SELECT last_txn FROM application_services_state WHERE as_id=?", + self.engine.convert_param_style( + "SELECT last_txn FROM application_services_state WHERE as_id=?" + ), (service.id,), ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) res = yield self.db_pool.runQuery( - "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), ) self.assertEquals(0, len(res)) @@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) res = yield self.db_pool.runQuery( - "SELECT last_txn, state FROM application_services_state WHERE " "as_id=?", + self.engine.convert_param_style( + "SELECT last_txn, state FROM application_services_state WHERE as_id=?" + ), (service.id,), ) self.assertEquals(1, len(res)) @@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(ApplicationServiceState.UP, res[0][1]) res = yield self.db_pool.runQuery( - "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), ) self.assertEquals(0, len(res)) @@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(suffix="1") f2 = self._write_config(suffix="2") - config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - datastore=Mock(), - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - ApplicationServiceStore(None, hs) + hs.config.app_service_config_files = [f1, f2] + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + + ApplicationServiceStore(hs.get_db_conn(), hs) @defer.inlineCallbacks def test_duplicate_ids(self): f1 = self._write_config(id="id", suffix="1") f2 = self._write_config(id="id", suffix="2") - config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - datastore=Mock(), - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + hs.config.app_service_config_files = [f1, f2] + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(None, hs) + ApplicationServiceStore(hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) @@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(as_token="as_token", suffix="1") f2 = self._write_config(as_token="as_token", suffix="2") - config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - datastore=Mock(), - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + hs.config.app_service_config_files = [f1, f2] + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(None, hs) + ApplicationServiceStore(hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 7cb5f0e4cf..829f47d2e8 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -20,11 +20,11 @@ from mock import Mock from twisted.internet import defer -from synapse.server import HomeServer from synapse.storage._base import SQLBaseStore from synapse.storage.engines import create_engine from tests import unittest +from tests.utils import TestHomeServer class SQLBaseStoreTestCase(unittest.TestCase): @@ -51,7 +51,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config.event_cache_size = 1 config.database_config = {"name": "sqlite3"} - hs = HomeServer( + hs = TestHomeServer( "test", db_pool=self.db_pool, config=config, diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index b4510c1c8d..4e128e1047 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.directory import DirectoryStore from synapse.types import RoomAlias, RoomID from tests import unittest @@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = DirectoryStore(None, hs) + self.store = hs.get_datastore() self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 2fdf34fdf6..0d4e74d637 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ( "INSERT INTO events (" " room_id, event_id, type, depth, topological_ordering," - " content, processed, outlier) " - "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" + " content, processed, outlier, stream_ordering) " + "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)" ), - (room_id, event_id, i, i, True, False), + (room_id, event_id, i, i, True, False, i), ) txn.execute( diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index f2ed866ae7..2036287288 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -13,25 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -import tests.unittest -import tests.utils -from tests.utils import setup_test_homeserver +from tests.unittest import HomeserverTestCase FORTY_DAYS = 40 * 24 * 60 * 60 -class MonthlyActiveUsersTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs) +class MonthlyActiveUsersTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver() + self.store = hs.get_datastore() + + # Advance the clock a bit + reactor.advance(FORTY_DAYS) - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = self.hs.get_datastore() + return hs - @defer.inlineCallbacks def test_initialise_reserved_users(self): self.hs.config.max_mau_value = 5 user1 = "@user1:server" @@ -44,88 +41,101 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): ] user_num = len(threepids) - yield self.store.register(user_id=user1, token="123", password_hash=None) - - yield self.store.register(user_id=user2, token="456", password_hash=None) + self.store.register(user_id=user1, token="123", password_hash=None) + self.store.register(user_id=user2, token="456", password_hash=None) + self.pump() now = int(self.hs.get_clock().time_msec()) - yield self.store.user_add_threepid(user1, "email", user1_email, now, now) - yield self.store.user_add_threepid(user2, "email", user2_email, now, now) - yield self.store.initialise_reserved_users(threepids) + self.store.user_add_threepid(user1, "email", user1_email, now, now) + self.store.user_add_threepid(user2, "email", user2_email, now, now) + self.store.initialise_reserved_users(threepids) + self.pump() - active_count = yield self.store.get_monthly_active_count() + active_count = self.store.get_monthly_active_count() # Test total counts - self.assertEquals(active_count, user_num) + self.assertEquals(self.get_success(active_count), user_num) # Test user is marked as active - - timestamp = yield self.store.user_last_seen_monthly_active(user1) - self.assertTrue(timestamp) - timestamp = yield self.store.user_last_seen_monthly_active(user2) - self.assertTrue(timestamp) + timestamp = self.store.user_last_seen_monthly_active(user1) + self.assertTrue(self.get_success(timestamp)) + timestamp = self.store.user_last_seen_monthly_active(user2) + self.assertTrue(self.get_success(timestamp)) # Test that users are never removed from the db. self.hs.config.max_mau_value = 0 - self.hs.get_clock().advance_time(FORTY_DAYS) + self.reactor.advance(FORTY_DAYS) - yield self.store.reap_monthly_active_users() + self.store.reap_monthly_active_users() + self.pump() - active_count = yield self.store.get_monthly_active_count() - self.assertEquals(active_count, user_num) + active_count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(active_count), user_num) # Test that regalar users are removed from the db ru_count = 2 - yield self.store.upsert_monthly_active_user("@ru1:server") - yield self.store.upsert_monthly_active_user("@ru2:server") - active_count = yield self.store.get_monthly_active_count() + self.store.upsert_monthly_active_user("@ru1:server") + self.store.upsert_monthly_active_user("@ru2:server") + self.pump() - self.assertEqual(active_count, user_num + ru_count) + active_count = self.store.get_monthly_active_count() + self.assertEqual(self.get_success(active_count), user_num + ru_count) self.hs.config.max_mau_value = user_num - yield self.store.reap_monthly_active_users() + self.store.reap_monthly_active_users() + self.pump() - active_count = yield self.store.get_monthly_active_count() - self.assertEquals(active_count, user_num) + active_count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(active_count), user_num) - @defer.inlineCallbacks def test_can_insert_and_count_mau(self): - count = yield self.store.get_monthly_active_count() - self.assertEqual(0, count) + count = self.store.get_monthly_active_count() + self.assertEqual(0, self.get_success(count)) - yield self.store.upsert_monthly_active_user("@user:server") - count = yield self.store.get_monthly_active_count() + self.store.upsert_monthly_active_user("@user:server") + self.pump() - self.assertEqual(1, count) + count = self.store.get_monthly_active_count() + self.assertEqual(1, self.get_success(count)) - @defer.inlineCallbacks def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" - result = yield self.store.user_last_seen_monthly_active(user_id1) - self.assertFalse(result == 0) - yield self.store.upsert_monthly_active_user(user_id1) - yield self.store.upsert_monthly_active_user(user_id2) - result = yield self.store.user_last_seen_monthly_active(user_id1) - self.assertTrue(result > 0) - result = yield self.store.user_last_seen_monthly_active(user_id3) - self.assertFalse(result == 0) + result = self.store.user_last_seen_monthly_active(user_id1) + self.assertFalse(self.get_success(result) == 0) + + self.store.upsert_monthly_active_user(user_id1) + self.store.upsert_monthly_active_user(user_id2) + self.pump() + + result = self.store.user_last_seen_monthly_active(user_id1) + self.assertGreater(self.get_success(result), 0) + + result = self.store.user_last_seen_monthly_active(user_id3) + self.assertNotEqual(self.get_success(result), 0) - @defer.inlineCallbacks def test_reap_monthly_active_users(self): self.hs.config.max_mau_value = 5 initial_users = 10 for i in range(initial_users): - yield self.store.upsert_monthly_active_user("@user%d:server" % i) - count = yield self.store.get_monthly_active_count() - self.assertTrue(count, initial_users) - yield self.store.reap_monthly_active_users() - count = yield self.store.get_monthly_active_count() - self.assertEquals(count, initial_users - self.hs.config.max_mau_value) - - self.hs.get_clock().advance_time(FORTY_DAYS) - yield self.store.reap_monthly_active_users() - count = yield self.store.get_monthly_active_count() - self.assertEquals(count, 0) + self.store.upsert_monthly_active_user("@user%d:server" % i) + self.pump() + + count = self.store.get_monthly_active_count() + self.assertTrue(self.get_success(count), initial_users) + + self.store.reap_monthly_active_users() + self.pump() + count = self.store.get_monthly_active_count() + self.assertEquals( + self.get_success(count), initial_users - self.hs.config.max_mau_value + ) + + self.reactor.advance(FORTY_DAYS) + self.store.reap_monthly_active_users() + self.pump() + + count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(count), 0) diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index b5b58ff660..c7a63f39b9 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -16,19 +16,18 @@ from twisted.internet import defer -from synapse.storage.presence import PresenceStore from synapse.types import UserID from tests import unittest -from tests.utils import MockClock, setup_test_homeserver +from tests.utils import setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock()) + hs = yield setup_test_homeserver(self.addCleanup) - self.store = PresenceStore(None, hs) + self.store = hs.get_datastore() self.u_apple = UserID.from_string("@apple:test") self.u_banana = UserID.from_string("@banana:test") diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index a1f6618bf9..45824bd3b2 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -28,7 +28,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(None, hs) + self.store = ProfileStore(hs.get_db_conn(), hs) self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py new file mode 100644 index 0000000000..f671599cb8 --- /dev/null +++ b/tests/storage/test_purge.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.client.v1 import room + +from tests.unittest import HomeserverTestCase + + +class PurgeTests(HomeserverTestCase): + + user_id = "@red:server" + servlets = [room.register_servlets] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) + return hs + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_purge(self): + """ + Purging a room will delete everything before the topological point. + """ + # Send four messages to the room + first = self.helper.send(self.room_id, body="test1") + second = self.helper.send(self.room_id, body="test2") + third = self.helper.send(self.room_id, body="test3") + last = self.helper.send(self.room_id, body="test4") + + storage = self.hs.get_datastore() + + # Get the topological token + event = storage.get_topological_token_for_event(last["event_id"]) + self.pump() + event = self.successResultOf(event) + + # Purge everything before this topological token + purge = storage.purge_history(self.room_id, event, True) + self.pump() + self.assertEqual(self.successResultOf(purge), None) + + # Try and get the events + get_first = storage.get_event(first["event_id"]) + get_second = storage.get_event(second["event_id"]) + get_third = storage.get_event(third["event_id"]) + get_last = storage.get_event(last["event_id"]) + self.pump() + + # 1-3 should fail and last will succeed, meaning that 1-3 are deleted + # and last is not. + self.failureResultOf(get_first) + self.failureResultOf(get_second) + self.failureResultOf(get_third) + self.successResultOf(get_last) + + def test_purge_wont_delete_extrems(self): + """ + Purging a room will delete everything before the topological point. + """ + # Send four messages to the room + first = self.helper.send(self.room_id, body="test1") + second = self.helper.send(self.room_id, body="test2") + third = self.helper.send(self.room_id, body="test3") + last = self.helper.send(self.room_id, body="test4") + + storage = self.hs.get_datastore() + + # Set the topological token higher than it should be + event = storage.get_topological_token_for_event(last["event_id"]) + self.pump() + event = self.successResultOf(event) + event = "t{}-{}".format( + *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) + ) + + # Purge everything before this topological token + purge = storage.purge_history(self.room_id, event, True) + self.pump() + f = self.failureResultOf(purge) + self.assertIn("greater than forward", f.value.args[0]) + + # Try and get the events + get_first = storage.get_event(first["event_id"]) + get_second = storage.get_event(second["event_id"]) + get_third = storage.get_event(third["event_id"]) + get_last = storage.get_event(last["event_id"]) + self.pump() + + # Nothing is deleted. + self.successResultOf(get_first) + self.successResultOf(get_second) + self.successResultOf(get_third) + self.successResultOf(get_last) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index c4e9fb72bf..02bf975fbf 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.types import RoomID, UserID from tests import unittest -from tests.utils import setup_test_homeserver +from tests.utils import create_room, setup_test_homeserver class RedactionTestCase(unittest.TestCase): @@ -41,6 +41,8 @@ class RedactionTestCase(unittest.TestCase): self.room1 = RoomID.from_string("!abc123:test") + yield create_room(hs, self.room1.to_string(), self.u_alice.to_string()) + self.depth = 1 @defer.inlineCallbacks diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 4eda122edc..3dfb7b903a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -46,6 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "consent_version": None, "consent_server_notice_sent": None, "appservice_id": None, + "creation_ts": 1000, }, (yield self.store.get_user_by_id(self.user_id)), ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index c83ef60062..978c66133d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.types import RoomID, UserID from tests import unittest -from tests.utils import setup_test_homeserver +from tests.utils import create_room, setup_test_homeserver class RoomMemberStoreTestCase(unittest.TestCase): @@ -45,6 +45,8 @@ class RoomMemberStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abc123:test") + yield create_room(hs, self.room.to_string(), self.u_alice.to_string()) + @defer.inlineCallbacks def inject_room_member(self, room, user, membership, replaces_state=None): builder = self.event_builder_factory.new( diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index ebfd969b36..d717b9f94e 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -185,6 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters out members with types=[] (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] ) @@ -197,8 +198,20 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + {}, + state_dict, + ) + # test _get_some_state_from_cache correctly filters in members with wildcard types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] ) @@ -207,6 +220,18 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e3.type, e3.state_key): e3.event_id, # e4 is overwritten by e5 (e5.type, e5.state_key): e5.event_id, @@ -216,6 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member], @@ -226,6 +252,20 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=[EventTypes.Member], + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e5.type, e5.state_key): e5.event_id, }, state_dict, @@ -234,6 +274,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types # and no filtered_types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=None ) @@ -254,9 +295,6 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, - (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 - (e5.type, e5.state_key): e5.event_id, }, ) @@ -269,8 +307,6 @@ class StateStoreTestCase(tests.unittest.TestCase): # list fetched keys so it knows it's partial fetched_keys=( (e1.type, e1.state_key), - (e3.type, e3.state_key), - (e5.type, e5.state_key), ), ) @@ -284,8 +320,6 @@ class StateStoreTestCase(tests.unittest.TestCase): set( [ (e1.type, e1.state_key), - (e3.type, e3.state_key), - (e5.type, e5.state_key), ] ), ) @@ -293,8 +327,6 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict_ids, { (e1.type, e1.state_key): e1.event_id, - (e3.type, e3.state_key): e3.event_id, - (e5.type, e5.state_key): e5.event_id, }, ) @@ -304,14 +336,25 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters out members with types=[] room_id = self.room.to_string() (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + room_id = self.room.to_string() + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({}, state_dict) + # test _get_some_state_from_cache correctly filters in members wildcard types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] ) @@ -319,8 +362,19 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual( { (e1.type, e1.state_key): e1.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 (e5.type, e5.state_key): e5.event_id, }, state_dict, @@ -328,6 +382,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member], @@ -337,6 +392,20 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual( { (e1.type, e1.state_key): e1.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=[EventTypes.Member], + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e5.type, e5.state_key): e5.event_id, }, state_dict, @@ -345,8 +414,22 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types # and no filtered_types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, + group, [(EventTypes.Member, e5.state_key)], filtered_types=None + ) + + self.assertEqual(is_all, False) + self.assertDictEqual({}, state_dict) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=None ) self.assertEqual(is_all, True) - self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + self.assertDictEqual( + { + (e5.type, e5.state_key): e5.event_id, + }, + state_dict, + ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index b46e0ea7e2..0dde1ab2fe 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -30,7 +30,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(None, self.hs) + self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_mau.py b/tests/test_mau.py new file mode 100644 index 0000000000..0732615447 --- /dev/null +++ b/tests/test_mau.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /rooms paths.""" + +import json + +from mock import Mock, NonCallableMock + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes, HttpResponseException, SynapseError +from synapse.http.server import JsonResource +from synapse.rest.client.v2_alpha import register, sync +from synapse.util import Clock + +from tests import unittest +from tests.server import ( + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) + + +class TestMauLimit(unittest.TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + self.clock = Clock(self.reactor) + + self.hs = setup_test_homeserver( + self.addCleanup, + "red", + http_client=None, + clock=self.clock, + reactor=self.reactor, + federation_client=Mock(), + ratelimiter=NonCallableMock(spec_set=["send_message"]), + ) + + self.store = self.hs.get_datastore() + + self.hs.config.registrations_require_3pid = [] + self.hs.config.enable_registration_captcha = False + self.hs.config.recaptcha_public_key = [] + + self.hs.config.limit_usage_by_mau = True + self.hs.config.hs_disabled = False + self.hs.config.max_mau_value = 2 + self.hs.config.mau_trial_days = 0 + self.hs.config.server_notices_mxid = "@server:red" + self.hs.config.server_notices_mxid_display_name = None + self.hs.config.server_notices_mxid_avatar_url = None + self.hs.config.server_notices_room_name = "Test Server Notice Room" + + self.resource = JsonResource(self.hs) + register.register_servlets(self.hs, self.resource) + sync.register_servlets(self.hs, self.resource) + + def test_simple_deny_mau(self): + # Create and sync so that the MAU counts get updated + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + + # We've created and activated two users, we shouldn't be able to + # register new users + with self.assertRaises(SynapseError) as cm: + self.create_user("kermit3") + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_allowed_after_a_month_mau(self): + # Create and sync so that the MAU counts get updated + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + + # Advance time by 31 days + self.reactor.advance(31 * 24 * 60 * 60) + + self.store.reap_monthly_active_users() + + self.reactor.advance(0) + + # We should be able to register more users + token3 = self.create_user("kermit3") + self.do_sync_for_user(token3) + + def test_trial_delay(self): + self.hs.config.mau_trial_days = 1 + + # We should be able to register more than the limit initially + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + token3 = self.create_user("kermit3") + self.do_sync_for_user(token3) + + # Advance time by 2 days + self.reactor.advance(2 * 24 * 60 * 60) + + # Two users should be able to sync + self.do_sync_for_user(token1) + self.do_sync_for_user(token2) + + # But the third should fail + with self.assertRaises(SynapseError) as cm: + self.do_sync_for_user(token3) + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + # And new registrations are now denied too + with self.assertRaises(SynapseError) as cm: + self.create_user("kermit4") + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_trial_users_cant_come_back(self): + self.hs.config.mau_trial_days = 1 + + # We should be able to register more than the limit initially + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + token3 = self.create_user("kermit3") + self.do_sync_for_user(token3) + + # Advance time by 2 days + self.reactor.advance(2 * 24 * 60 * 60) + + # Two users should be able to sync + self.do_sync_for_user(token1) + self.do_sync_for_user(token2) + + # Advance by 2 months so everyone falls out of MAU + self.reactor.advance(60 * 24 * 60 * 60) + self.store.reap_monthly_active_users() + self.reactor.advance(0) + + # We can create as many new users as we want + token4 = self.create_user("kermit4") + self.do_sync_for_user(token4) + token5 = self.create_user("kermit5") + self.do_sync_for_user(token5) + token6 = self.create_user("kermit6") + self.do_sync_for_user(token6) + + # users 2 and 3 can come back to bring us back up to MAU limit + self.do_sync_for_user(token2) + self.do_sync_for_user(token3) + + # New trial users can still sync + self.do_sync_for_user(token4) + self.do_sync_for_user(token5) + self.do_sync_for_user(token6) + + # But old user cant + with self.assertRaises(SynapseError) as cm: + self.do_sync_for_user(token1) + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def create_user(self, localpart): + request_data = json.dumps({ + "username": localpart, + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + }) + + request, channel = make_request(b"POST", b"/register", request_data) + render(request, self.resource, self.reactor) + + if channel.result["code"] != b"200": + raise HttpResponseException( + int(channel.result["code"]), + channel.result["reason"], + channel.result["body"], + ).to_synapse_error() + + access_token = channel.json_body["access_token"] + + return access_token + + def do_sync_for_user(self, token): + request, channel = make_request(b"GET", b"/sync", access_token=token) + render(request, self.resource, self.reactor) + + if channel.result["code"] != b"200": + raise HttpResponseException( + int(channel.result["code"]), + channel.result["reason"], + channel.result["body"], + ).to_synapse_error() diff --git a/tests/test_state.py b/tests/test_state.py index 96fdb8636c..452a123c3a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -18,7 +18,7 @@ from mock import Mock from twisted.internet import defer from synapse.api.auth import Auth -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RoomVersions from synapse.events import FrozenEvent from synapse.state import StateHandler, StateResolutionHandler @@ -117,6 +117,9 @@ class StateGroupStore(object): def register_event_id_state_group(self, event_id, state_group): self._event_to_state_group[event_id] = state_group + def get_room_version(self, room_id): + return RoomVersions.V1 + class DictObj(dict): def __init__(self, **kwargs): @@ -176,7 +179,9 @@ class StateTestCase(unittest.TestCase): def test_branch_no_conflict(self): graph = Graph( nodes={ - "START": DictObj(type=EventTypes.Create, state_key="", depth=1), + "START": DictObj( + type=EventTypes.Create, state_key="", content={}, depth=1, + ), "A": DictObj(type=EventTypes.Message, depth=2), "B": DictObj(type=EventTypes.Message, depth=3), "C": DictObj(type=EventTypes.Name, state_key="", depth=3), diff --git a/tests/test_types.py b/tests/test_types.py index be072d402b..0f5c8bfaf9 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -14,12 +14,12 @@ # limitations under the License. from synapse.api.errors import SynapseError -from synapse.server import HomeServer from synapse.types import GroupID, RoomAlias, UserID from tests import unittest +from tests.utils import TestHomeServer -mock_homeserver = HomeServer(hostname="my.domain") +mock_homeserver = TestHomeServer(hostname="my.domain") class UserIDTestCase(unittest.TestCase): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 45a78338d6..2eea3b098b 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -21,7 +21,7 @@ from synapse.events import FrozenEvent from synapse.visibility import filter_events_for_server import tests.unittest -from tests.utils import setup_test_homeserver +from tests.utils import create_room, setup_test_homeserver logger = logging.getLogger(__name__) @@ -36,6 +36,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() + yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") + @defer.inlineCallbacks def test_filtering(self): # @@ -94,7 +96,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter.append(evt) # the erasey user gets erased - self.hs.get_datastore().mark_user_erased("@erased:local_hs") + yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") # ... and the filtering happens. filtered = yield filter_events_for_server( diff --git a/tests/unittest.py b/tests/unittest.py index d852e2465a..a3d39920db 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -22,6 +22,7 @@ from canonicaljson import json import twisted import twisted.logger +from twisted.internet.defer import Deferred from twisted.trial import unittest from synapse.http.server import JsonResource @@ -151,6 +152,7 @@ class HomeserverTestCase(TestCase): hijack_auth (bool): Whether to hijack auth to return the user specified in user_id. """ + servlets = [] hijack_auth = True @@ -279,3 +281,15 @@ class HomeserverTestCase(TestCase): kwargs = dict(kwargs) kwargs.update(self._hs_args) return setup_test_homeserver(self.addCleanup, *args, **kwargs) + + def pump(self, by=0.0): + """ + Pump the reactor enough that Deferreds will fire. + """ + self.reactor.pump([by] * 100) + + def get_success(self, d): + if not isinstance(d, Deferred): + return d + self.pump() + return self.successResultOf(d) diff --git a/tests/utils.py b/tests/utils.py index 8de2898b2f..b85017d279 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,12 +24,14 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer, reactor +from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error +from synapse.config.server import ServerConfig from synapse.federation.transport import server from synapse.http.server import HttpServer from synapse.server import HomeServer -from synapse.storage import PostgresEngine -from synapse.storage.engines import create_engine +from synapse.storage import DataStore +from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import ( _get_or_create_schema_state, _setup_new_database, @@ -40,6 +42,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) +LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False) POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) @@ -91,10 +94,14 @@ def setupdb(): atexit.register(_cleanup) +class TestHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + + @defer.inlineCallbacks def setup_test_homeserver( cleanup_func, name="test", datastore=None, config=None, reactor=None, - homeserverToUse=HomeServer, **kargs + homeserverToUse=TestHomeServer, **kargs ): """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -141,7 +148,9 @@ def setup_test_homeserver( config.hs_disabled_limit_type = "" config.max_mau_value = 50 config.mau_limits_reserved_threepids = [] - config.admin_uri = None + config.admin_contact = None + config.rc_messages_per_second = 10000 + config.rc_message_burst_count = 10000 # we need a sane default_room_version, otherwise attempts to create rooms will # fail. @@ -151,6 +160,11 @@ def setup_test_homeserver( # background, which upsets the test runner. config.update_user_directory = False + def is_threepid_reserved(threepid): + return ServerConfig.is_threepid_reserved(config, threepid) + + config.is_threepid_reserved.side_effect = is_threepid_reserved + config.use_frozen_dicts = True config.ldap_enabled = False @@ -231,8 +245,9 @@ def setup_test_homeserver( cur.close() db_conn.close() - # Register the cleanup hook - cleanup_func(cleanup) + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) hs.setup() else: @@ -545,3 +560,32 @@ class DeferredMockCallable(object): "Expected not to received any calls, got:\n" + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls]) ) + + +@defer.inlineCallbacks +def create_room(hs, room_id, creator_id): + """Creates and persist a creation event for the given room + + Args: + hs + room_id (str) + creator_id (str) + """ + + store = hs.get_datastore() + event_builder_factory = hs.get_event_builder_factory() + event_creation_handler = hs.get_event_creation_handler() + + builder = event_builder_factory.new({ + "type": EventTypes.Create, + "state_key": "", + "sender": creator_id, + "room_id": room_id, + "content": {}, + }) + + event, context = yield event_creation_handler.create_new_client_event( + builder + ) + + yield store.persist_event(event, context) |