diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 022d81ce3e..f65a27e5f1 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -457,8 +457,8 @@ class AuthTestCase(unittest.TestCase):
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
- self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
+ self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
# Ensure does not throw an error
@@ -468,11 +468,36 @@ class AuthTestCase(unittest.TestCase):
yield self.auth.check_auth_blocking()
@defer.inlineCallbacks
+ def test_reserved_threepid(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.max_mau_value = 1
+ threepid = {'medium': 'email', 'address': 'reserved@server.com'}
+ unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
+ self.hs.config.mau_limits_reserved_threepids = [threepid]
+
+ yield self.store.register(user_id='user1', token="123", password_hash=None)
+ with self.assertRaises(ResourceLimitError):
+ yield self.auth.check_auth_blocking()
+
+ with self.assertRaises(ResourceLimitError):
+ yield self.auth.check_auth_blocking(threepid=unknown_threepid)
+
+ yield self.auth.check_auth_blocking(threepid=threepid)
+
+ @defer.inlineCallbacks
def test_hs_disabled(self):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
- self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
+ self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
+
+ @defer.inlineCallbacks
+ def test_server_notices_mxid_special_cased(self):
+ self.hs.config.hs_disabled = True
+ user = "@user:server"
+ self.hs.config.server_notices_mxid = user
+ self.hs.config.hs_disabled_message = "Reason for being disabled"
+ yield self.auth.check_auth_blocking(user)
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 56e7acd37c..a3aa0a1cf2 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,79 +14,79 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import synapse.api.errors
import synapse.handlers.device
import synapse.storage
-from tests import unittest, utils
+from tests import unittest
user1 = "@boris:aaa"
user2 = "@theresa:bbb"
-class DeviceTestCase(unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(DeviceTestCase, self).__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
- self.handler = None # type: synapse.handlers.device.DeviceHandler
- self.clock = None # type: utils.MockClock
-
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield utils.setup_test_homeserver(self.addCleanup)
+class DeviceTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
- self.clock = hs.get_clock()
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ # These tests assume that it starts 1000 seconds in.
+ self.reactor.advance(1000)
- @defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self):
- res = yield self.handler.check_device_registered(
- user_id="@boris:foo",
- device_id="fco",
- initial_device_display_name="display name",
+ res = self.get_success(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="fco",
+ initial_device_display_name="display name",
+ )
)
self.assertEqual(res, "fco")
- dev = yield self.handler.store.get_device("@boris:foo", "fco")
+ dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
- @defer.inlineCallbacks
def test_device_is_preserved_if_exists(self):
- res1 = yield self.handler.check_device_registered(
- user_id="@boris:foo",
- device_id="fco",
- initial_device_display_name="display name",
+ res1 = self.get_success(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="fco",
+ initial_device_display_name="display name",
+ )
)
self.assertEqual(res1, "fco")
- res2 = yield self.handler.check_device_registered(
- user_id="@boris:foo",
- device_id="fco",
- initial_device_display_name="new display name",
+ res2 = self.get_success(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="fco",
+ initial_device_display_name="new display name",
+ )
)
self.assertEqual(res2, "fco")
- dev = yield self.handler.store.get_device("@boris:foo", "fco")
+ dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
- @defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self):
- device_id = yield self.handler.check_device_registered(
- user_id="@theresa:foo",
- device_id=None,
- initial_device_display_name="display",
+ device_id = self.get_success(
+ self.handler.check_device_registered(
+ user_id="@theresa:foo",
+ device_id=None,
+ initial_device_display_name="display",
+ )
)
- dev = yield self.handler.store.get_device("@theresa:foo", device_id)
+ dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
self.assertEqual(dev["display_name"], "display")
- @defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield self._record_users()
+ self._record_users()
+
+ res = self.get_success(self.handler.get_devices_by_user(user1))
- res = yield self.handler.get_devices_by_user(user1)
self.assertEqual(3, len(res))
device_map = {d["device_id"]: d for d in res}
self.assertDictContainsSubset(
@@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase):
device_map["abc"],
)
- @defer.inlineCallbacks
def test_get_device(self):
- yield self._record_users()
+ self._record_users()
- res = yield self.handler.get_device(user1, "abc")
+ res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertDictContainsSubset(
{
"user_id": user1,
@@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase):
res,
)
- @defer.inlineCallbacks
def test_delete_device(self):
- yield self._record_users()
+ self._record_users()
# delete the device
- yield self.handler.delete_device(user1, "abc")
+ self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
- with self.assertRaises(synapse.api.errors.NotFoundError):
- yield self.handler.get_device(user1, "abc")
+ res = self.handler.get_device(user1, "abc")
+ self.pump()
+ self.assertIsInstance(
+ self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ )
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
- @defer.inlineCallbacks
def test_update_device(self):
- yield self._record_users()
+ self._record_users()
update = {"display_name": "new display"}
- yield self.handler.update_device(user1, "abc", update)
+ self.get_success(self.handler.update_device(user1, "abc", update))
- res = yield self.handler.get_device(user1, "abc")
+ res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display")
- @defer.inlineCallbacks
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
- with self.assertRaises(synapse.api.errors.NotFoundError):
- yield self.handler.update_device("user_id", "unknown_device_id", update)
+ res = self.handler.update_device("user_id", "unknown_device_id", update)
+ self.pump()
+ self.assertIsInstance(
+ self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ )
- @defer.inlineCallbacks
def _record_users(self):
# check this works for both devices which have a recorded client_ip,
# and those which don't.
- yield self._record_user(user1, "xyz", "display 0")
- yield self._record_user(user1, "fco", "display 1", "token1", "ip1")
- yield self._record_user(user1, "abc", "display 2", "token2", "ip2")
- yield self._record_user(user1, "abc", "display 2", "token3", "ip3")
+ self._record_user(user1, "xyz", "display 0")
+ self._record_user(user1, "fco", "display 1", "token1", "ip1")
+ self._record_user(user1, "abc", "display 2", "token2", "ip2")
+ self._record_user(user1, "abc", "display 2", "token3", "ip3")
+
+ self._record_user(user2, "def", "dispkay", "token4", "ip4")
- yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
+ self.reactor.advance(10000)
- @defer.inlineCallbacks
def _record_user(
self, user_id, device_id, display_name, access_token=None, ip=None
):
- device_id = yield self.handler.check_device_registered(
- user_id=user_id,
- device_id=device_id,
- initial_device_display_name=display_name,
+ device_id = self.get_success(
+ self.handler.check_device_registered(
+ user_id=user_id,
+ device_id=device_id,
+ initial_device_display_name=display_name,
+ )
)
if ip is not None:
- yield self.store.insert_client_ip(
- user_id, access_token, ip, "user_agent", device_id
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, access_token, ip, "user_agent", device_id
+ )
)
- self.clock.advance_time(1000)
+ self.reactor.advance(1000)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 62dc69003c..80da1c8954 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -20,7 +20,7 @@ from twisted.internet import defer
import synapse.types
from synapse.api.errors import AuthError
-from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
@@ -29,7 +29,7 @@ from tests.utils import setup_test_homeserver
class ProfileHandlers(object):
def __init__(self, hs):
- self.profile_handler = ProfileHandler(hs)
+ self.profile_handler = MasterProfileHandler(hs)
class ProfileTestCase(unittest.TestCase):
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index a01ab471f5..31f54bbd7d 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -51,7 +51,7 @@ class SyncTestCase(tests.unittest.TestCase):
self.hs.config.hs_disabled = True
with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
+ self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False
@@ -59,7 +59,7 @@ class SyncTestCase(tests.unittest.TestCase):
with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
+ self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 65df116efc..089cecfbee 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,89 +12,91 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import tempfile
from mock import Mock, NonCallableMock
-from twisted.internet import defer, reactor
-from twisted.internet.defer import Deferred
+import attr
from synapse.replication.tcp.client import (
ReplicationClientFactory,
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
from tests import unittest
-from tests.utils import setup_test_homeserver
-class TestReplicationClientHandler(ReplicationClientHandler):
- """Overrides on_rdata so that we can wait for it to happen"""
+class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
- def __init__(self, store):
- super(TestReplicationClientHandler, self).__init__(store)
- self._rdata_awaiters = []
-
- def await_replication(self):
- d = Deferred()
- self._rdata_awaiters.append(d)
- return make_deferred_yieldable(d)
-
- def on_rdata(self, stream_name, token, rows):
- awaiters = self._rdata_awaiters
- self._rdata_awaiters = []
- super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
- with PreserveLoggingContext():
- for a in awaiters:
- a.callback(None)
-
-
-class BaseSlavedStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(
- self.addCleanup,
+ hs = self.setup_test_homeserver(
"blue",
- http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
- self.hs.get_ratelimiter().send_message.return_value = (True, 0)
+
+ hs.get_ratelimiter().send_message.return_value = (True, 0)
+
+ return hs
+
+ def prepare(self, reactor, clock, hs):
self.master_store = self.hs.get_datastore()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
- # XXX: mktemp is unsafe and should never be used. but we're just a test.
- path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
- listener = reactor.listenUNIX(path, server_factory)
- self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
- self.replication_handler = TestReplicationClientHandler(self.slaved_store)
+ self.replication_handler = ReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
- client_connector = reactor.connectUNIX(path, client_factory)
- self.addCleanup(client_factory.stopTrying)
- self.addCleanup(client_connector.disconnect)
+
+ server = server_factory.buildProtocol(None)
+ client = client_factory.buildProtocol(None)
+
+ @attr.s
+ class FakeTransport(object):
+
+ other = attr.ib()
+ disconnecting = False
+ buffer = attr.ib(default=b'')
+
+ def registerProducer(self, producer, streaming):
+
+ self.producer = producer
+
+ def _produce():
+ self.producer.resumeProducing()
+ reactor.callLater(0.1, _produce)
+
+ reactor.callLater(0.0, _produce)
+
+ def write(self, byt):
+ self.buffer = self.buffer + byt
+
+ if getattr(self.other, "transport") is not None:
+ self.other.dataReceived(self.buffer)
+ self.buffer = b""
+
+ def writeSequence(self, seq):
+ for x in seq:
+ self.write(x)
+
+ client.makeConnection(FakeTransport(server))
+ server.makeConnection(FakeTransport(client))
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
- # xxx: should we be more specific in what we wait for?
- d = self.replication_handler.await_replication()
self.streamer.on_notifier_poke()
- return d
+ self.pump(0.1)
- @defer.inlineCallbacks
def check(self, method, args, expected_result=None):
- master_result = yield getattr(self.master_store, method)(*args)
- slaved_result = yield getattr(self.slaved_store, method)(*args)
+ master_result = self.get_success(getattr(self.master_store, method)(*args))
+ slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
index 87cc2b2fba..43e3248703 100644
--- a/tests/replication/slave/storage/test_account_data.py
+++ b/tests/replication/slave/storage/test_account_data.py
@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from ._base import BaseSlavedStoreTestCase
@@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedAccountDataStore
- @defer.inlineCallbacks
def test_user_account_data(self):
- yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
- yield self.replicate()
- yield self.check(
+ self.get_success(
+ self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
+ )
+ self.replicate()
+ self.check(
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
)
- yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
- yield self.replicate()
- yield self.check(
+ self.get_success(
+ self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
+ )
+ self.replicate()
+ self.check(
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 622be2eef8..db44d33c68 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
from synapse.replication.slave.storage.events import SlavedEventStore
@@ -55,69 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def tearDown(self):
[unpatch() for unpatch in self.unpatches]
- @defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self):
- create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
- yield self.replicate()
- yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
+ create = self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.replicate()
+ self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
- join = yield self.persist(
+ join = self.persist(
type="m.room.member",
key=USER_ID,
membership="join",
prev_events=[(create.event_id, {})],
)
- yield self.replicate()
- yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
+ self.replicate()
+ self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
- @defer.inlineCallbacks
def test_redactions(self):
- yield self.persist(type="m.room.create", key="", creator=USER_ID)
- yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
- msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
- yield self.replicate()
- yield self.check("get_event", [msg.event_id], msg)
+ msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
+ self.replicate()
+ self.check("get_event", [msg.event_id], msg)
- redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id)
- yield self.replicate()
+ redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
+ self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
- yield self.check("get_event", [msg.event_id], redacted)
+ self.check("get_event", [msg.event_id], redacted)
- @defer.inlineCallbacks
def test_backfilled_redactions(self):
- yield self.persist(type="m.room.create", key="", creator=USER_ID)
- yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
- msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
- yield self.replicate()
- yield self.check("get_event", [msg.event_id], msg)
+ msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
+ self.replicate()
+ self.check("get_event", [msg.event_id], msg)
- redaction = yield self.persist(
+ redaction = self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
)
- yield self.replicate()
+ self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
- yield self.check("get_event", [msg.event_id], redacted)
+ self.check("get_event", [msg.event_id], redacted)
- @defer.inlineCallbacks
def test_invites(self):
- yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
- event = yield self.persist(
- type="m.room.member", key=USER_ID_2, membership="invite"
- )
- yield self.replicate()
- yield self.check(
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+ event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
+
+ self.replicate()
+
+ self.check(
"get_invited_rooms_for_user",
[USER_ID_2],
[
@@ -131,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
],
)
- @defer.inlineCallbacks
def test_push_actions_for_user(self):
- yield self.persist(type="m.room.create", creator=USER_ID)
- yield self.persist(type="m.room.join", key=USER_ID, membership="join")
- yield self.persist(
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.join", key=USER_ID, membership="join")
+ self.persist(
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
)
- event1 = yield self.persist(
- type="m.room.message", msgtype="m.text", body="hello"
- )
- yield self.replicate()
- yield self.check(
+ event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
+ self.replicate()
+ self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 0},
)
- yield self.persist(
+ self.persist(
type="m.room.message",
msgtype="m.text",
body="world",
push_actions=[(USER_ID_2, ["notify"])],
)
- yield self.replicate()
- yield self.check(
+ self.replicate()
+ self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 1},
)
- yield self.persist(
+ self.persist(
type="m.room.message",
msgtype="m.text",
body="world",
@@ -169,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
(USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
],
)
- yield self.replicate()
- yield self.check(
+ self.replicate()
+ self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 1, "notify_count": 2},
@@ -178,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0
- @defer.inlineCallbacks
def persist(
self,
sender=USER_ID,
@@ -205,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
depth = self.event_id
if not prev_events:
- latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
- room_id
+ latest_event_ids = self.get_success(
+ self.master_store.get_latest_event_ids_in_room(room_id)
)
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
@@ -239,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
else:
state_handler = self.hs.get_state_handler()
- context = yield state_handler.compute_event_context(event)
+ context = self.get_success(state_handler.compute_event_context(event))
- yield self.master_store.add_push_actions_to_staging(
+ self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
)
ordering = None
if backfill:
- yield self.master_store.persist_events([(event, context)], backfilled=True)
+ self.get_success(
+ self.master_store.persist_events([(event, context)], backfilled=True)
+ )
else:
- ordering, _ = yield self.master_store.persist_event(event, context)
+ ordering, _ = self.get_success(
+ self.master_store.persist_event(event, context)
+ )
if ordering:
event.internal_metadata.stream_ordering = ordering
- defer.returnValue(event)
+ return event
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index ae1adeded1..f47d94f690 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from ._base import BaseSlavedStoreTestCase
@@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedReceiptsStore
- @defer.inlineCallbacks
def test_receipt(self):
- yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
- yield self.master_store.insert_receipt(
- ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
- )
- yield self.replicate()
- yield self.check(
- "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
+ self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
+ self.get_success(
+ self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
)
+ self.replicate()
+ self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 40dc4ea256..530dc8ba6d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -240,7 +240,6 @@ class RestHelper(object):
self.assertEquals(200, code)
defer.returnValue(response)
- @defer.inlineCallbacks
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -248,9 +247,16 @@ class RestHelper(object):
body = "body_text_here"
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
- content = '{"msgtype":"m.text","body":"%s"}' % body
+ content = {"msgtype": "m.text", "body": body}
if tok:
path = path + "?access_token=%s" % tok
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(expect_code, code, msg=str(response))
+ request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
+ render(request, self.resource, self.hs.get_reactor())
+
+ assert int(channel.result["code"]) == expect_code, (
+ "Expected: %d, got: %d, resp: %r"
+ % (expect_code, int(channel.result["code"]), channel.result["body"])
+ )
+
+ return channel.json_body
diff --git a/tests/server.py b/tests/server.py
index c63b2c3100..615bba1b59 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -5,7 +5,7 @@ from six import text_type
import attr
-from twisted.internet import threads
+from twisted.internet import address, threads
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock
@@ -63,7 +63,9 @@ class FakeChannel(object):
self.result["done"] = True
def getPeer(self):
- return None
+ # We give an address so that getClientIP returns a non null entry,
+ # causing us to record the MAU
+ return address.IPv4Address(b"TCP", "127.0.0.1", 3423)
def getHost(self):
return None
@@ -91,7 +93,7 @@ class FakeSite:
return FakeLogger()
-def make_request(method, path, content=b""):
+def make_request(method, path, content=b"", access_token=None):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
@@ -116,6 +118,11 @@ def make_request(method, path, content=b""):
req = SynapseRequest(site, channel)
req.process = lambda: b""
req.content = BytesIO(content)
+
+ if access_token:
+ req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
+
+ req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1")
req.requestReceived(method, path, b"1.1")
return req, channel
@@ -225,6 +232,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool()
+ pool.running = True
return d
diff --git a/tests/server_notices/__init__.py b/tests/server_notices/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/server_notices/__init__.py
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
new file mode 100644
index 0000000000..5cc7fff39b
--- /dev/null
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -0,0 +1,213 @@
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, ServerNoticeMsgType
+from synapse.api.errors import ResourceLimitError
+from synapse.handlers.auth import AuthHandler
+from synapse.server_notices.resource_limits_server_notices import (
+ ResourceLimitsServerNotices,
+)
+
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+
+class AuthHandlers(object):
+ def __init__(self, hs):
+ self.auth_handler = AuthHandler(hs)
+
+
+class TestResourceLimitsServerNotices(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
+ self.hs.handlers = AuthHandlers(self.hs)
+ self.auth_handler = self.hs.handlers.auth_handler
+ self.server_notices_sender = self.hs.get_server_notices_sender()
+
+ # relying on [1] is far from ideal, but the only case where
+ # ResourceLimitsServerNotices class needs to be isolated is this test,
+ # general code should never have a reason to do so ...
+ self._rlsn = self.server_notices_sender._server_notices[1]
+ if not isinstance(self._rlsn, ResourceLimitsServerNotices):
+ raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+
+ self._rlsn._store.user_last_seen_monthly_active = Mock(
+ return_value=defer.succeed(1000)
+ )
+ self._send_notice = self._rlsn._server_notices_manager.send_notice
+ self._rlsn._server_notices_manager.send_notice = Mock()
+ self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
+ self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
+
+ self._send_notice = self._rlsn._server_notices_manager.send_notice
+
+ self.hs.config.limit_usage_by_mau = True
+ self.user_id = "@user_id:test"
+
+ # self.server_notices_mxid = "@server:test"
+ # self.server_notices_mxid_display_name = None
+ # self.server_notices_mxid_avatar_url = None
+ # self.server_notices_room_name = "Server Notices"
+
+ self._rlsn._server_notices_manager.get_notice_room_for_user = Mock(
+ returnValue=""
+ )
+ self._rlsn._store.add_tag_to_room = Mock()
+ self._rlsn._store.get_tags_for_room = Mock(return_value={})
+ self.hs.config.admin_contact = "mailto:user@test.com"
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_flag_off(self):
+ """Tests cases where the flags indicate nothing to do"""
+ # test hs disabled case
+ self.hs.config.hs_disabled = True
+
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ self._send_notice.assert_not_called()
+ # Test when mau limiting disabled
+ self.hs.config.hs_disabled = False
+ self.hs.limit_usage_by_mau = False
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ self._send_notice.assert_not_called()
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
+ """Test when user has blocked notice, but should have it removed"""
+
+ self._rlsn._auth.check_auth_blocking = Mock()
+ mock_event = Mock(
+ type=EventTypes.Message,
+ content={"msgtype": ServerNoticeMsgType},
+ )
+ self._rlsn._store.get_events = Mock(return_value=defer.succeed(
+ {"123": mock_event}
+ ))
+
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ # Would be better to check the content, but once == remove blocking event
+ self._send_notice.assert_called_once()
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
+ """Test when user has blocked notice, but notice ought to be there (NOOP)"""
+ self._rlsn._auth.check_auth_blocking = Mock(
+ side_effect=ResourceLimitError(403, 'foo')
+ )
+
+ mock_event = Mock(
+ type=EventTypes.Message,
+ content={"msgtype": ServerNoticeMsgType},
+ )
+ self._rlsn._store.get_events = Mock(return_value=defer.succeed(
+ {"123": mock_event}
+ ))
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ self._send_notice.assert_not_called()
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
+ """Test when user does not have blocked notice, but should have one"""
+
+ self._rlsn._auth.check_auth_blocking = Mock(
+ side_effect=ResourceLimitError(403, 'foo')
+ )
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ # Would be better to check contents, but 2 calls == set blocking event
+ self.assertTrue(self._send_notice.call_count == 2)
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
+ """Test when user does not have blocked notice, nor should they (NOOP)"""
+
+ self._rlsn._auth.check_auth_blocking = Mock()
+
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ self._send_notice.assert_not_called()
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
+
+ """Test when user is not part of the MAU cohort - this should not ever
+ happen - but ...
+ """
+
+ self._rlsn._auth.check_auth_blocking = Mock()
+ self._rlsn._store.user_last_seen_monthly_active = Mock(
+ return_value=defer.succeed(None)
+ )
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ self._send_notice.assert_not_called()
+
+
+class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver(self.addCleanup)
+ self.store = self.hs.get_datastore()
+ self.server_notices_sender = self.hs.get_server_notices_sender()
+ self.server_notices_manager = self.hs.get_server_notices_manager()
+ self.event_source = self.hs.get_event_sources()
+
+ # relying on [1] is far from ideal, but the only case where
+ # ResourceLimitsServerNotices class needs to be isolated is this test,
+ # general code should never have a reason to do so ...
+ self._rlsn = self.server_notices_sender._server_notices[1]
+ if not isinstance(self._rlsn, ResourceLimitsServerNotices):
+ raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.hs_disabled = False
+ self.hs.config.max_mau_value = 5
+ self.hs.config.server_notices_mxid = "@server:test"
+ self.hs.config.server_notices_mxid_display_name = None
+ self.hs.config.server_notices_mxid_avatar_url = None
+ self.hs.config.server_notices_room_name = "Test Server Notice Room"
+
+ self.user_id = "@user_id:test"
+
+ self.hs.config.admin_contact = "mailto:user@test.com"
+
+ @defer.inlineCallbacks
+ def test_server_notice_only_sent_once(self):
+ self.store.get_monthly_active_count = Mock(
+ return_value=1000,
+ )
+
+ self.store.user_last_seen_monthly_active = Mock(
+ return_value=1000,
+ )
+
+ # Call the function multiple times to ensure we only send the notice once
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ # Now lets get the last load of messages in the service notice room and
+ # check that there is only one server notice
+ room_id = yield self.server_notices_manager.get_notice_room_for_user(
+ self.user_id,
+ )
+
+ token = yield self.event_source.get_current_token()
+ events, _ = yield self.store.get_recent_events_for_room(
+ room_id, limit=100, end_token=token.room_key,
+ )
+
+ count = 0
+ for event in events:
+ if event.type != EventTypes.Message:
+ continue
+ if event.content.get("msgtype") != ServerNoticeMsgType:
+ continue
+
+ count += 1
+
+ self.assertEqual(count, 1)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index c893990454..3f0083831b 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.as_yaml_files = []
- config = Mock(
- app_service_config_files=self.as_yaml_files,
- event_cache_size=1,
- password_providers=[],
- )
hs = yield setup_test_homeserver(
- self.addCleanup,
- config=config,
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+ hs.config.app_service_config_files = self.as_yaml_files
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
self.as_token = "token1"
self.as_url = "some_url"
self.as_id = "as1"
@@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- self.store = ApplicationServiceStore(None, hs)
+ self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def setUp(self):
self.as_yaml_files = []
- config = Mock(
- app_service_config_files=self.as_yaml_files,
- event_cache_size=1,
- password_providers=[],
- )
hs = yield setup_test_homeserver(
- self.addCleanup,
- config=config,
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+
+ hs.config.app_service_config_files = self.as_yaml_files
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
self.db_pool = hs.get_db_pool()
+ self.engine = hs.database_engine
self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
@@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
- self.store = TestTransactionStore(None, hs)
+ self.store = TestTransactionStore(hs.get_db_conn(), hs)
def _add_service(self, url, as_token, id):
as_yaml = dict(
@@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files.append(as_token)
def _set_state(self, id, state, txn=None):
- return self.db_pool.runQuery(
- "INSERT INTO application_services_state(as_id, state, last_txn) "
- "VALUES(?,?,?)",
+ return self.db_pool.runOperation(
+ self.engine.convert_param_style(
+ "INSERT INTO application_services_state(as_id, state, last_txn) "
+ "VALUES(?,?,?)"
+ ),
(id, state, txn),
)
def _insert_txn(self, as_id, txn_id, events):
- return self.db_pool.runQuery(
- "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
- "VALUES(?,?,?)",
+ return self.db_pool.runOperation(
+ self.engine.convert_param_style(
+ "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
+ "VALUES(?,?,?)"
+ ),
(as_id, txn_id, json.dumps([e.event_id for e in events])),
)
def _set_last_txn(self, as_id, txn_id):
- return self.db_pool.runQuery(
- "INSERT INTO application_services_state(as_id, last_txn, state) "
- "VALUES(?,?,?)",
+ return self.db_pool.runOperation(
+ self.engine.convert_param_style(
+ "INSERT INTO application_services_state(as_id, last_txn, state) "
+ "VALUES(?,?,?)"
+ ),
(as_id, txn_id, ApplicationServiceState.UP),
)
@defer.inlineCallbacks
def test_get_appservice_state_none(self):
- service = Mock(id=999)
+ service = Mock(id="999")
state = yield self.store.get_appservice_state(service)
self.assertEquals(None, state)
@@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
rows = yield self.db_pool.runQuery(
- "SELECT as_id FROM application_services_state WHERE state=?",
+ self.engine.convert_param_style(
+ "SELECT as_id FROM application_services_state WHERE state=?"
+ ),
(ApplicationServiceState.DOWN,),
)
self.assertEquals(service.id, rows[0][0])
@@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
rows = yield self.db_pool.runQuery(
- "SELECT as_id FROM application_services_state WHERE state=?",
+ self.engine.convert_param_style(
+ "SELECT as_id FROM application_services_state WHERE state=?"
+ ),
(ApplicationServiceState.UP,),
)
self.assertEquals(service.id, rows[0][0])
@@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
- "SELECT last_txn FROM application_services_state WHERE as_id=?",
+ self.engine.convert_param_style(
+ "SELECT last_txn FROM application_services_state WHERE as_id=?"
+ ),
(service.id,),
)
self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0])
res = yield self.db_pool.runQuery(
- "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
+ self.engine.convert_param_style(
+ "SELECT * FROM application_services_txns WHERE txn_id=?"
+ ),
+ (txn_id,),
)
self.assertEquals(0, len(res))
@@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
- "SELECT last_txn, state FROM application_services_state WHERE " "as_id=?",
+ self.engine.convert_param_style(
+ "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
+ ),
(service.id,),
)
self.assertEquals(1, len(res))
@@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.assertEquals(ApplicationServiceState.UP, res[0][1])
res = yield self.db_pool.runQuery(
- "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
+ self.engine.convert_param_style(
+ "SELECT * FROM application_services_txns WHERE txn_id=?"
+ ),
+ (txn_id,),
)
self.assertEquals(0, len(res))
@@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2")
- config = Mock(
- app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
- )
hs = yield setup_test_homeserver(
- self.addCleanup,
- config=config,
- datastore=Mock(),
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
- ApplicationServiceStore(None, hs)
+ hs.config.app_service_config_files = [f1, f2]
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
+ ApplicationServiceStore(hs.get_db_conn(), hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2")
- config = Mock(
- app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
- )
hs = yield setup_test_homeserver(
- self.addCleanup,
- config=config,
- datastore=Mock(),
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+ hs.config.app_service_config_files = [f1, f2]
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(None, hs)
+ ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
@@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2")
- config = Mock(
- app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
- )
hs = yield setup_test_homeserver(
- self.addCleanup,
- config=config,
- datastore=Mock(),
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+ hs.config.app_service_config_files = [f1, f2]
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(None, hs)
+ ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 7cb5f0e4cf..829f47d2e8 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -20,11 +20,11 @@ from mock import Mock
from twisted.internet import defer
-from synapse.server import HomeServer
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import create_engine
from tests import unittest
+from tests.utils import TestHomeServer
class SQLBaseStoreTestCase(unittest.TestCase):
@@ -51,7 +51,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config.event_cache_size = 1
config.database_config = {"name": "sqlite3"}
- hs = HomeServer(
+ hs = TestHomeServer(
"test",
db_pool=self.db_pool,
config=config,
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index b4510c1c8d..4e128e1047 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -16,7 +16,6 @@
from twisted.internet import defer
-from synapse.storage.directory import DirectoryStore
from synapse.types import RoomAlias, RoomID
from tests import unittest
@@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.store = DirectoryStore(None, hs)
+ self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 2fdf34fdf6..0d4e74d637 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
(
"INSERT INTO events ("
" room_id, event_id, type, depth, topological_ordering,"
- " content, processed, outlier) "
- "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
+ " content, processed, outlier, stream_ordering) "
+ "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)"
),
- (room_id, event_id, i, i, True, False),
+ (room_id, event_id, i, i, True, False, i),
)
txn.execute(
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index f2ed866ae7..2036287288 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -13,25 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-import tests.unittest
-import tests.utils
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase
FORTY_DAYS = 40 * 24 * 60 * 60
-class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs)
+class MonthlyActiveUsersTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+
+ hs = self.setup_test_homeserver()
+ self.store = hs.get_datastore()
+
+ # Advance the clock a bit
+ reactor.advance(FORTY_DAYS)
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = self.hs.get_datastore()
+ return hs
- @defer.inlineCallbacks
def test_initialise_reserved_users(self):
self.hs.config.max_mau_value = 5
user1 = "@user1:server"
@@ -44,88 +41,101 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
]
user_num = len(threepids)
- yield self.store.register(user_id=user1, token="123", password_hash=None)
-
- yield self.store.register(user_id=user2, token="456", password_hash=None)
+ self.store.register(user_id=user1, token="123", password_hash=None)
+ self.store.register(user_id=user2, token="456", password_hash=None)
+ self.pump()
now = int(self.hs.get_clock().time_msec())
- yield self.store.user_add_threepid(user1, "email", user1_email, now, now)
- yield self.store.user_add_threepid(user2, "email", user2_email, now, now)
- yield self.store.initialise_reserved_users(threepids)
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ self.store.initialise_reserved_users(threepids)
+ self.pump()
- active_count = yield self.store.get_monthly_active_count()
+ active_count = self.store.get_monthly_active_count()
# Test total counts
- self.assertEquals(active_count, user_num)
+ self.assertEquals(self.get_success(active_count), user_num)
# Test user is marked as active
-
- timestamp = yield self.store.user_last_seen_monthly_active(user1)
- self.assertTrue(timestamp)
- timestamp = yield self.store.user_last_seen_monthly_active(user2)
- self.assertTrue(timestamp)
+ timestamp = self.store.user_last_seen_monthly_active(user1)
+ self.assertTrue(self.get_success(timestamp))
+ timestamp = self.store.user_last_seen_monthly_active(user2)
+ self.assertTrue(self.get_success(timestamp))
# Test that users are never removed from the db.
self.hs.config.max_mau_value = 0
- self.hs.get_clock().advance_time(FORTY_DAYS)
+ self.reactor.advance(FORTY_DAYS)
- yield self.store.reap_monthly_active_users()
+ self.store.reap_monthly_active_users()
+ self.pump()
- active_count = yield self.store.get_monthly_active_count()
- self.assertEquals(active_count, user_num)
+ active_count = self.store.get_monthly_active_count()
+ self.assertEquals(self.get_success(active_count), user_num)
# Test that regalar users are removed from the db
ru_count = 2
- yield self.store.upsert_monthly_active_user("@ru1:server")
- yield self.store.upsert_monthly_active_user("@ru2:server")
- active_count = yield self.store.get_monthly_active_count()
+ self.store.upsert_monthly_active_user("@ru1:server")
+ self.store.upsert_monthly_active_user("@ru2:server")
+ self.pump()
- self.assertEqual(active_count, user_num + ru_count)
+ active_count = self.store.get_monthly_active_count()
+ self.assertEqual(self.get_success(active_count), user_num + ru_count)
self.hs.config.max_mau_value = user_num
- yield self.store.reap_monthly_active_users()
+ self.store.reap_monthly_active_users()
+ self.pump()
- active_count = yield self.store.get_monthly_active_count()
- self.assertEquals(active_count, user_num)
+ active_count = self.store.get_monthly_active_count()
+ self.assertEquals(self.get_success(active_count), user_num)
- @defer.inlineCallbacks
def test_can_insert_and_count_mau(self):
- count = yield self.store.get_monthly_active_count()
- self.assertEqual(0, count)
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(0, self.get_success(count))
- yield self.store.upsert_monthly_active_user("@user:server")
- count = yield self.store.get_monthly_active_count()
+ self.store.upsert_monthly_active_user("@user:server")
+ self.pump()
- self.assertEqual(1, count)
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(1, self.get_success(count))
- @defer.inlineCallbacks
def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
- result = yield self.store.user_last_seen_monthly_active(user_id1)
- self.assertFalse(result == 0)
- yield self.store.upsert_monthly_active_user(user_id1)
- yield self.store.upsert_monthly_active_user(user_id2)
- result = yield self.store.user_last_seen_monthly_active(user_id1)
- self.assertTrue(result > 0)
- result = yield self.store.user_last_seen_monthly_active(user_id3)
- self.assertFalse(result == 0)
+ result = self.store.user_last_seen_monthly_active(user_id1)
+ self.assertFalse(self.get_success(result) == 0)
+
+ self.store.upsert_monthly_active_user(user_id1)
+ self.store.upsert_monthly_active_user(user_id2)
+ self.pump()
+
+ result = self.store.user_last_seen_monthly_active(user_id1)
+ self.assertGreater(self.get_success(result), 0)
+
+ result = self.store.user_last_seen_monthly_active(user_id3)
+ self.assertNotEqual(self.get_success(result), 0)
- @defer.inlineCallbacks
def test_reap_monthly_active_users(self):
self.hs.config.max_mau_value = 5
initial_users = 10
for i in range(initial_users):
- yield self.store.upsert_monthly_active_user("@user%d:server" % i)
- count = yield self.store.get_monthly_active_count()
- self.assertTrue(count, initial_users)
- yield self.store.reap_monthly_active_users()
- count = yield self.store.get_monthly_active_count()
- self.assertEquals(count, initial_users - self.hs.config.max_mau_value)
-
- self.hs.get_clock().advance_time(FORTY_DAYS)
- yield self.store.reap_monthly_active_users()
- count = yield self.store.get_monthly_active_count()
- self.assertEquals(count, 0)
+ self.store.upsert_monthly_active_user("@user%d:server" % i)
+ self.pump()
+
+ count = self.store.get_monthly_active_count()
+ self.assertTrue(self.get_success(count), initial_users)
+
+ self.store.reap_monthly_active_users()
+ self.pump()
+ count = self.store.get_monthly_active_count()
+ self.assertEquals(
+ self.get_success(count), initial_users - self.hs.config.max_mau_value
+ )
+
+ self.reactor.advance(FORTY_DAYS)
+ self.store.reap_monthly_active_users()
+ self.pump()
+
+ count = self.store.get_monthly_active_count()
+ self.assertEquals(self.get_success(count), 0)
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
index b5b58ff660..c7a63f39b9 100644
--- a/tests/storage/test_presence.py
+++ b/tests/storage/test_presence.py
@@ -16,19 +16,18 @@
from twisted.internet import defer
-from synapse.storage.presence import PresenceStore
from synapse.types import UserID
from tests import unittest
-from tests.utils import MockClock, setup_test_homeserver
+from tests.utils import setup_test_homeserver
class PresenceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock())
+ hs = yield setup_test_homeserver(self.addCleanup)
- self.store = PresenceStore(None, hs)
+ self.store = hs.get_datastore()
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index a1f6618bf9..45824bd3b2 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -28,7 +28,7 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.store = ProfileStore(None, hs)
+ self.store = ProfileStore(hs.get_db_conn(), hs)
self.u_frank = UserID.from_string("@frank:test")
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
new file mode 100644
index 0000000000..f671599cb8
--- /dev/null
+++ b/tests/storage/test_purge.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest.client.v1 import room
+
+from tests.unittest import HomeserverTestCase
+
+
+class PurgeTests(HomeserverTestCase):
+
+ user_id = "@red:server"
+ servlets = [room.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.room_id = self.helper.create_room_as(self.user_id)
+
+ def test_purge(self):
+ """
+ Purging a room will delete everything before the topological point.
+ """
+ # Send four messages to the room
+ first = self.helper.send(self.room_id, body="test1")
+ second = self.helper.send(self.room_id, body="test2")
+ third = self.helper.send(self.room_id, body="test3")
+ last = self.helper.send(self.room_id, body="test4")
+
+ storage = self.hs.get_datastore()
+
+ # Get the topological token
+ event = storage.get_topological_token_for_event(last["event_id"])
+ self.pump()
+ event = self.successResultOf(event)
+
+ # Purge everything before this topological token
+ purge = storage.purge_history(self.room_id, event, True)
+ self.pump()
+ self.assertEqual(self.successResultOf(purge), None)
+
+ # Try and get the events
+ get_first = storage.get_event(first["event_id"])
+ get_second = storage.get_event(second["event_id"])
+ get_third = storage.get_event(third["event_id"])
+ get_last = storage.get_event(last["event_id"])
+ self.pump()
+
+ # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
+ # and last is not.
+ self.failureResultOf(get_first)
+ self.failureResultOf(get_second)
+ self.failureResultOf(get_third)
+ self.successResultOf(get_last)
+
+ def test_purge_wont_delete_extrems(self):
+ """
+ Purging a room will delete everything before the topological point.
+ """
+ # Send four messages to the room
+ first = self.helper.send(self.room_id, body="test1")
+ second = self.helper.send(self.room_id, body="test2")
+ third = self.helper.send(self.room_id, body="test3")
+ last = self.helper.send(self.room_id, body="test4")
+
+ storage = self.hs.get_datastore()
+
+ # Set the topological token higher than it should be
+ event = storage.get_topological_token_for_event(last["event_id"])
+ self.pump()
+ event = self.successResultOf(event)
+ event = "t{}-{}".format(
+ *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
+ )
+
+ # Purge everything before this topological token
+ purge = storage.purge_history(self.room_id, event, True)
+ self.pump()
+ f = self.failureResultOf(purge)
+ self.assertIn("greater than forward", f.value.args[0])
+
+ # Try and get the events
+ get_first = storage.get_event(first["event_id"])
+ get_second = storage.get_event(second["event_id"])
+ get_third = storage.get_event(third["event_id"])
+ get_last = storage.get_event(last["event_id"])
+ self.pump()
+
+ # Nothing is deleted.
+ self.successResultOf(get_first)
+ self.successResultOf(get_second)
+ self.successResultOf(get_third)
+ self.successResultOf(get_last)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index c4e9fb72bf..02bf975fbf 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
class RedactionTestCase(unittest.TestCase):
@@ -41,6 +41,8 @@ class RedactionTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test")
+ yield create_room(hs, self.room1.to_string(), self.u_alice.to_string())
+
self.depth = 1
@defer.inlineCallbacks
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eda122edc..3dfb7b903a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -46,6 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None,
"consent_server_notice_sent": None,
"appservice_id": None,
+ "creation_ts": 1000,
},
(yield self.store.get_user_by_id(self.user_id)),
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index c83ef60062..978c66133d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
class RoomMemberStoreTestCase(unittest.TestCase):
@@ -45,6 +45,8 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
+ yield create_room(hs, self.room.to_string(), self.u_alice.to_string())
+
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None):
builder = self.event_builder_factory.new(
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index ebfd969b36..d717b9f94e 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -185,6 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[]
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group, [], filtered_types=[EventTypes.Member]
)
@@ -197,8 +198,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group, [], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {},
+ state_dict,
+ )
+
# test _get_some_state_from_cache correctly filters in members with wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
)
@@ -207,6 +220,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
{
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
(e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
@@ -216,6 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group,
[(EventTypes.Member, e5.state_key)],
filtered_types=[EventTypes.Member],
@@ -226,6 +252,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
{
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group,
+ [(EventTypes.Member, e5.state_key)],
+ filtered_types=[EventTypes.Member],
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
@@ -234,6 +274,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
)
@@ -254,9 +295,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
{
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
- (e3.type, e3.state_key): e3.event_id,
- # e4 is overwritten by e5
- (e5.type, e5.state_key): e5.event_id,
},
)
@@ -269,8 +307,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
# list fetched keys so it knows it's partial
fetched_keys=(
(e1.type, e1.state_key),
- (e3.type, e3.state_key),
- (e5.type, e5.state_key),
),
)
@@ -284,8 +320,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
set(
[
(e1.type, e1.state_key),
- (e3.type, e3.state_key),
- (e5.type, e5.state_key),
]
),
)
@@ -293,8 +327,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict_ids,
{
(e1.type, e1.state_key): e1.event_id,
- (e3.type, e3.state_key): e3.event_id,
- (e5.type, e5.state_key): e5.event_id,
},
)
@@ -304,14 +336,25 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[]
room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group, [], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ room_id = self.room.to_string()
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group, [], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({}, state_dict)
+
# test _get_some_state_from_cache correctly filters in members wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
)
@@ -319,8 +362,19 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
(e3.type, e3.state_key): e3.event_id,
- # e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
@@ -328,6 +382,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group,
[(EventTypes.Member, e5.state_key)],
filtered_types=[EventTypes.Member],
@@ -337,6 +392,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group,
+ [(EventTypes.Member, e5.state_key)],
+ filtered_types=[EventTypes.Member],
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
@@ -345,8 +414,22 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
+ group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+ )
+
+ self.assertEqual(is_all, False)
+ self.assertDictEqual({}, state_dict)
+
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
)
self.assertEqual(is_all, True)
- self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+ self.assertDictEqual(
+ {
+ (e5.type, e5.state_key): e5.event_id,
+ },
+ state_dict,
+ )
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index b46e0ea7e2..0dde1ab2fe 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -30,7 +30,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = UserDirectoryStore(None, self.hs)
+ self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
diff --git a/tests/test_mau.py b/tests/test_mau.py
new file mode 100644
index 0000000000..0732615447
--- /dev/null
+++ b/tests/test_mau.py
@@ -0,0 +1,217 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests REST events for /rooms paths."""
+
+import json
+
+from mock import Mock, NonCallableMock
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.http.server import JsonResource
+from synapse.rest.client.v2_alpha import register, sync
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import (
+ ThreadedMemoryReactorClock,
+ make_request,
+ render,
+ setup_test_homeserver,
+)
+
+
+class TestMauLimit(unittest.TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+ self.clock = Clock(self.reactor)
+
+ self.hs = setup_test_homeserver(
+ self.addCleanup,
+ "red",
+ http_client=None,
+ clock=self.clock,
+ reactor=self.reactor,
+ federation_client=Mock(),
+ ratelimiter=NonCallableMock(spec_set=["send_message"]),
+ )
+
+ self.store = self.hs.get_datastore()
+
+ self.hs.config.registrations_require_3pid = []
+ self.hs.config.enable_registration_captcha = False
+ self.hs.config.recaptcha_public_key = []
+
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.hs_disabled = False
+ self.hs.config.max_mau_value = 2
+ self.hs.config.mau_trial_days = 0
+ self.hs.config.server_notices_mxid = "@server:red"
+ self.hs.config.server_notices_mxid_display_name = None
+ self.hs.config.server_notices_mxid_avatar_url = None
+ self.hs.config.server_notices_room_name = "Test Server Notice Room"
+
+ self.resource = JsonResource(self.hs)
+ register.register_servlets(self.hs, self.resource)
+ sync.register_servlets(self.hs, self.resource)
+
+ def test_simple_deny_mau(self):
+ # Create and sync so that the MAU counts get updated
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+
+ # We've created and activated two users, we shouldn't be able to
+ # register new users
+ with self.assertRaises(SynapseError) as cm:
+ self.create_user("kermit3")
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def test_allowed_after_a_month_mau(self):
+ # Create and sync so that the MAU counts get updated
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+
+ # Advance time by 31 days
+ self.reactor.advance(31 * 24 * 60 * 60)
+
+ self.store.reap_monthly_active_users()
+
+ self.reactor.advance(0)
+
+ # We should be able to register more users
+ token3 = self.create_user("kermit3")
+ self.do_sync_for_user(token3)
+
+ def test_trial_delay(self):
+ self.hs.config.mau_trial_days = 1
+
+ # We should be able to register more than the limit initially
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+ token3 = self.create_user("kermit3")
+ self.do_sync_for_user(token3)
+
+ # Advance time by 2 days
+ self.reactor.advance(2 * 24 * 60 * 60)
+
+ # Two users should be able to sync
+ self.do_sync_for_user(token1)
+ self.do_sync_for_user(token2)
+
+ # But the third should fail
+ with self.assertRaises(SynapseError) as cm:
+ self.do_sync_for_user(token3)
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ # And new registrations are now denied too
+ with self.assertRaises(SynapseError) as cm:
+ self.create_user("kermit4")
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def test_trial_users_cant_come_back(self):
+ self.hs.config.mau_trial_days = 1
+
+ # We should be able to register more than the limit initially
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+ token3 = self.create_user("kermit3")
+ self.do_sync_for_user(token3)
+
+ # Advance time by 2 days
+ self.reactor.advance(2 * 24 * 60 * 60)
+
+ # Two users should be able to sync
+ self.do_sync_for_user(token1)
+ self.do_sync_for_user(token2)
+
+ # Advance by 2 months so everyone falls out of MAU
+ self.reactor.advance(60 * 24 * 60 * 60)
+ self.store.reap_monthly_active_users()
+ self.reactor.advance(0)
+
+ # We can create as many new users as we want
+ token4 = self.create_user("kermit4")
+ self.do_sync_for_user(token4)
+ token5 = self.create_user("kermit5")
+ self.do_sync_for_user(token5)
+ token6 = self.create_user("kermit6")
+ self.do_sync_for_user(token6)
+
+ # users 2 and 3 can come back to bring us back up to MAU limit
+ self.do_sync_for_user(token2)
+ self.do_sync_for_user(token3)
+
+ # New trial users can still sync
+ self.do_sync_for_user(token4)
+ self.do_sync_for_user(token5)
+ self.do_sync_for_user(token6)
+
+ # But old user cant
+ with self.assertRaises(SynapseError) as cm:
+ self.do_sync_for_user(token1)
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def create_user(self, localpart):
+ request_data = json.dumps({
+ "username": localpart,
+ "password": "monkey",
+ "auth": {"type": LoginType.DUMMY},
+ })
+
+ request, channel = make_request(b"POST", b"/register", request_data)
+ render(request, self.resource, self.reactor)
+
+ if channel.result["code"] != b"200":
+ raise HttpResponseException(
+ int(channel.result["code"]),
+ channel.result["reason"],
+ channel.result["body"],
+ ).to_synapse_error()
+
+ access_token = channel.json_body["access_token"]
+
+ return access_token
+
+ def do_sync_for_user(self, token):
+ request, channel = make_request(b"GET", b"/sync", access_token=token)
+ render(request, self.resource, self.reactor)
+
+ if channel.result["code"] != b"200":
+ raise HttpResponseException(
+ int(channel.result["code"]),
+ channel.result["reason"],
+ channel.result["body"],
+ ).to_synapse_error()
diff --git a/tests/test_state.py b/tests/test_state.py
index 96fdb8636c..452a123c3a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.auth import Auth
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.events import FrozenEvent
from synapse.state import StateHandler, StateResolutionHandler
@@ -117,6 +117,9 @@ class StateGroupStore(object):
def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group
+ def get_room_version(self, room_id):
+ return RoomVersions.V1
+
class DictObj(dict):
def __init__(self, **kwargs):
@@ -176,7 +179,9 @@ class StateTestCase(unittest.TestCase):
def test_branch_no_conflict(self):
graph = Graph(
nodes={
- "START": DictObj(type=EventTypes.Create, state_key="", depth=1),
+ "START": DictObj(
+ type=EventTypes.Create, state_key="", content={}, depth=1,
+ ),
"A": DictObj(type=EventTypes.Message, depth=2),
"B": DictObj(type=EventTypes.Message, depth=3),
"C": DictObj(type=EventTypes.Name, state_key="", depth=3),
diff --git a/tests/test_types.py b/tests/test_types.py
index be072d402b..0f5c8bfaf9 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -14,12 +14,12 @@
# limitations under the License.
from synapse.api.errors import SynapseError
-from synapse.server import HomeServer
from synapse.types import GroupID, RoomAlias, UserID
from tests import unittest
+from tests.utils import TestHomeServer
-mock_homeserver = HomeServer(hostname="my.domain")
+mock_homeserver = TestHomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase):
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 45a78338d6..2eea3b098b 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -21,7 +21,7 @@ from synapse.events import FrozenEvent
from synapse.visibility import filter_events_for_server
import tests.unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
logger = logging.getLogger(__name__)
@@ -36,6 +36,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
+ yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
+
@defer.inlineCallbacks
def test_filtering(self):
#
@@ -94,7 +96,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt)
# the erasey user gets erased
- self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+ yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
# ... and the filtering happens.
filtered = yield filter_events_for_server(
diff --git a/tests/unittest.py b/tests/unittest.py
index d852e2465a..a3d39920db 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -22,6 +22,7 @@ from canonicaljson import json
import twisted
import twisted.logger
+from twisted.internet.defer import Deferred
from twisted.trial import unittest
from synapse.http.server import JsonResource
@@ -151,6 +152,7 @@ class HomeserverTestCase(TestCase):
hijack_auth (bool): Whether to hijack auth to return the user specified
in user_id.
"""
+
servlets = []
hijack_auth = True
@@ -279,3 +281,15 @@ class HomeserverTestCase(TestCase):
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+
+ def pump(self, by=0.0):
+ """
+ Pump the reactor enough that Deferreds will fire.
+ """
+ self.reactor.pump([by] * 100)
+
+ def get_success(self, d):
+ if not isinstance(d, Deferred):
+ return d
+ self.pump()
+ return self.successResultOf(d)
diff --git a/tests/utils.py b/tests/utils.py
index 8de2898b2f..b85017d279 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -24,12 +24,14 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer, reactor
+from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
+from synapse.config.server import ServerConfig
from synapse.federation.transport import server
from synapse.http.server import HttpServer
from synapse.server import HomeServer
-from synapse.storage import PostgresEngine
-from synapse.storage.engines import create_engine
+from synapse.storage import DataStore
+from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import (
_get_or_create_schema_state,
_setup_new_database,
@@ -40,6 +42,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
+LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
@@ -91,10 +94,14 @@ def setupdb():
atexit.register(_cleanup)
+class TestHomeServer(HomeServer):
+ DATASTORE_CLASS = DataStore
+
+
@defer.inlineCallbacks
def setup_test_homeserver(
cleanup_func, name="test", datastore=None, config=None, reactor=None,
- homeserverToUse=HomeServer, **kargs
+ homeserverToUse=TestHomeServer, **kargs
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
@@ -141,7 +148,9 @@ def setup_test_homeserver(
config.hs_disabled_limit_type = ""
config.max_mau_value = 50
config.mau_limits_reserved_threepids = []
- config.admin_uri = None
+ config.admin_contact = None
+ config.rc_messages_per_second = 10000
+ config.rc_message_burst_count = 10000
# we need a sane default_room_version, otherwise attempts to create rooms will
# fail.
@@ -151,6 +160,11 @@ def setup_test_homeserver(
# background, which upsets the test runner.
config.update_user_directory = False
+ def is_threepid_reserved(threepid):
+ return ServerConfig.is_threepid_reserved(config, threepid)
+
+ config.is_threepid_reserved.side_effect = is_threepid_reserved
+
config.use_frozen_dicts = True
config.ldap_enabled = False
@@ -231,8 +245,9 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- # Register the cleanup hook
- cleanup_func(cleanup)
+ if not LEAVE_DB:
+ # Register the cleanup hook
+ cleanup_func(cleanup)
hs.setup()
else:
@@ -545,3 +560,32 @@ class DeferredMockCallable(object):
"Expected not to received any calls, got:\n"
+ "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
)
+
+
+@defer.inlineCallbacks
+def create_room(hs, room_id, creator_id):
+ """Creates and persist a creation event for the given room
+
+ Args:
+ hs
+ room_id (str)
+ creator_id (str)
+ """
+
+ store = hs.get_datastore()
+ event_builder_factory = hs.get_event_builder_factory()
+ event_creation_handler = hs.get_event_creation_handler()
+
+ builder = event_builder_factory.new({
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": creator_id,
+ "room_id": room_id,
+ "content": {},
+ })
+
+ event, context = yield event_creation_handler.create_new_client_event(
+ builder
+ )
+
+ yield store.persist_event(event, context)
|