diff options
Diffstat (limited to 'tests')
32 files changed, 1229 insertions, 786 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 54e396d19d..379e9c4ab1 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -468,6 +468,24 @@ class AuthTestCase(unittest.TestCase): yield self.auth.check_auth_blocking() @defer.inlineCallbacks + def test_reserved_threepid(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 1 + self.store.get_monthly_active_count = lambda: defer.succeed(2) + threepid = {'medium': 'email', 'address': 'reserved@server.com'} + unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'} + self.hs.config.mau_limits_reserved_threepids = [threepid] + + yield self.store.register(user_id='user1', token="123", password_hash=None) + with self.assertRaises(ResourceLimitError): + yield self.auth.check_auth_blocking() + + with self.assertRaises(ResourceLimitError): + yield self.auth.check_auth_blocking(threepid=unknown_threepid) + + yield self.auth.check_auth_blocking(threepid=threepid) + + @defer.inlineCallbacks def test_hs_disabled(self): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 76b5090fff..a83f567ebd 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -47,7 +47,7 @@ class FrontendProxyTests(HomeserverTestCase): self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] self.resource = ( - site.resource.children["_matrix"].children["client"].children["r0"] + site.resource.children[b"_matrix"].children[b"client"].children[b"r0"] ) request, channel = self.make_request("PUT", "presence/a/status") @@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase): self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] self.resource = ( - site.resource.children["_matrix"].children["client"].children["r0"] + site.resource.children[b"_matrix"].children[b"client"].children[b"r0"] ) request, channel = self.make_request("PUT", "presence/a/status") diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 56e7acd37c..a3aa0a1cf2 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,79 +14,79 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import synapse.api.errors import synapse.handlers.device import synapse.storage -from tests import unittest, utils +from tests import unittest user1 = "@boris:aaa" user2 = "@theresa:bbb" -class DeviceTestCase(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(DeviceTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore - self.handler = None # type: synapse.handlers.device.DeviceHandler - self.clock = None # type: utils.MockClock - - @defer.inlineCallbacks - def setUp(self): - hs = yield utils.setup_test_homeserver(self.addCleanup) +class DeviceTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) self.handler = hs.get_device_handler() self.store = hs.get_datastore() - self.clock = hs.get_clock() + return hs + + def prepare(self, reactor, clock, hs): + # These tests assume that it starts 1000 seconds in. + self.reactor.advance(1000) - @defer.inlineCallbacks def test_device_is_created_if_doesnt_exist(self): - res = yield self.handler.check_device_registered( - user_id="@boris:foo", - device_id="fco", - initial_device_display_name="display name", + res = self.get_success( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="fco", + initial_device_display_name="display name", + ) ) self.assertEqual(res, "fco") - dev = yield self.handler.store.get_device("@boris:foo", "fco") + dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) self.assertEqual(dev["display_name"], "display name") - @defer.inlineCallbacks def test_device_is_preserved_if_exists(self): - res1 = yield self.handler.check_device_registered( - user_id="@boris:foo", - device_id="fco", - initial_device_display_name="display name", + res1 = self.get_success( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="fco", + initial_device_display_name="display name", + ) ) self.assertEqual(res1, "fco") - res2 = yield self.handler.check_device_registered( - user_id="@boris:foo", - device_id="fco", - initial_device_display_name="new display name", + res2 = self.get_success( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="fco", + initial_device_display_name="new display name", + ) ) self.assertEqual(res2, "fco") - dev = yield self.handler.store.get_device("@boris:foo", "fco") + dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) self.assertEqual(dev["display_name"], "display name") - @defer.inlineCallbacks def test_device_id_is_made_up_if_unspecified(self): - device_id = yield self.handler.check_device_registered( - user_id="@theresa:foo", - device_id=None, - initial_device_display_name="display", + device_id = self.get_success( + self.handler.check_device_registered( + user_id="@theresa:foo", + device_id=None, + initial_device_display_name="display", + ) ) - dev = yield self.handler.store.get_device("@theresa:foo", device_id) + dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) self.assertEqual(dev["display_name"], "display") - @defer.inlineCallbacks def test_get_devices_by_user(self): - yield self._record_users() + self._record_users() + + res = self.get_success(self.handler.get_devices_by_user(user1)) - res = yield self.handler.get_devices_by_user(user1) self.assertEqual(3, len(res)) device_map = {d["device_id"]: d for d in res} self.assertDictContainsSubset( @@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase): device_map["abc"], ) - @defer.inlineCallbacks def test_get_device(self): - yield self._record_users() + self._record_users() - res = yield self.handler.get_device(user1, "abc") + res = self.get_success(self.handler.get_device(user1, "abc")) self.assertDictContainsSubset( { "user_id": user1, @@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase): res, ) - @defer.inlineCallbacks def test_delete_device(self): - yield self._record_users() + self._record_users() # delete the device - yield self.handler.delete_device(user1, "abc") + self.get_success(self.handler.delete_device(user1, "abc")) # check the device was deleted - with self.assertRaises(synapse.api.errors.NotFoundError): - yield self.handler.get_device(user1, "abc") + res = self.handler.get_device(user1, "abc") + self.pump() + self.assertIsInstance( + self.failureResultOf(res).value, synapse.api.errors.NotFoundError + ) # we'd like to check the access token was invalidated, but that's a # bit of a PITA. - @defer.inlineCallbacks def test_update_device(self): - yield self._record_users() + self._record_users() update = {"display_name": "new display"} - yield self.handler.update_device(user1, "abc", update) + self.get_success(self.handler.update_device(user1, "abc", update)) - res = yield self.handler.get_device(user1, "abc") + res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "new display") - @defer.inlineCallbacks def test_update_unknown_device(self): update = {"display_name": "new_display"} - with self.assertRaises(synapse.api.errors.NotFoundError): - yield self.handler.update_device("user_id", "unknown_device_id", update) + res = self.handler.update_device("user_id", "unknown_device_id", update) + self.pump() + self.assertIsInstance( + self.failureResultOf(res).value, synapse.api.errors.NotFoundError + ) - @defer.inlineCallbacks def _record_users(self): # check this works for both devices which have a recorded client_ip, # and those which don't. - yield self._record_user(user1, "xyz", "display 0") - yield self._record_user(user1, "fco", "display 1", "token1", "ip1") - yield self._record_user(user1, "abc", "display 2", "token2", "ip2") - yield self._record_user(user1, "abc", "display 2", "token3", "ip3") + self._record_user(user1, "xyz", "display 0") + self._record_user(user1, "fco", "display 1", "token1", "ip1") + self._record_user(user1, "abc", "display 2", "token2", "ip2") + self._record_user(user1, "abc", "display 2", "token3", "ip3") + + self._record_user(user2, "def", "dispkay", "token4", "ip4") - yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + self.reactor.advance(10000) - @defer.inlineCallbacks def _record_user( self, user_id, device_id, display_name, access_token=None, ip=None ): - device_id = yield self.handler.check_device_registered( - user_id=user_id, - device_id=device_id, - initial_device_display_name=display_name, + device_id = self.get_success( + self.handler.check_device_registered( + user_id=user_id, + device_id=device_id, + initial_device_display_name=display_name, + ) ) if ip is not None: - yield self.store.insert_client_ip( - user_id, access_token, ip, "user_agent", device_id + self.get_success( + self.store.insert_client_ip( + user_id, access_token, ip, "user_agent", device_id + ) ) - self.clock.advance_time(1000) + self.reactor.advance(1000) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ad58073a14..36e136cded 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -33,7 +33,7 @@ from ..utils import ( ) -def _expect_edu(destination, edu_type, content, origin="test"): +def _expect_edu_transaction(edu_type, content, origin="test"): return { "origin": origin, "origin_server_ts": 1000000, @@ -42,10 +42,8 @@ def _expect_edu(destination, edu_type, content, origin="test"): } -def _make_edu_json(origin, edu_type, content): - return json.dumps(_expect_edu("test", edu_type, content, origin=origin)).encode( - 'utf8' - ) +def _make_edu_transaction_json(edu_type, content): + return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8') class TypingNotificationsTestCase(unittest.TestCase): @@ -190,8 +188,7 @@ class TypingNotificationsTestCase(unittest.TestCase): call( "farm", path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu( - "farm", + data=_expect_edu_transaction( "m.typing", content={ "room_id": self.room_id, @@ -221,11 +218,10 @@ class TypingNotificationsTestCase(unittest.TestCase): self.assertEquals(self.event_source.get_current_key(), 0) - yield self.mock_federation_resource.trigger( + (code, response) = yield self.mock_federation_resource.trigger( "PUT", "/_matrix/federation/v1/send/1000000/", - _make_edu_json( - "farm", + _make_edu_transaction_json( "m.typing", content={ "room_id": self.room_id, @@ -233,7 +229,7 @@ class TypingNotificationsTestCase(unittest.TestCase): "typing": True, }, ), - federation_auth=True, + federation_auth_origin=b'farm', ) self.on_new_event.assert_has_calls( @@ -264,8 +260,7 @@ class TypingNotificationsTestCase(unittest.TestCase): call( "farm", path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu( - "farm", + data=_expect_edu_transaction( "m.typing", content={ "room_id": self.room_id, diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py new file mode 100644 index 0000000000..1c46c9cfeb --- /dev/null +++ b/tests/http/test_fedclient.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet.defer import TimeoutError +from twisted.internet.error import ConnectingCancelledError, DNSLookupError +from twisted.web.client import ResponseNeverReceived + +from synapse.http.matrixfederationclient import MatrixFederationHttpClient + +from tests.unittest import HomeserverTestCase + + +class FederationClientTests(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver(reactor=reactor, clock=clock) + hs.tls_client_options_factory = None + return hs + + def prepare(self, reactor, clock, homeserver): + + self.cl = MatrixFederationHttpClient(self.hs) + self.reactor.lookups["testserv"] = "1.2.3.4" + + def test_dns_error(self): + """ + If the DNS raising returns an error, it will bubble up. + """ + d = self.cl._request("testserv2:8008", "GET", "foo/bar", timeout=10000) + self.pump() + + f = self.failureResultOf(d) + self.assertIsInstance(f.value, DNSLookupError) + + def test_client_never_connect(self): + """ + If the HTTP request is not connected and is timed out, it'll give a + ConnectingCancelledError. + """ + d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000) + + self.pump() + + # Nothing happened yet + self.assertFalse(d.called) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][1], 8008) + + # Deferred is still without a result + self.assertFalse(d.called) + + # Push by enough to time it out + self.reactor.advance(10.5) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, ConnectingCancelledError) + + def test_client_connect_no_response(self): + """ + If the HTTP request is connected, but gets no response before being + timed out, it'll give a ResponseNeverReceived. + """ + d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000) + + self.pump() + + # Nothing happened yet + self.assertFalse(d.called) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][1], 8008) + + conn = Mock() + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred is still without a result + self.assertFalse(d.called) + + # Push by enough to time it out + self.reactor.advance(10.5) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, ResponseNeverReceived) + + def test_client_gets_headers(self): + """ + Once the client gets the headers, _request returns successfully. + """ + d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000) + + self.pump() + + conn = Mock() + clients = self.reactor.tcpClients + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred does not have a result + self.assertFalse(d.called) + + # Send it the HTTP response + client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n") + + # We should get a successful response + r = self.successResultOf(d) + self.assertEqual(r.code, 200) + + def test_client_headers_no_body(self): + """ + If the HTTP request is connected, but gets no response before being + timed out, it'll give a ResponseNeverReceived. + """ + d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + + self.pump() + + conn = Mock() + clients = self.reactor.tcpClients + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred does not have a result + self.assertFalse(d.called) + + # Send it the HTTP response + client.dataReceived( + (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n") + ) + + # Push by enough to time it out + self.reactor.advance(10.5) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, TimeoutError) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 65df116efc..089cecfbee 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -1,4 +1,5 @@ # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,89 +12,91 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import tempfile from mock import Mock, NonCallableMock -from twisted.internet import defer, reactor -from twisted.internet.defer import Deferred +import attr from synapse.replication.tcp.client import ( ReplicationClientFactory, ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory -from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable from tests import unittest -from tests.utils import setup_test_homeserver -class TestReplicationClientHandler(ReplicationClientHandler): - """Overrides on_rdata so that we can wait for it to happen""" +class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): - def __init__(self, store): - super(TestReplicationClientHandler, self).__init__(store) - self._rdata_awaiters = [] - - def await_replication(self): - d = Deferred() - self._rdata_awaiters.append(d) - return make_deferred_yieldable(d) - - def on_rdata(self, stream_name, token, rows): - awaiters = self._rdata_awaiters - self._rdata_awaiters = [] - super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows) - with PreserveLoggingContext(): - for a in awaiters: - a.callback(None) - - -class BaseSlavedStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver( - self.addCleanup, + hs = self.setup_test_homeserver( "blue", - http_client=None, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) - self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + hs.get_ratelimiter().send_message.return_value = (True, 0) + + return hs + + def prepare(self, reactor, clock, hs): self.master_store = self.hs.get_datastore() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - # XXX: mktemp is unsafe and should never be used. but we're just a test. - path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket") - listener = reactor.listenUNIX(path, server_factory) - self.addCleanup(listener.stopListening) self.streamer = server_factory.streamer - self.replication_handler = TestReplicationClientHandler(self.slaved_store) + self.replication_handler = ReplicationClientHandler(self.slaved_store) client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) - client_connector = reactor.connectUNIX(path, client_factory) - self.addCleanup(client_factory.stopTrying) - self.addCleanup(client_connector.disconnect) + + server = server_factory.buildProtocol(None) + client = client_factory.buildProtocol(None) + + @attr.s + class FakeTransport(object): + + other = attr.ib() + disconnecting = False + buffer = attr.ib(default=b'') + + def registerProducer(self, producer, streaming): + + self.producer = producer + + def _produce(): + self.producer.resumeProducing() + reactor.callLater(0.1, _produce) + + reactor.callLater(0.0, _produce) + + def write(self, byt): + self.buffer = self.buffer + byt + + if getattr(self.other, "transport") is not None: + self.other.dataReceived(self.buffer) + self.buffer = b"" + + def writeSequence(self, seq): + for x in seq: + self.write(x) + + client.makeConnection(FakeTransport(server)) + server.makeConnection(FakeTransport(client)) def replicate(self): """Tell the master side of replication that something has happened, and then wait for the replication to occur. """ - # xxx: should we be more specific in what we wait for? - d = self.replication_handler.await_replication() self.streamer.on_notifier_poke() - return d + self.pump(0.1) - @defer.inlineCallbacks def check(self, method, args, expected_result=None): - master_result = yield getattr(self.master_store, method)(*args) - slaved_result = yield getattr(self.slaved_store, method)(*args) + master_result = self.get_success(getattr(self.master_store, method)(*args)) + slaved_result = self.get_success(getattr(self.slaved_store, method)(*args)) if expected_result is not None: self.assertEqual(master_result, expected_result) self.assertEqual(slaved_result, expected_result) diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py index 87cc2b2fba..43e3248703 100644 --- a/tests/replication/slave/storage/test_account_data.py +++ b/tests/replication/slave/storage/test_account_data.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from ._base import BaseSlavedStoreTestCase @@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): STORE_TYPE = SlavedAccountDataStore - @defer.inlineCallbacks def test_user_account_data(self): - yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) - yield self.replicate() - yield self.check( + self.get_success( + self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) + ) + self.replicate() + self.check( "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1} ) - yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) - yield self.replicate() - yield self.check( + self.get_success( + self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) + ) + self.replicate() + self.check( "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2} ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 2ba80ccdcf..db44d33c68 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext from synapse.replication.slave.storage.events import SlavedEventStore @@ -55,70 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def tearDown(self): [unpatch() for unpatch in self.unpatches] - @defer.inlineCallbacks def test_get_latest_event_ids_in_room(self): - create = yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.replicate() - yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) + create = self.persist(type="m.room.create", key="", creator=USER_ID) + self.replicate() + self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) - join = yield self.persist( + join = self.persist( type="m.room.member", key=USER_ID, membership="join", prev_events=[(create.event_id, {})], ) - yield self.replicate() - yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) + self.replicate() + self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) - @defer.inlineCallbacks def test_redactions(self): - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.persist(type="m.room.member", key=USER_ID, membership="join") + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") - msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") - yield self.replicate() - yield self.check("get_event", [msg.event_id], msg) + msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") + self.replicate() + self.check("get_event", [msg.event_id], msg) - redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id) - yield self.replicate() + redaction = self.persist(type="m.room.redaction", redacts=msg.event_id) + self.replicate() msg_dict = msg.get_dict() msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) - yield self.check("get_event", [msg.event_id], redacted) + self.check("get_event", [msg.event_id], redacted) - @defer.inlineCallbacks def test_backfilled_redactions(self): - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.persist(type="m.room.member", key=USER_ID, membership="join") + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") - msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") - yield self.replicate() - yield self.check("get_event", [msg.event_id], msg) + msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") + self.replicate() + self.check("get_event", [msg.event_id], msg) - redaction = yield self.persist( + redaction = self.persist( type="m.room.redaction", redacts=msg.event_id, backfill=True ) - yield self.replicate() + self.replicate() msg_dict = msg.get_dict() msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) - yield self.check("get_event", [msg.event_id], redacted) + self.check("get_event", [msg.event_id], redacted) - @defer.inlineCallbacks def test_invites(self): - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.check("get_invited_rooms_for_user", [USER_ID_2], []) - event = yield self.persist( - type="m.room.member", key=USER_ID_2, membership="invite" - ) - yield self.replicate() - yield self.check( + self.persist(type="m.room.create", key="", creator=USER_ID) + self.check("get_invited_rooms_for_user", [USER_ID_2], []) + event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") + + self.replicate() + + self.check( "get_invited_rooms_for_user", [USER_ID_2], [ @@ -132,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ], ) - @defer.inlineCallbacks def test_push_actions_for_user(self): - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.persist(type="m.room.join", key=USER_ID, membership="join") - yield self.persist( + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.join", key=USER_ID, membership="join") + self.persist( type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join" ) - event1 = yield self.persist( - type="m.room.message", msgtype="m.text", body="hello" - ) - yield self.replicate() - yield self.check( + event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello") + self.replicate() + self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], {"highlight_count": 0, "notify_count": 0}, ) - yield self.persist( + self.persist( type="m.room.message", msgtype="m.text", body="world", push_actions=[(USER_ID_2, ["notify"])], ) - yield self.replicate() - yield self.check( + self.replicate() + self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], {"highlight_count": 0, "notify_count": 1}, ) - yield self.persist( + self.persist( type="m.room.message", msgtype="m.text", body="world", @@ -170,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}]) ], ) - yield self.replicate() - yield self.check( + self.replicate() + self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], {"highlight_count": 1, "notify_count": 2}, @@ -179,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event_id = 0 - @defer.inlineCallbacks def persist( self, sender=USER_ID, @@ -206,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): depth = self.event_id if not prev_events: - latest_event_ids = yield self.master_store.get_latest_event_ids_in_room( - room_id + latest_event_ids = self.get_success( + self.master_store.get_latest_event_ids_in_room(room_id) ) prev_events = [(ev_id, {}) for ev_id in latest_event_ids] @@ -240,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) else: state_handler = self.hs.get_state_handler() - context = yield state_handler.compute_event_context(event) + context = self.get_success(state_handler.compute_event_context(event)) - yield self.master_store.add_push_actions_to_staging( + self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} ) ordering = None if backfill: - yield self.master_store.persist_events([(event, context)], backfilled=True) + self.get_success( + self.master_store.persist_events([(event, context)], backfilled=True) + ) else: - ordering, _ = yield self.master_store.persist_event(event, context) + ordering, _ = self.get_success( + self.master_store.persist_event(event, context) + ) if ordering: event.internal_metadata.stream_ordering = ordering - defer.returnValue(event) + return event diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index ae1adeded1..f47d94f690 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from ._base import BaseSlavedStoreTestCase @@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): STORE_TYPE = SlavedReceiptsStore - @defer.inlineCallbacks def test_receipt(self): - yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) - yield self.master_store.insert_receipt( - ROOM_ID, "m.read", USER_ID, [EVENT_ID], {} - ) - yield self.replicate() - yield self.check( - "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID} + self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) + self.get_success( + self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}) ) + self.replicate() + self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 9fe0760496..359f7777ff 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -22,39 +22,24 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer -import synapse.rest.client.v1.room from synapse.api.constants import Membership -from synapse.http.server import JsonResource -from synapse.types import UserID -from synapse.util import Clock +from synapse.rest.client.v1 import room from tests import unittest -from tests.server import ( - ThreadedMemoryReactorClock, - make_request, - render, - setup_test_homeserver, -) - -from .utils import RestHelper PATH_PREFIX = b"/_matrix/client/api/v1" -class RoomBase(unittest.TestCase): +class RoomBase(unittest.HomeserverTestCase): rmcreator_id = None - def setUp(self): + servlets = [room.register_servlets, room.register_deprecated_servlets] - self.clock = ThreadedMemoryReactorClock() - self.hs_clock = Clock(self.clock) + def make_homeserver(self, reactor, clock): - self.hs = setup_test_homeserver( - self.addCleanup, + self.hs = self.setup_test_homeserver( "red", http_client=None, - clock=self.hs_clock, - reactor=self.clock, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) @@ -63,42 +48,21 @@ class RoomBase(unittest.TestCase): self.hs.get_federation_handler = Mock(return_value=Mock()) - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } - - def get_user_by_req(request, allow_guest=False, rights="access"): - return synapse.types.create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None - ) - - self.hs.get_auth().get_user_by_req = get_user_by_req - self.hs.get_auth().get_user_by_access_token = get_user_by_access_token - self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234") - def _insert_client_ip(*args, **kwargs): return defer.succeed(None) self.hs.get_datastore().insert_client_ip = _insert_client_ip - self.resource = JsonResource(self.hs) - synapse.rest.client.v1.room.register_servlets(self.hs, self.resource) - synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource) - self.helper = RestHelper(self.hs, self.resource, self.user_id) + return self.hs class RoomPermissionsTestCase(RoomBase): """ Tests room permissions. """ - user_id = b"@sid1:red" - rmcreator_id = b"@notme:red" - - def setUp(self): + user_id = "@sid1:red" + rmcreator_id = "@notme:red" - super(RoomPermissionsTestCase, self).setUp() + def prepare(self, reactor, clock, hs): self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id @@ -114,22 +78,20 @@ class RoomPermissionsTestCase(RoomBase): self.created_rmid_msg_path = ( "rooms/%s/send/m.room.message/a1" % (self.created_rmid) ).encode('ascii') - request, channel = make_request( - b"PUT", - self.created_rmid_msg_path, - b'{"msgtype":"m.text","body":"test msg"}', + request, channel = self.make_request( + "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - render(request, self.resource, self.clock) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.render(request) + self.assertEquals(200, channel.code, channel.result) # set topic for public room - request, channel = make_request( - b"PUT", + request, channel = self.make_request( + "PUT", ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), b'{"topic":"Public Room Topic"}', ) - render(request, self.resource, self.clock) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.render(request) + self.assertEquals(200, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -140,128 +102,128 @@ class RoomPermissionsTestCase(RoomBase): seq = iter(range(100)) def send_msg_path(): - return b"/rooms/%s/send/m.room.message/mid%s" % ( + return "/rooms/%s/send/m.room.message/mid%s" % ( self.created_rmid, - str(next(seq)).encode('ascii'), + str(next(seq)), ) # send message in uncreated room, expect 403 - request, channel = make_request( - b"PUT", - b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), + request, channel = self.make_request( + "PUT", + "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): topic_content = b'{"topic":"My Topic Name"}' - topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid + topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid # set/get topic in uncreated room, expect 403 - request, channel = make_request( - b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content + request, channel = self.make_request( + "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - request, channel = make_request( - b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + request, channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assert_dict(json.loads(topic_content), channel.json_body) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 - request, channel = make_request( - b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid + request, channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 - request, channel = make_request( - b"PUT", - b"/rooms/%s/state/m.room.topic" % self.created_public_rmid, + request, channel = self.make_request( + "PUT", + "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def _test_get_membership(self, room=None, members=[], expect_code=None): for member in members: - path = b"/rooms/%s/state/m.room.member/%s" % (room, member) - request, channel = make_request(b"GET", path) - render(request, self.resource, self.clock) - self.assertEquals(expect_code, int(channel.result["code"])) + path = "/rooms/%s/state/m.room.member/%s" % (room, member) + request, channel = self.make_request("GET", path) + self.render(request) + self.assertEquals(expect_code, channel.code) def test_membership_basic_room_perms(self): # === room does not exist === @@ -428,217 +390,211 @@ class RoomPermissionsTestCase(RoomBase): class RoomsMemberListTestCase(RoomBase): """ Tests /rooms/$room_id/members/list REST events.""" - user_id = b"@sid1:red" + user_id = "@sid1:red" def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): - request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members") - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): - room_id = self.helper.create_room_as(b"@some_other_guy:red") - request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + room_id = self.helper.create_room_as("@some_other_guy:red") + request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): - room_creator = b"@some_other_guy:red" + room_creator = "@some_other_guy:red" room_id = self.helper.create_room_as(room_creator) - room_path = b"/rooms/%s/members" % room_id + room_path = "/rooms/%s/members" % room_id self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. - request, channel = make_request(b"GET", room_path) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", room_path) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined - request, channel = make_request(b"GET", room_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", room_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left - request, channel = make_request(b"GET", room_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", room_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) class RoomsCreateTestCase(RoomBase): """ Tests /rooms and /rooms/$room_id REST events. """ - user_id = b"@sid1:red" + user_id = "@sid1:red" def test_post_room_no_keys(self): # POST with no config keys, expect new room id - request, channel = make_request(b"POST", b"/createRoom", b"{}") + request, channel = self.make_request("POST", "/createRoom", "{}") - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), channel.result) + self.render(request) + self.assertEquals(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id - request, channel = make_request( - b"POST", b"/createRoom", b'{"visibility":"private"}' + request, channel = self.make_request( + "POST", "/createRoom", b'{"visibility":"private"}' ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self): # POST with custom config keys, expect new room id - request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}') - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + request, channel = self.make_request( + "POST", "/createRoom", b'{"custom":"stuff"}' + ) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self): # POST with custom + known config keys, expect new room id - request, channel = make_request( - b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}' + request, channel = self.make_request( + "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 - request, channel = make_request(b"POST", b"/createRoom", b'{"visibili') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"])) + request, channel = self.make_request("POST", "/createRoom", b'{"visibili') + self.render(request) + self.assertEquals(400, channel.code) - request, channel = make_request(b"POST", b"/createRoom", b'["hello"]') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"])) + request, channel = self.make_request("POST", "/createRoom", b'["hello"]') + self.render(request) + self.assertEquals(400, channel.code) class RoomTopicTestCase(RoomBase): """ Tests /rooms/$room_id/topic REST events. """ - user_id = b"@sid1:red" - - def setUp(self): - - super(RoomTopicTestCase, self).setUp() + user_id = "@sid1:red" + def prepare(self, reactor, clock, hs): # create the room self.room_id = self.helper.create_room_as(self.user_id) - self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,) + self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) def test_invalid_puts(self): # missing keys or invalid json - request, channel = make_request(b"PUT", self.path, '{}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '{}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, '{"nao') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '{"nao') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request( - b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = self.make_request( + "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, 'text only') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, 'text only') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, '') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' - request, channel = make_request(b"PUT", self.path, content) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, content) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there - request, channel = make_request(b"GET", self.path) - render(request, self.resource, self.clock) - self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", self.path) + self.render(request) + self.assertEquals(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' - request, channel = make_request(b"PUT", self.path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = make_request(b"GET", self.path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", self.path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - request, channel = make_request(b"PUT", self.path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = make_request(b"GET", self.path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", self.path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) class RoomMemberStateTestCase(RoomBase): """ Tests /rooms/$room_id/members/$user_id/state REST events. """ - user_id = b"@sid1:red" - - def setUp(self): + user_id = "@sid1:red" - super(RoomMemberStateTestCase, self).setUp() + def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) - def tearDown(self): - pass - def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - request, channel = make_request(b"PUT", path, '{}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '{}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"_name":"bob"}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '{"_name":"bo"}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"nao') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '{"nao') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request( - b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]' + request, channel = self.make_request( + "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' ) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, 'text only') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, 'text only') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -646,9 +602,9 @@ class RoomMemberStateTestCase(RoomBase): Membership.JOIN, Membership.LEAVE, ) - request, channel = make_request(b"PUT", path, content.encode('ascii')) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content.encode('ascii')) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( @@ -658,13 +614,13 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - request, channel = make_request(b"PUT", path, content.encode('ascii')) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content.encode('ascii')) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"GET", path, None) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", path, None) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEquals(expected_response, channel.json_body) @@ -678,13 +634,13 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"GET", path, None) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", path, None) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self): @@ -699,13 +655,13 @@ class RoomMemberStateTestCase(RoomBase): Membership.INVITE, "Join us!", ) - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"GET", path, None) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", path, None) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -714,60 +670,58 @@ class RoomMessagesTestCase(RoomBase): user_id = "@sid1:red" - def setUp(self): - super(RoomMessagesTestCase, self).setUp() - + def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - request, channel = make_request(b"PUT", path, '{}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'{}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"_name":"bob"}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'{"_name":"bo"}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"nao') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'{"nao') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request( - b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = self.make_request( + "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' ) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, 'text only') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'text only') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) - content = '{"body":"test","msgtype":{"type":"a"}}' - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + content = b'{"body":"test","msgtype":{"type":"a"}}' + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) # custom message types - content = '{"body":"test","msgtype":"test.custom.text"}' - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + content = b'{"body":"test","msgtype":"test.custom.text"}' + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) - content = '{"body":"test2","msgtype":"m.text"}' - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + content = b'{"body":"test2","msgtype":"m.text"}' + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) class RoomInitialSyncTestCase(RoomBase): @@ -775,16 +729,16 @@ class RoomInitialSyncTestCase(RoomBase): user_id = "@sid1:red" - def setUp(self): - super(RoomInitialSyncTestCase, self).setUp() - + def prepare(self, reactor, clock, hs): # create the room self.room_id = self.helper.create_room_as(self.user_id) def test_initial_sync(self): - request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + request, channel = self.make_request( + "GET", "/rooms/%s/initialSync" % self.room_id + ) + self.render(request) + self.assertEquals(200, channel.code) self.assertEquals(self.room_id, channel.json_body["room_id"]) self.assertEquals("join", channel.json_body["membership"]) @@ -819,17 +773,16 @@ class RoomMessageListTestCase(RoomBase): user_id = "@sid1:red" - def setUp(self): - super(RoomMessageListTestCase, self).setUp() + def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self): token = "t1-0_0_0_0_0_0_0_0_0" - request, channel = make_request( - b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + request, channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) self.assertEquals(token, channel.json_body['start']) self.assertTrue("chunk" in channel.json_body) @@ -837,11 +790,11 @@ class RoomMessageListTestCase(RoomBase): def test_stream_token_is_accepted_for_fwd_pagianation(self): token = "s0_0_0_0_0_0_0_0_0" - request, channel = make_request( - b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + request, channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) self.assertEquals(token, channel.json_body['start']) self.assertTrue("chunk" in channel.json_body) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 40dc4ea256..530dc8ba6d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -240,7 +240,6 @@ class RestHelper(object): self.assertEquals(200, code) defer.returnValue(response) - @defer.inlineCallbacks def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -248,9 +247,16 @@ class RestHelper(object): body = "body_text_here" path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) - content = '{"msgtype":"m.text","body":"%s"}' % body + content = {"msgtype": "m.text", "body": body} if tok: path = path + "?access_token=%s" % tok - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(expect_code, code, msg=str(response)) + request, channel = make_request("PUT", path, json.dumps(content).encode('utf8')) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 560b1fba96..4c30c5f258 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -62,12 +62,6 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( set( - [ - "next_batch", - "rooms", - "account_data", - "to_device", - "device_lists", - ] + ["next_batch", "rooms", "account_data", "to_device", "device_lists"] ).issubset(set(channel.json_body.keys())) ) diff --git a/tests/server.py b/tests/server.py index 7dbdb7f8ea..420ec4e088 100644 --- a/tests/server.py +++ b/tests/server.py @@ -4,9 +4,14 @@ from io import BytesIO from six import text_type import attr +from zope.interface import implementer -from twisted.internet import address, threads +from twisted.internet import address, threads, udp +from twisted.internet._resolver import HostResolution +from twisted.internet.address import IPv4Address from twisted.internet.defer import Deferred +from twisted.internet.error import DNSLookupError +from twisted.internet.interfaces import IReactorPluggableNameResolver from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactorClock @@ -65,7 +70,7 @@ class FakeChannel(object): def getPeer(self): # We give an address so that getClientIP returns a non null entry, # causing us to record the MAU - return address.IPv4Address(b"TCP", "127.0.0.1", 3423) + return address.IPv4Address("TCP", "127.0.0.1", 3423) def getHost(self): return None @@ -154,11 +159,46 @@ def render(request, resource, clock): wait_until_result(clock, request) +@implementer(IReactorPluggableNameResolver) class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread. """ + def __init__(self): + self._udp = [] + self.lookups = {} + + class Resolver(object): + def resolveHostName( + _self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics='TCP', + ): + + resolution = HostResolution(hostName) + resolutionReceiver.resolutionBegan(resolution) + if hostName not in self.lookups: + raise DNSLookupError("OH NO") + + resolutionReceiver.addressResolved( + IPv4Address('TCP', self.lookups[hostName], portNumber) + ) + resolutionReceiver.resolutionComplete() + return resolution + + self.nameResolver = Resolver() + super(ThreadedMemoryReactorClock, self).__init__() + + def listenUDP(self, port, protocol, interface='', maxPacketSize=8196): + p = udp.Port(port, protocol, interface, maxPacketSize, self) + p.startListening() + self._udp.append(p) + return p + def callFromThread(self, callback, *args, **kwargs): """ Make the callback fire in the next reactor iteration. @@ -232,6 +272,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): clock.threadpool = ThreadPool() pool.threadpool = ThreadPool() + pool.running = True return d diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 5cc7fff39b..4701eedd45 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -80,12 +80,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase): self._rlsn._auth.check_auth_blocking = Mock() mock_event = Mock( - type=EventTypes.Message, - content={"msgtype": ServerNoticeMsgType}, + type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} + ) + self._rlsn._store.get_events = Mock( + return_value=defer.succeed({"123": mock_event}) ) - self._rlsn._store.get_events = Mock(return_value=defer.succeed( - {"123": mock_event} - )) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) # Would be better to check the content, but once == remove blocking event @@ -99,12 +98,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase): ) mock_event = Mock( - type=EventTypes.Message, - content={"msgtype": ServerNoticeMsgType}, + type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} + ) + self._rlsn._store.get_events = Mock( + return_value=defer.succeed({"123": mock_event}) ) - self._rlsn._store.get_events = Mock(return_value=defer.succeed( - {"123": mock_event} - )) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self._send_notice.assert_not_called() @@ -177,13 +175,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): @defer.inlineCallbacks def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock( - return_value=1000, - ) + self.store.get_monthly_active_count = Mock(return_value=1000) - self.store.user_last_seen_monthly_active = Mock( - return_value=1000, - ) + self.store.user_last_seen_monthly_active = Mock(return_value=1000) # Call the function multiple times to ensure we only send the notice once yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) @@ -193,12 +187,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): # Now lets get the last load of messages in the service notice room and # check that there is only one server notice room_id = yield self.server_notices_manager.get_notice_room_for_user( - self.user_id, + self.user_id ) token = yield self.event_source.get_current_token() events, _ = yield self.store.get_recent_events_for_room( - room_id, limit=100, end_token=token.room_key, + room_id, limit=100, end_token=token.room_key ) count = 0 diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index c893990454..3f0083831b 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.as_yaml_files = [] - config = Mock( - app_service_config_files=self.as_yaml_files, - event_cache_size=1, - password_providers=[], - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + hs.config.app_service_config_files = self.as_yaml_files + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + self.as_token = "token1" self.as_url = "some_url" self.as_id = "as1" @@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(None, hs) + self.store = ApplicationServiceStore(hs.get_db_conn(), hs) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): def setUp(self): self.as_yaml_files = [] - config = Mock( - app_service_config_files=self.as_yaml_files, - event_cache_size=1, - password_providers=[], - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + + hs.config.app_service_config_files = self.as_yaml_files + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + self.db_pool = hs.get_db_pool() + self.engine = hs.database_engine self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, @@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(None, hs) + self.store = TestTransactionStore(hs.get_db_conn(), hs) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files.append(as_token) def _set_state(self, id, state, txn=None): - return self.db_pool.runQuery( - "INSERT INTO application_services_state(as_id, state, last_txn) " - "VALUES(?,?,?)", + return self.db_pool.runOperation( + self.engine.convert_param_style( + "INSERT INTO application_services_state(as_id, state, last_txn) " + "VALUES(?,?,?)" + ), (id, state, txn), ) def _insert_txn(self, as_id, txn_id, events): - return self.db_pool.runQuery( - "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " - "VALUES(?,?,?)", + return self.db_pool.runOperation( + self.engine.convert_param_style( + "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " + "VALUES(?,?,?)" + ), (as_id, txn_id, json.dumps([e.event_id for e in events])), ) def _set_last_txn(self, as_id, txn_id): - return self.db_pool.runQuery( - "INSERT INTO application_services_state(as_id, last_txn, state) " - "VALUES(?,?,?)", + return self.db_pool.runOperation( + self.engine.convert_param_style( + "INSERT INTO application_services_state(as_id, last_txn, state) " + "VALUES(?,?,?)" + ), (as_id, txn_id, ApplicationServiceState.UP), ) @defer.inlineCallbacks def test_get_appservice_state_none(self): - service = Mock(id=999) + service = Mock(id="999") state = yield self.store.get_appservice_state(service) self.assertEquals(None, state) @@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): service = Mock(id=self.as_list[1]["id"]) yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) rows = yield self.db_pool.runQuery( - "SELECT as_id FROM application_services_state WHERE state=?", + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), (ApplicationServiceState.DOWN,), ) self.assertEquals(service.id, rows[0][0]) @@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) yield self.store.set_appservice_state(service, ApplicationServiceState.UP) rows = yield self.db_pool.runQuery( - "SELECT as_id FROM application_services_state WHERE state=?", + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), (ApplicationServiceState.UP,), ) self.assertEquals(service.id, rows[0][0]) @@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) res = yield self.db_pool.runQuery( - "SELECT last_txn FROM application_services_state WHERE as_id=?", + self.engine.convert_param_style( + "SELECT last_txn FROM application_services_state WHERE as_id=?" + ), (service.id,), ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) res = yield self.db_pool.runQuery( - "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), ) self.assertEquals(0, len(res)) @@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) res = yield self.db_pool.runQuery( - "SELECT last_txn, state FROM application_services_state WHERE " "as_id=?", + self.engine.convert_param_style( + "SELECT last_txn, state FROM application_services_state WHERE as_id=?" + ), (service.id,), ) self.assertEquals(1, len(res)) @@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(ApplicationServiceState.UP, res[0][1]) res = yield self.db_pool.runQuery( - "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), ) self.assertEquals(0, len(res)) @@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(suffix="1") f2 = self._write_config(suffix="2") - config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - datastore=Mock(), - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - ApplicationServiceStore(None, hs) + hs.config.app_service_config_files = [f1, f2] + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + + ApplicationServiceStore(hs.get_db_conn(), hs) @defer.inlineCallbacks def test_duplicate_ids(self): f1 = self._write_config(id="id", suffix="1") f2 = self._write_config(id="id", suffix="2") - config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - datastore=Mock(), - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + hs.config.app_service_config_files = [f1, f2] + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(None, hs) + ApplicationServiceStore(hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) @@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(as_token="as_token", suffix="1") f2 = self._write_config(as_token="as_token", suffix="2") - config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - datastore=Mock(), - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + hs.config.app_service_config_files = [f1, f2] + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(None, hs) + ApplicationServiceStore(hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 7cb5f0e4cf..829f47d2e8 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -20,11 +20,11 @@ from mock import Mock from twisted.internet import defer -from synapse.server import HomeServer from synapse.storage._base import SQLBaseStore from synapse.storage.engines import create_engine from tests import unittest +from tests.utils import TestHomeServer class SQLBaseStoreTestCase(unittest.TestCase): @@ -51,7 +51,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config.event_cache_size = 1 config.database_config = {"name": "sqlite3"} - hs = HomeServer( + hs = TestHomeServer( "test", db_pool=self.db_pool, config=config, diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c2e88bdbaf..c9b02a062b 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -101,6 +101,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" + yield self.store.register(user_id=user_id, token="123", password_hash=None) active = yield self.store.user_last_seen_monthly_active(user_id) self.assertFalse(active) @@ -108,8 +109,5 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id" ) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" - ) active = yield self.store.user_last_seen_monthly_active(user_id) self.assertTrue(active) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index b4510c1c8d..4e128e1047 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.directory import DirectoryStore from synapse.types import RoomAlias, RoomID from tests import unittest @@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = DirectoryStore(None, hs) + self.store = hs.get_datastore() self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 2fdf34fdf6..0d4e74d637 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ( "INSERT INTO events (" " room_id, event_id, type, depth, topological_ordering," - " content, processed, outlier) " - "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" + " content, processed, outlier, stream_ordering) " + "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)" ), - (room_id, event_id, i, i, True, False), + (room_id, event_id, i, i, True, False, i), ) txn.execute( diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index f2ed866ae7..686f12a0dc 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -12,26 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock from twisted.internet import defer -import tests.unittest -import tests.utils -from tests.utils import setup_test_homeserver +from tests.unittest import HomeserverTestCase FORTY_DAYS = 40 * 24 * 60 * 60 -class MonthlyActiveUsersTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs) +class MonthlyActiveUsersTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = self.hs.get_datastore() + hs = self.setup_test_homeserver() + self.store = hs.get_datastore() + hs.config.limit_usage_by_mau = True + hs.config.max_mau_value = 50 + # Advance the clock a bit + reactor.advance(FORTY_DAYS) + + return hs - @defer.inlineCallbacks def test_initialise_reserved_users(self): self.hs.config.max_mau_value = 5 user1 = "@user1:server" @@ -44,88 +45,172 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): ] user_num = len(threepids) - yield self.store.register(user_id=user1, token="123", password_hash=None) - - yield self.store.register(user_id=user2, token="456", password_hash=None) + self.store.register(user_id=user1, token="123", password_hash=None) + self.store.register(user_id=user2, token="456", password_hash=None) + self.pump() now = int(self.hs.get_clock().time_msec()) - yield self.store.user_add_threepid(user1, "email", user1_email, now, now) - yield self.store.user_add_threepid(user2, "email", user2_email, now, now) - yield self.store.initialise_reserved_users(threepids) + self.store.user_add_threepid(user1, "email", user1_email, now, now) + self.store.user_add_threepid(user2, "email", user2_email, now, now) + self.store.initialise_reserved_users(threepids) + self.pump() - active_count = yield self.store.get_monthly_active_count() + active_count = self.store.get_monthly_active_count() # Test total counts - self.assertEquals(active_count, user_num) + self.assertEquals(self.get_success(active_count), user_num) # Test user is marked as active - - timestamp = yield self.store.user_last_seen_monthly_active(user1) - self.assertTrue(timestamp) - timestamp = yield self.store.user_last_seen_monthly_active(user2) - self.assertTrue(timestamp) + timestamp = self.store.user_last_seen_monthly_active(user1) + self.assertTrue(self.get_success(timestamp)) + timestamp = self.store.user_last_seen_monthly_active(user2) + self.assertTrue(self.get_success(timestamp)) # Test that users are never removed from the db. self.hs.config.max_mau_value = 0 - self.hs.get_clock().advance_time(FORTY_DAYS) + self.reactor.advance(FORTY_DAYS) - yield self.store.reap_monthly_active_users() + self.store.reap_monthly_active_users() + self.pump() - active_count = yield self.store.get_monthly_active_count() - self.assertEquals(active_count, user_num) + active_count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(active_count), user_num) - # Test that regalar users are removed from the db + # Test that regular users are removed from the db ru_count = 2 - yield self.store.upsert_monthly_active_user("@ru1:server") - yield self.store.upsert_monthly_active_user("@ru2:server") - active_count = yield self.store.get_monthly_active_count() + self.store.upsert_monthly_active_user("@ru1:server") + self.store.upsert_monthly_active_user("@ru2:server") + self.pump() - self.assertEqual(active_count, user_num + ru_count) + active_count = self.store.get_monthly_active_count() + self.assertEqual(self.get_success(active_count), user_num + ru_count) self.hs.config.max_mau_value = user_num - yield self.store.reap_monthly_active_users() + self.store.reap_monthly_active_users() + self.pump() - active_count = yield self.store.get_monthly_active_count() - self.assertEquals(active_count, user_num) + active_count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(active_count), user_num) - @defer.inlineCallbacks def test_can_insert_and_count_mau(self): - count = yield self.store.get_monthly_active_count() - self.assertEqual(0, count) + count = self.store.get_monthly_active_count() + self.assertEqual(0, self.get_success(count)) - yield self.store.upsert_monthly_active_user("@user:server") - count = yield self.store.get_monthly_active_count() + self.store.upsert_monthly_active_user("@user:server") + self.pump() - self.assertEqual(1, count) + count = self.store.get_monthly_active_count() + self.assertEqual(1, self.get_success(count)) - @defer.inlineCallbacks def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" - result = yield self.store.user_last_seen_monthly_active(user_id1) - self.assertFalse(result == 0) - yield self.store.upsert_monthly_active_user(user_id1) - yield self.store.upsert_monthly_active_user(user_id2) - result = yield self.store.user_last_seen_monthly_active(user_id1) - self.assertTrue(result > 0) - result = yield self.store.user_last_seen_monthly_active(user_id3) - self.assertFalse(result == 0) + result = self.store.user_last_seen_monthly_active(user_id1) + self.assertFalse(self.get_success(result) == 0) + + self.store.upsert_monthly_active_user(user_id1) + self.store.upsert_monthly_active_user(user_id2) + self.pump() + + result = self.store.user_last_seen_monthly_active(user_id1) + self.assertGreater(self.get_success(result), 0) + + result = self.store.user_last_seen_monthly_active(user_id3) + self.assertNotEqual(self.get_success(result), 0) - @defer.inlineCallbacks def test_reap_monthly_active_users(self): self.hs.config.max_mau_value = 5 initial_users = 10 for i in range(initial_users): - yield self.store.upsert_monthly_active_user("@user%d:server" % i) - count = yield self.store.get_monthly_active_count() - self.assertTrue(count, initial_users) - yield self.store.reap_monthly_active_users() - count = yield self.store.get_monthly_active_count() - self.assertEquals(count, initial_users - self.hs.config.max_mau_value) - - self.hs.get_clock().advance_time(FORTY_DAYS) - yield self.store.reap_monthly_active_users() - count = yield self.store.get_monthly_active_count() - self.assertEquals(count, 0) + self.store.upsert_monthly_active_user("@user%d:server" % i) + self.pump() + + count = self.store.get_monthly_active_count() + self.assertTrue(self.get_success(count), initial_users) + + self.store.reap_monthly_active_users() + self.pump() + count = self.store.get_monthly_active_count() + self.assertEquals( + self.get_success(count), initial_users - self.hs.config.max_mau_value + ) + + self.reactor.advance(FORTY_DAYS) + self.store.reap_monthly_active_users() + self.pump() + + count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(count), 0) + + def test_populate_monthly_users_is_guest(self): + # Test that guest users are not added to mau list + user_id = "user_id" + self.store.register( + user_id=user_id, token="123", password_hash=None, make_guest=True + ) + self.store.upsert_monthly_active_user = Mock() + self.store.populate_monthly_active_users(user_id) + self.pump() + self.store.upsert_monthly_active_user.assert_not_called() + + def test_populate_monthly_users_should_update(self): + self.store.upsert_monthly_active_user = Mock() + + self.store.is_trial_user = Mock( + return_value=defer.succeed(False) + ) + + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(None) + ) + self.store.populate_monthly_active_users('user_id') + self.pump() + self.store.upsert_monthly_active_user.assert_called_once() + + def test_populate_monthly_users_should_not_update(self): + self.store.upsert_monthly_active_user = Mock() + + self.store.is_trial_user = Mock( + return_value=defer.succeed(False) + ) + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed( + self.hs.get_clock().time_msec() + ) + ) + self.store.populate_monthly_active_users('user_id') + self.pump() + self.store.upsert_monthly_active_user.assert_not_called() + + def test_get_reserved_real_user_account(self): + # Test no reserved users, or reserved threepids + count = self.store.get_registered_reserved_users_count() + self.assertEquals(self.get_success(count), 0) + # Test reserved users but no registered users + + user1 = '@user1:example.com' + user2 = '@user2:example.com' + user1_email = 'user1@example.com' + user2_email = 'user2@example.com' + threepids = [ + {'medium': 'email', 'address': user1_email}, + {'medium': 'email', 'address': user2_email}, + ] + self.hs.config.mau_limits_reserved_threepids = threepids + self.store.initialise_reserved_users(threepids) + self.pump() + count = self.store.get_registered_reserved_users_count() + self.assertEquals(self.get_success(count), 0) + + # Test reserved registed users + self.store.register(user_id=user1, token="123", password_hash=None) + self.store.register(user_id=user2, token="456", password_hash=None) + self.pump() + + now = int(self.hs.get_clock().time_msec()) + self.store.user_add_threepid(user1, "email", user1_email, now, now) + self.store.user_add_threepid(user2, "email", user2_email, now, now) + count = self.store.get_registered_reserved_users_count() + self.assertEquals(self.get_success(count), len(threepids)) diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index b5b58ff660..c7a63f39b9 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -16,19 +16,18 @@ from twisted.internet import defer -from synapse.storage.presence import PresenceStore from synapse.types import UserID from tests import unittest -from tests.utils import MockClock, setup_test_homeserver +from tests.utils import setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock()) + hs = yield setup_test_homeserver(self.addCleanup) - self.store = PresenceStore(None, hs) + self.store = hs.get_datastore() self.u_apple = UserID.from_string("@apple:test") self.u_banana = UserID.from_string("@banana:test") diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index a1f6618bf9..45824bd3b2 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -28,7 +28,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(None, hs) + self.store = ProfileStore(hs.get_db_conn(), hs) self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py new file mode 100644 index 0000000000..f671599cb8 --- /dev/null +++ b/tests/storage/test_purge.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.client.v1 import room + +from tests.unittest import HomeserverTestCase + + +class PurgeTests(HomeserverTestCase): + + user_id = "@red:server" + servlets = [room.register_servlets] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) + return hs + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_purge(self): + """ + Purging a room will delete everything before the topological point. + """ + # Send four messages to the room + first = self.helper.send(self.room_id, body="test1") + second = self.helper.send(self.room_id, body="test2") + third = self.helper.send(self.room_id, body="test3") + last = self.helper.send(self.room_id, body="test4") + + storage = self.hs.get_datastore() + + # Get the topological token + event = storage.get_topological_token_for_event(last["event_id"]) + self.pump() + event = self.successResultOf(event) + + # Purge everything before this topological token + purge = storage.purge_history(self.room_id, event, True) + self.pump() + self.assertEqual(self.successResultOf(purge), None) + + # Try and get the events + get_first = storage.get_event(first["event_id"]) + get_second = storage.get_event(second["event_id"]) + get_third = storage.get_event(third["event_id"]) + get_last = storage.get_event(last["event_id"]) + self.pump() + + # 1-3 should fail and last will succeed, meaning that 1-3 are deleted + # and last is not. + self.failureResultOf(get_first) + self.failureResultOf(get_second) + self.failureResultOf(get_third) + self.successResultOf(get_last) + + def test_purge_wont_delete_extrems(self): + """ + Purging a room will delete everything before the topological point. + """ + # Send four messages to the room + first = self.helper.send(self.room_id, body="test1") + second = self.helper.send(self.room_id, body="test2") + third = self.helper.send(self.room_id, body="test3") + last = self.helper.send(self.room_id, body="test4") + + storage = self.hs.get_datastore() + + # Set the topological token higher than it should be + event = storage.get_topological_token_for_event(last["event_id"]) + self.pump() + event = self.successResultOf(event) + event = "t{}-{}".format( + *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) + ) + + # Purge everything before this topological token + purge = storage.purge_history(self.room_id, event, True) + self.pump() + f = self.failureResultOf(purge) + self.assertIn("greater than forward", f.value.args[0]) + + # Try and get the events + get_first = storage.get_event(first["event_id"]) + get_second = storage.get_event(second["event_id"]) + get_third = storage.get_event(third["event_id"]) + get_last = storage.get_event(last["event_id"]) + self.pump() + + # Nothing is deleted. + self.successResultOf(get_first) + self.successResultOf(get_second) + self.successResultOf(get_third) + self.successResultOf(get_last) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index d717b9f94e..b910965932 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -185,8 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters out members with types=[] (state_dict, is_all) = yield self.store._get_some_state_from_cache( - self.store._state_group_cache, - group, [], filtered_types=[EventTypes.Member] + self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] ) self.assertEqual(is_all, True) @@ -200,19 +199,20 @@ class StateStoreTestCase(tests.unittest.TestCase): (state_dict, is_all) = yield self.store._get_some_state_from_cache( self.store._state_group_members_cache, - group, [], filtered_types=[EventTypes.Member] + group, + [], + filtered_types=[EventTypes.Member], ) self.assertEqual(is_all, True) - self.assertDictEqual( - {}, - state_dict, - ) + self.assertDictEqual({}, state_dict) # test _get_some_state_from_cache correctly filters in members with wildcard types (state_dict, is_all) = yield self.store._get_some_state_from_cache( self.store._state_group_cache, - group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + group, + [(EventTypes.Member, None)], + filtered_types=[EventTypes.Member], ) self.assertEqual(is_all, True) @@ -226,7 +226,9 @@ class StateStoreTestCase(tests.unittest.TestCase): (state_dict, is_all) = yield self.store._get_some_state_from_cache( self.store._state_group_members_cache, - group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + group, + [(EventTypes.Member, None)], + filtered_types=[EventTypes.Member], ) self.assertEqual(is_all, True) @@ -264,18 +266,15 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, True) - self.assertDictEqual( - { - (e5.type, e5.state_key): e5.event_id, - }, - state_dict, - ) + self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) # test _get_some_state_from_cache correctly filters in members with specific types # and no filtered_types (state_dict, is_all) = yield self.store._get_some_state_from_cache( self.store._state_group_members_cache, - group, [(EventTypes.Member, e5.state_key)], filtered_types=None + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=None, ) self.assertEqual(is_all, True) @@ -305,9 +304,7 @@ class StateStoreTestCase(tests.unittest.TestCase): key=group, value=state_dict_ids, # list fetched keys so it knows it's partial - fetched_keys=( - (e1.type, e1.state_key), - ), + fetched_keys=((e1.type, e1.state_key),), ) (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( @@ -315,20 +312,8 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, False) - self.assertEqual( - known_absent, - set( - [ - (e1.type, e1.state_key), - ] - ), - ) - self.assertDictEqual( - state_dict_ids, - { - (e1.type, e1.state_key): e1.event_id, - }, - ) + self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) + self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) ############################################ # test that things work with a partial cache @@ -336,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters out members with types=[] room_id = self.room.to_string() (state_dict, is_all) = yield self.store._get_some_state_from_cache( - self.store._state_group_cache, - group, [], filtered_types=[EventTypes.Member] + self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] ) self.assertEqual(is_all, False) @@ -346,7 +330,9 @@ class StateStoreTestCase(tests.unittest.TestCase): room_id = self.room.to_string() (state_dict, is_all) = yield self.store._get_some_state_from_cache( self.store._state_group_members_cache, - group, [], filtered_types=[EventTypes.Member] + group, + [], + filtered_types=[EventTypes.Member], ) self.assertEqual(is_all, True) @@ -355,20 +341,19 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members wildcard types (state_dict, is_all) = yield self.store._get_some_state_from_cache( self.store._state_group_cache, - group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + group, + [(EventTypes.Member, None)], + filtered_types=[EventTypes.Member], ) self.assertEqual(is_all, False) - self.assertDictEqual( - { - (e1.type, e1.state_key): e1.event_id, - }, - state_dict, - ) + self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) (state_dict, is_all) = yield self.store._get_some_state_from_cache( self.store._state_group_members_cache, - group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + group, + [(EventTypes.Member, None)], + filtered_types=[EventTypes.Member], ) self.assertEqual(is_all, True) @@ -389,12 +374,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, False) - self.assertDictEqual( - { - (e1.type, e1.state_key): e1.event_id, - }, - state_dict, - ) + self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) (state_dict, is_all) = yield self.store._get_some_state_from_cache( self.store._state_group_members_cache, @@ -404,18 +384,15 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, True) - self.assertDictEqual( - { - (e5.type, e5.state_key): e5.event_id, - }, - state_dict, - ) + self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) # test _get_some_state_from_cache correctly filters in members with specific types # and no filtered_types (state_dict, is_all) = yield self.store._get_some_state_from_cache( self.store._state_group_cache, - group, [(EventTypes.Member, e5.state_key)], filtered_types=None + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=None, ) self.assertEqual(is_all, False) @@ -423,13 +400,10 @@ class StateStoreTestCase(tests.unittest.TestCase): (state_dict, is_all) = yield self.store._get_some_state_from_cache( self.store._state_group_members_cache, - group, [(EventTypes.Member, e5.state_key)], filtered_types=None + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=None, ) self.assertEqual(is_all, True) - self.assertDictEqual( - { - (e5.type, e5.state_key): e5.event_id, - }, - state_dict, - ) + self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index b46e0ea7e2..0dde1ab2fe 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -30,7 +30,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(None, self.hs) + self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_mau.py b/tests/test_mau.py index 0732615447..bdbacb8448 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -185,20 +185,20 @@ class TestMauLimit(unittest.TestCase): self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def create_user(self, localpart): - request_data = json.dumps({ - "username": localpart, - "password": "monkey", - "auth": {"type": LoginType.DUMMY}, - }) + request_data = json.dumps( + { + "username": localpart, + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + } + ) - request, channel = make_request(b"POST", b"/register", request_data) + request, channel = make_request("POST", "/register", request_data) render(request, self.resource, self.reactor) - if channel.result["code"] != b"200": + if channel.code != 200: raise HttpResponseException( - int(channel.result["code"]), - channel.result["reason"], - channel.result["body"], + channel.code, channel.result["reason"], channel.result["body"] ).to_synapse_error() access_token = channel.json_body["access_token"] @@ -206,12 +206,12 @@ class TestMauLimit(unittest.TestCase): return access_token def do_sync_for_user(self, token): - request, channel = make_request(b"GET", b"/sync", access_token=token) + request, channel = make_request( + "GET", "/sync", access_token=token.encode('ascii') + ) render(request, self.resource, self.reactor) - if channel.result["code"] != b"200": + if channel.code != 200: raise HttpResponseException( - int(channel.result["code"]), - channel.result["reason"], - channel.result["body"], + channel.code, channel.result["reason"], channel.result["body"] ).to_synapse_error() diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000000..17897711a1 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.metrics import InFlightGauge + +from tests import unittest + + +class TestMauLimit(unittest.TestCase): + def test_basic(self): + gauge = InFlightGauge( + "test1", "", + labels=["test_label"], + sub_metrics=["foo", "bar"], + ) + + def handle1(metrics): + metrics.foo += 2 + metrics.bar = max(metrics.bar, 5) + + def handle2(metrics): + metrics.foo += 3 + metrics.bar = max(metrics.bar, 7) + + gauge.register(("key1",), handle1) + + self.assert_dict({ + "test1_total": {("key1",): 1}, + "test1_foo": {("key1",): 2}, + "test1_bar": {("key1",): 5}, + }, self.get_metrics_from_gauge(gauge)) + + gauge.unregister(("key1",), handle1) + + self.assert_dict({ + "test1_total": {("key1",): 0}, + "test1_foo": {("key1",): 0}, + "test1_bar": {("key1",): 0}, + }, self.get_metrics_from_gauge(gauge)) + + gauge.register(("key1",), handle1) + gauge.register(("key2",), handle2) + + self.assert_dict({ + "test1_total": {("key1",): 1, ("key2",): 1}, + "test1_foo": {("key1",): 2, ("key2",): 3}, + "test1_bar": {("key1",): 5, ("key2",): 7}, + }, self.get_metrics_from_gauge(gauge)) + + gauge.unregister(("key2",), handle2) + gauge.register(("key1",), handle2) + + self.assert_dict({ + "test1_total": {("key1",): 2, ("key2",): 0}, + "test1_foo": {("key1",): 5, ("key2",): 0}, + "test1_bar": {("key1",): 7, ("key2",): 0}, + }, self.get_metrics_from_gauge(gauge)) + + def get_metrics_from_gauge(self, gauge): + results = {} + + for r in gauge.collect(): + results[r.name] = { + tuple(labels[x] for x in gauge.labels): value + for _, labels, value in r.samples + } + + return results diff --git a/tests/test_state.py b/tests/test_state.py index 452a123c3a..e20c33322a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -180,7 +180,7 @@ class StateTestCase(unittest.TestCase): graph = Graph( nodes={ "START": DictObj( - type=EventTypes.Create, state_key="", content={}, depth=1, + type=EventTypes.Create, state_key="", content={}, depth=1 ), "A": DictObj(type=EventTypes.Message, depth=2), "B": DictObj(type=EventTypes.Message, depth=3), diff --git a/tests/test_types.py b/tests/test_types.py index be072d402b..0f5c8bfaf9 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -14,12 +14,12 @@ # limitations under the License. from synapse.api.errors import SynapseError -from synapse.server import HomeServer from synapse.types import GroupID, RoomAlias, UserID from tests import unittest +from tests.utils import TestHomeServer -mock_homeserver = HomeServer(hostname="my.domain") +mock_homeserver = TestHomeServer(hostname="my.domain") class UserIDTestCase(unittest.TestCase): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 8d8ce0cab9..2eea3b098b 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -96,7 +96,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter.append(evt) # the erasey user gets erased - self.hs.get_datastore().mark_user_erased("@erased:local_hs") + yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") # ... and the filtering happens. filtered = yield filter_events_for_server( diff --git a/tests/unittest.py b/tests/unittest.py index d852e2465a..a3d39920db 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -22,6 +22,7 @@ from canonicaljson import json import twisted import twisted.logger +from twisted.internet.defer import Deferred from twisted.trial import unittest from synapse.http.server import JsonResource @@ -151,6 +152,7 @@ class HomeserverTestCase(TestCase): hijack_auth (bool): Whether to hijack auth to return the user specified in user_id. """ + servlets = [] hijack_auth = True @@ -279,3 +281,15 @@ class HomeserverTestCase(TestCase): kwargs = dict(kwargs) kwargs.update(self._hs_args) return setup_test_homeserver(self.addCleanup, *args, **kwargs) + + def pump(self, by=0.0): + """ + Pump the reactor enough that Deferreds will fire. + """ + self.reactor.pump([by] * 100) + + def get_success(self, d): + if not isinstance(d, Deferred): + return d + self.pump() + return self.successResultOf(d) diff --git a/tests/utils.py b/tests/utils.py index e8ef10445c..215226debf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,11 +26,12 @@ from twisted.internet import defer, reactor from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error +from synapse.config.server import ServerConfig from synapse.federation.transport import server from synapse.http.server import HttpServer from synapse.server import HomeServer -from synapse.storage import PostgresEngine -from synapse.storage.engines import create_engine +from synapse.storage import DataStore +from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import ( _get_or_create_schema_state, _setup_new_database, @@ -41,6 +42,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) +LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False) POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) @@ -92,10 +94,19 @@ def setupdb(): atexit.register(_cleanup) +class TestHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + + @defer.inlineCallbacks def setup_test_homeserver( - cleanup_func, name="test", datastore=None, config=None, reactor=None, - homeserverToUse=HomeServer, **kargs + cleanup_func, + name="test", + datastore=None, + config=None, + reactor=None, + homeserverToUse=TestHomeServer, + **kargs ): """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -141,8 +152,11 @@ def setup_test_homeserver( config.hs_disabled_message = "" config.hs_disabled_limit_type = "" config.max_mau_value = 50 + config.mau_trial_days = 0 config.mau_limits_reserved_threepids = [] config.admin_contact = None + config.rc_messages_per_second = 10000 + config.rc_message_burst_count = 10000 # we need a sane default_room_version, otherwise attempts to create rooms will # fail. @@ -152,6 +166,11 @@ def setup_test_homeserver( # background, which upsets the test runner. config.update_user_directory = False + def is_threepid_reserved(threepid): + return ServerConfig.is_threepid_reserved(config, threepid) + + config.is_threepid_reserved.side_effect = is_threepid_reserved + config.use_frozen_dicts = True config.ldap_enabled = False @@ -232,8 +251,9 @@ def setup_test_homeserver( cur.close() db_conn.close() - # Register the cleanup hook - cleanup_func(cleanup) + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) hs.setup() else: @@ -307,7 +327,9 @@ class MockHttpResource(HttpServer): @patch('twisted.web.http.Request') @defer.inlineCallbacks - def trigger(self, http_method, path, content, mock_request, federation_auth=False): + def trigger( + self, http_method, path, content, mock_request, federation_auth_origin=None + ): """ Fire an HTTP event. Args: @@ -316,6 +338,7 @@ class MockHttpResource(HttpServer): content : The HTTP body mock_request : Mocked request to pass to the event so it can get content. + federation_auth_origin (bytes|None): domain to authenticate as, for federation Returns: A tuple of (code, response) Raises: @@ -336,8 +359,10 @@ class MockHttpResource(HttpServer): mock_request.getClientIP.return_value = "-" headers = {} - if federation_auth: - headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="] + if federation_auth_origin is not None: + headers[b"Authorization"] = [ + b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,) + ] mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it @@ -556,16 +581,16 @@ def create_room(hs, room_id, creator_id): event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() - builder = event_builder_factory.new({ - "type": EventTypes.Create, - "state_key": "", - "sender": creator_id, - "room_id": room_id, - "content": {}, - }) - - event, context = yield event_creation_handler.create_new_client_event( - builder + builder = event_builder_factory.new( + { + "type": EventTypes.Create, + "state_key": "", + "sender": creator_id, + "room_id": room_id, + "content": {}, + } ) + event, context = yield event_creation_handler.create_new_client_event(builder) + yield store.persist_event(event, context) |