diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index d5f1777093..6bfc1970ad 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -43,7 +43,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
self.generate_config()
with open(self.file, "r") as f:
- raw = yaml.load(f)
+ raw = yaml.safe_load(f)
self.assertIn("macaroon_secret_key", raw)
config = HomeServerConfig.load_config("", ["-c", self.file])
diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
index 3dc2631523..47fffcfeb2 100644
--- a/tests/config/test_room_directory.py
+++ b/tests/config/test_room_directory.py
@@ -22,7 +22,7 @@ from tests import unittest
class RoomDirectoryConfigTestCase(unittest.TestCase):
def test_alias_creation_acl(self):
- config = yaml.load("""
+ config = yaml.safe_load("""
alias_creation_rules:
- user_id: "*bob*"
alias: "*"
@@ -74,7 +74,7 @@ class RoomDirectoryConfigTestCase(unittest.TestCase):
))
def test_room_publish_acl(self):
- config = yaml.load("""
+ config = yaml.safe_load("""
alias_creation_rules: []
room_list_publication_rules:
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 2217eb2a10..017ea0385e 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -22,8 +22,6 @@ from synapse.api.errors import ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
from synapse.types import RoomAlias, UserID, create_requester
-from tests.utils import default_config, setup_test_homeserver
-
from .. import unittest
@@ -32,26 +30,23 @@ class RegistrationHandlers(object):
self.registration_handler = RegistrationHandler(hs)
-class RegistrationTestCase(unittest.TestCase):
+class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_distributor = Mock()
- self.mock_distributor.declare("registered_user")
- self.mock_captcha_client = Mock()
-
- hs_config = default_config("test")
+ def make_homeserver(self, reactor, clock):
+ hs_config = self.default_config("test")
# some of the tests rely on us having a user consent version
hs_config.user_consent_version = "test_consent_version"
hs_config.max_mau_value = 50
- self.hs = yield setup_test_homeserver(
- self.addCleanup,
- config=hs_config,
- expire_access_token=True,
- )
+ hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.mock_distributor = Mock()
+ self.mock_distributor.declare("registered_user")
+ self.mock_captcha_client = Mock()
self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')
)
@@ -63,136 +58,133 @@ class RegistrationTestCase(unittest.TestCase):
self.requester = create_requester("@requester:test")
- @defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
frank = UserID.from_string("@frank:test")
user_id = frank.to_string()
requester = create_requester(user_id)
- result_user_id, result_token = yield self.handler.get_or_create_user(
- requester, frank.localpart, "Frankie"
+ result_user_id, result_token = self.get_success(
+ self.handler.get_or_create_user(requester, frank.localpart, "Frankie")
)
self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None)
self.assertEquals(result_token, 'secret')
- @defer.inlineCallbacks
def test_if_user_exists(self):
store = self.hs.get_datastore()
frank = UserID.from_string("@frank:test")
- yield store.register(
- user_id=frank.to_string(),
- token="jkv;g498752-43gj['eamb!-5",
- password_hash=None,
+ self.get_success(
+ store.register(
+ user_id=frank.to_string(),
+ token="jkv;g498752-43gj['eamb!-5",
+ password_hash=None,
+ )
)
local_part = frank.localpart
user_id = frank.to_string()
requester = create_requester(user_id)
- result_user_id, result_token = yield self.handler.get_or_create_user(
- requester, local_part, None
+ result_user_id, result_token = self.get_success(
+ self.handler.get_or_create_user(requester, local_part, None)
)
self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None)
- @defer.inlineCallbacks
def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.handler.get_or_create_user(self.requester, 'a', "display_name")
+ self.get_success(
+ self.handler.get_or_create_user(self.requester, 'a', "display_name")
+ )
- @defer.inlineCallbacks
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
- yield self.handler.get_or_create_user(self.requester, 'c', "User")
+ self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User"))
- @defer.inlineCallbacks
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users)
)
- with self.assertRaises(ResourceLimitError):
- yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
+ self.get_failure(
+ self.handler.get_or_create_user(self.requester, 'b', "display_name"),
+ ResourceLimitError,
+ )
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
- with self.assertRaises(ResourceLimitError):
- yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
+ self.get_failure(
+ self.handler.get_or_create_user(self.requester, 'b', "display_name"),
+ ResourceLimitError,
+ )
- @defer.inlineCallbacks
def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users)
)
- with self.assertRaises(ResourceLimitError):
- yield self.handler.register(localpart="local_part")
+ self.get_failure(
+ self.handler.register(localpart="local_part"), ResourceLimitError
+ )
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
- with self.assertRaises(ResourceLimitError):
- yield self.handler.register(localpart="local_part")
+ self.get_failure(
+ self.handler.register(localpart="local_part"), ResourceLimitError
+ )
- @defer.inlineCallbacks
def test_auto_create_auto_join_rooms(self):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- res = yield self.handler.register(localpart='jeff')
- rooms = yield self.store.get_rooms_for_user(res[0])
+ res = self.get_success(self.handler.register(localpart='jeff'))
+ rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
directory_handler = self.hs.get_handlers().directory_handler
room_alias = RoomAlias.from_string(room_alias_str)
- room_id = yield directory_handler.get_association(room_alias)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
self.assertTrue(room_id['room_id'] in rooms)
self.assertEqual(len(rooms), 1)
- @defer.inlineCallbacks
def test_auto_create_auto_join_rooms_with_no_rooms(self):
self.hs.config.auto_join_rooms = []
frank = UserID.from_string("@frank:test")
- res = yield self.handler.register(frank.localpart)
+ res = self.get_success(self.handler.register(frank.localpart))
self.assertEqual(res[0], frank.to_string())
- rooms = yield self.store.get_rooms_for_user(res[0])
+ rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
self.assertEqual(len(rooms), 0)
- @defer.inlineCallbacks
def test_auto_create_auto_join_where_room_is_another_domain(self):
self.hs.config.auto_join_rooms = ["#room:another"]
frank = UserID.from_string("@frank:test")
- res = yield self.handler.register(frank.localpart)
+ res = self.get_success(self.handler.register(frank.localpart))
self.assertEqual(res[0], frank.to_string())
- rooms = yield self.store.get_rooms_for_user(res[0])
+ rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
self.assertEqual(len(rooms), 0)
- @defer.inlineCallbacks
def test_auto_create_auto_join_where_auto_create_is_false(self):
self.hs.config.autocreate_auto_join_rooms = False
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- res = yield self.handler.register(localpart='jeff')
- rooms = yield self.store.get_rooms_for_user(res[0])
+ res = self.get_success(self.handler.register(localpart='jeff'))
+ rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
self.assertEqual(len(rooms), 0)
- @defer.inlineCallbacks
def test_auto_create_auto_join_rooms_when_support_user_exists(self):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.is_support_user = Mock(return_value=True)
- res = yield self.handler.register(localpart='support')
- rooms = yield self.store.get_rooms_for_user(res[0])
+ res = self.get_success(self.handler.register(localpart='support'))
+ rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
self.assertEqual(len(rooms), 0)
directory_handler = self.hs.get_handlers().directory_handler
room_alias = RoomAlias.from_string(room_alias_str)
- with self.assertRaises(SynapseError):
- yield directory_handler.get_association(room_alias)
+ self.get_failure(directory_handler.get_association(room_alias), SynapseError)
- @defer.inlineCallbacks
def test_auto_create_auto_join_where_no_consent(self):
"""Test to ensure that the first user is not auto-joined to a room if
they have not given general consent.
@@ -208,27 +200,27 @@ class RegistrationTestCase(unittest.TestCase):
# (Messing with the internals of event_creation_handler is fragile
# but can't see a better way to do this. One option could be to subclass
# the test with custom config.)
- event_creation_handler._block_events_without_consent_error = ("Error")
+ event_creation_handler._block_events_without_consent_error = "Error"
event_creation_handler._consent_uri_builder = Mock()
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
# When:-
# * the user is registered and post consent actions are called
- res = yield self.handler.register(localpart='jeff')
- yield self.handler.post_consent_actions(res[0])
+ res = self.get_success(self.handler.register(localpart='jeff'))
+ self.get_success(self.handler.post_consent_actions(res[0]))
# Then:-
# * Ensure that they have not been joined to the room
- rooms = yield self.store.get_rooms_for_user(res[0])
+ rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
self.assertEqual(len(rooms), 0)
- @defer.inlineCallbacks
def test_register_support_user(self):
- res = yield self.handler.register(localpart='user', user_type=UserTypes.SUPPORT)
+ res = self.get_success(
+ self.handler.register(localpart='user', user_type=UserTypes.SUPPORT)
+ )
self.assertTrue(self.store.is_support_user(res[0]))
- @defer.inlineCallbacks
def test_register_not_support_user(self):
- res = yield self.handler.register(localpart='user')
+ res = self.get_success(self.handler.register(localpart='user'))
self.assertFalse(self.store.is_support_user(res[0]))
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 13486930fb..6460cbc708 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -180,7 +180,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with(
"farm",
- path="/_matrix/federation/v1/send/1000000/",
+ path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
"m.typing",
content={
@@ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
json_data_callback=ANY,
long_retries=True,
backoff_on_404=True,
+ try_trailing_slash_on_400=True,
)
def test_started_typing_remote_recv(self):
@@ -201,7 +202,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
(request, channel) = self.make_request(
"PUT",
- "/_matrix/federation/v1/send/1000000/",
+ "/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json(
"m.typing",
content={
@@ -257,7 +258,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with(
"farm",
- path="/_matrix/federation/v1/send/1000000/",
+ path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
"m.typing",
content={
@@ -269,6 +270,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
json_data_callback=ANY,
long_retries=True,
backoff_on_404=True,
+ try_trailing_slash_on_400=True,
)
self.assertEquals(self.event_source.get_current_key(), 1)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index b03b37affe..cd8e086f86 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -268,6 +268,105 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, TimeoutError)
+ def test_client_requires_trailing_slashes(self):
+ """
+ If a connection is made to a client but the client rejects it due to
+ requiring a trailing slash. We need to retry the request with a
+ trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
+ """
+ d = self.cl.get_json(
+ "testserv:8008", "foo/bar", try_trailing_slash_on_400=True,
+ )
+
+ # Send the request
+ self.pump()
+
+ # there should have been a call to connectTCP
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (_host, _port, factory, _timeout, _bindAddress) = clients[0]
+
+ # complete the connection and wire it up to a fake transport
+ client = factory.buildProtocol(None)
+ conn = StringTransport()
+ client.makeConnection(conn)
+
+ # that should have made it send the request to the connection
+ self.assertRegex(conn.value(), b"^GET /foo/bar")
+
+ # Clear the original request data before sending a response
+ conn.clear()
+
+ # Send the HTTP response
+ client.dataReceived(
+ b"HTTP/1.1 400 Bad Request\r\n"
+ b"Content-Type: application/json\r\n"
+ b"Content-Length: 59\r\n"
+ b"\r\n"
+ b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}'
+ )
+
+ # We should get another request with a trailing slash
+ self.assertRegex(conn.value(), b"^GET /foo/bar/")
+
+ # Send a happy response this time
+ client.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Content-Type: application/json\r\n"
+ b"Content-Length: 2\r\n"
+ b"\r\n"
+ b'{}'
+ )
+
+ # We should get a successful response
+ r = self.successResultOf(d)
+ self.assertEqual(r, {})
+
+ def test_client_does_not_retry_on_400_plus(self):
+ """
+ Another test for trailing slashes but now test that we don't retry on
+ trailing slashes on a non-400/M_UNRECOGNIZED response.
+
+ See test_client_requires_trailing_slashes() for context.
+ """
+ d = self.cl.get_json(
+ "testserv:8008", "foo/bar", try_trailing_slash_on_400=True,
+ )
+
+ # Send the request
+ self.pump()
+
+ # there should have been a call to connectTCP
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (_host, _port, factory, _timeout, _bindAddress) = clients[0]
+
+ # complete the connection and wire it up to a fake transport
+ client = factory.buildProtocol(None)
+ conn = StringTransport()
+ client.makeConnection(conn)
+
+ # that should have made it send the request to the connection
+ self.assertRegex(conn.value(), b"^GET /foo/bar")
+
+ # Clear the original request data before sending a response
+ conn.clear()
+
+ # Send the HTTP response
+ client.dataReceived(
+ b"HTTP/1.1 404 Not Found\r\n"
+ b"Content-Type: application/json\r\n"
+ b"Content-Length: 2\r\n"
+ b"\r\n"
+ b"{}"
+ )
+
+ # We should not get another request
+ self.assertEqual(conn.value(), b"")
+
+ # We should get a 404 failure response
+ self.failureResultOf(d)
+
def test_client_sends_body(self):
self.cl.post_json(
"testserv:8008", "foo/bar", timeout=10000,
diff --git a/tests/replication/tcp/__init__.py b/tests/replication/tcp/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/tests/replication/tcp/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/replication/tcp/streams/__init__.py b/tests/replication/tcp/streams/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/tests/replication/tcp/streams/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
new file mode 100644
index 0000000000..38b368a972
--- /dev/null
+++ b/tests/replication/tcp/streams/_base.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.replication.tcp.commands import ReplicateCommand
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+
+from tests import unittest
+from tests.server import FakeTransport
+
+
+class BaseStreamTestCase(unittest.HomeserverTestCase):
+ """Base class for tests of the replication streams"""
+ def prepare(self, reactor, clock, hs):
+ # build a replication server
+ server_factory = ReplicationStreamProtocolFactory(self.hs)
+ self.streamer = server_factory.streamer
+ server = server_factory.buildProtocol(None)
+
+ # build a replication client, with a dummy handler
+ self.test_handler = TestReplicationClientHandler()
+ self.client = ClientReplicationStreamProtocol(
+ "client", "test", clock, self.test_handler
+ )
+
+ # wire them together
+ self.client.makeConnection(FakeTransport(server, reactor))
+ server.makeConnection(FakeTransport(self.client, reactor))
+
+ def replicate(self):
+ """Tell the master side of replication that something has happened, and then
+ wait for the replication to occur.
+ """
+ self.streamer.on_notifier_poke()
+ self.pump(0.1)
+
+ def replicate_stream(self, stream, token="NOW"):
+ """Make the client end a REPLICATE command to set up a subscription to a stream"""
+ self.client.send_command(ReplicateCommand(stream, token))
+
+
+class TestReplicationClientHandler(object):
+ """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
+ def __init__(self):
+ self.received_rdata_rows = []
+
+ def get_streams_to_replicate(self):
+ return {}
+
+ def get_currently_syncing_users(self):
+ return []
+
+ def update_connection(self, connection):
+ pass
+
+ def finished_connecting(self):
+ pass
+
+ def on_rdata(self, stream_name, token, rows):
+ for r in rows:
+ self.received_rdata_rows.append(
+ (stream_name, token, r)
+ )
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
new file mode 100644
index 0000000000..9aa9dfe82e
--- /dev/null
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.replication.tcp.streams import ReceiptsStreamRow
+
+from tests.replication.tcp.streams._base import BaseStreamTestCase
+
+USER_ID = "@feeling:blue"
+ROOM_ID = "!room:blue"
+EVENT_ID = "$event:blue"
+
+
+class ReceiptsStreamTestCase(BaseStreamTestCase):
+ def test_receipt(self):
+ # make the client subscribe to the receipts stream
+ self.replicate_stream("receipts", "NOW")
+
+ # tell the master to send a new receipt
+ self.get_success(
+ self.hs.get_datastore().insert_receipt(
+ ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
+ )
+ )
+ self.replicate()
+
+ # there should be one RDATA command
+ rdata_rows = self.test_handler.received_rdata_rows
+ self.assertEqual(1, len(rdata_rows))
+ self.assertEqual(rdata_rows[0][0], "receipts")
+ row = rdata_rows[0][2] # type: ReceiptsStreamRow
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual("m.read", row.receipt_type)
+ self.assertEqual(USER_ID, row.user_id)
+ self.assertEqual(EVENT_ID, row.event_id)
+ self.assertEqual({"a": 1}, row.data)
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index 0caa4aa802..ef38473bd6 100644
--- a/tests/rest/client/v1/test_admin.py
+++ b/tests/rest/client/v1/test_admin.py
@@ -20,7 +20,7 @@ import json
from mock import Mock
from synapse.api.constants import UserTypes
-from synapse.rest.client.v1 import admin, login, room
+from synapse.rest.client.v1 import admin, events, login, room
from tests import unittest
@@ -359,7 +359,9 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
+ events.register_servlets,
room.register_servlets,
+ room.register_deprecated_servlets,
]
def prepare(self, reactor, clock, hs):
@@ -426,3 +428,65 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
self.store.get_users_in_room(room_id),
)
self.assertEqual([], users_in_room)
+
+ @unittest.DEBUG
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ """
+
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode('ascii'),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_token,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode('ascii'),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(room_id, expect_code=403)
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET",
+ url.encode('ascii'),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"],
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET",
+ url.encode('ascii'),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"],
+ )
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 9c401bf300..05b0143c42 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -18,136 +18,11 @@ import time
import attr
-from twisted.internet import defer
-
from synapse.api.constants import Membership
-from tests import unittest
from tests.server import make_request, render
-class RestTestCase(unittest.TestCase):
- """Contains extra helper functions to quickly and clearly perform a given
- REST action, which isn't the focus of the test.
-
- This subclass assumes there are mock_resource and auth_user_id attributes.
- """
-
- def __init__(self, *args, **kwargs):
- super(RestTestCase, self).__init__(*args, **kwargs)
- self.mock_resource = None
- self.auth_user_id = None
-
- @defer.inlineCallbacks
- def create_room_as(self, room_creator, is_public=True, tok=None):
- temp_id = self.auth_user_id
- self.auth_user_id = room_creator
- path = "/createRoom"
- content = "{}"
- if not is_public:
- content = '{"visibility":"private"}'
- if tok:
- path = path + "?access_token=%s" % tok
- (code, response) = yield self.mock_resource.trigger("POST", path, content)
- self.assertEquals(200, code, msg=str(response))
- self.auth_user_id = temp_id
- defer.returnValue(response["room_id"])
-
- @defer.inlineCallbacks
- def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
- yield self.change_membership(
- room=room,
- src=src,
- targ=targ,
- tok=tok,
- membership=Membership.INVITE,
- expect_code=expect_code,
- )
-
- @defer.inlineCallbacks
- def join(self, room=None, user=None, expect_code=200, tok=None):
- yield self.change_membership(
- room=room,
- src=user,
- targ=user,
- tok=tok,
- membership=Membership.JOIN,
- expect_code=expect_code,
- )
-
- @defer.inlineCallbacks
- def leave(self, room=None, user=None, expect_code=200, tok=None):
- yield self.change_membership(
- room=room,
- src=user,
- targ=user,
- tok=tok,
- membership=Membership.LEAVE,
- expect_code=expect_code,
- )
-
- @defer.inlineCallbacks
- def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
- temp_id = self.auth_user_id
- self.auth_user_id = src
-
- path = "/rooms/%s/state/m.room.member/%s" % (room, targ)
- if tok:
- path = path + "?access_token=%s" % tok
-
- data = {"membership": membership}
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, json.dumps(data)
- )
- self.assertEquals(
- expect_code,
- code,
- msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response),
- )
-
- self.auth_user_id = temp_id
-
- @defer.inlineCallbacks
- def register(self, user_id):
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/register",
- json.dumps(
- {"user": user_id, "password": "test", "type": "m.login.password"}
- ),
- )
- self.assertEquals(200, code, msg=response)
- defer.returnValue(response)
-
- @defer.inlineCallbacks
- def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
- if txn_id is None:
- txn_id = "m%s" % (str(time.time()))
- if body is None:
- body = "body_text_here"
-
- path = "/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
- content = '{"msgtype":"m.text","body":"%s"}' % body
- if tok:
- path = path + "?access_token=%s" % tok
-
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(expect_code, code, msg=str(response))
-
- def assert_dict(self, required, actual):
- """Does a partial assert of a dict.
-
- Args:
- required (dict): The keys and value which MUST be in 'actual'.
- actual (dict): The test result. Extra keys will not be checked.
- """
- for key in required:
- self.assertEquals(
- required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
- )
-
-
@attr.s
class RestHelper(object):
"""Contains extra helper functions to quickly and clearly perform a given
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 3bd9f1e9c1..be73e718c2 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -1,3 +1,18 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018, 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from mock import Mock
from twisted.internet import defer
@@ -9,16 +24,18 @@ from synapse.server_notices.resource_limits_server_notices import (
)
from tests import unittest
-from tests.utils import default_config, setup_test_homeserver
-class TestResourceLimitsServerNotices(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs_config = default_config(name="test")
+class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
+
+ def make_homeserver(self, reactor, clock):
+ hs_config = self.default_config("test")
hs_config.server_notices_mxid = "@server:test"
- self.hs = yield setup_test_homeserver(self.addCleanup, config=hs_config)
+ hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
self.server_notices_sender = self.hs.get_server_notices_sender()
# relying on [1] is far from ideal, but the only case where
@@ -53,23 +70,21 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
self._rlsn._store.get_tags_for_room = Mock(return_value={})
self.hs.config.admin_contact = "mailto:user@test.com"
- @defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_flag_off(self):
"""Tests cases where the flags indicate nothing to do"""
# test hs disabled case
self.hs.config.hs_disabled = True
- yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
# Test when mau limiting disabled
self.hs.config.hs_disabled = False
self.hs.limit_usage_by_mau = False
- yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
- @defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""
@@ -81,13 +96,14 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
return_value=defer.succeed({"123": mock_event})
)
- yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
self._send_notice.assert_called_once()
- @defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
- """Test when user has blocked notice, but notice ought to be there (NOOP)"""
+ """
+ Test when user has blocked notice, but notice ought to be there (NOOP)
+ """
self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, 'foo')
)
@@ -98,52 +114,49 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
- yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
- @defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
- """Test when user does not have blocked notice, but should have one"""
+ """
+ Test when user does not have blocked notice, but should have one
+ """
self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, 'foo')
)
- yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check contents, but 2 calls == set blocking event
self.assertTrue(self._send_notice.call_count == 2)
- @defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
- """Test when user does not have blocked notice, nor should they (NOOP)"""
-
+ """
+ Test when user does not have blocked notice, nor should they (NOOP)
+ """
self._rlsn._auth.check_auth_blocking = Mock()
- yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
- @defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
-
- """Test when user is not part of the MAU cohort - this should not ever
+ """
+ Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
-
self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
- yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
-class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
+class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = self.hs.get_datastore()
self.server_notices_sender = self.hs.get_server_notices_sender()
self.server_notices_manager = self.hs.get_server_notices_manager()
@@ -168,26 +181,27 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
self.hs.config.admin_contact = "mailto:user@test.com"
- @defer.inlineCallbacks
def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(return_value=1000)
# Call the function multiple times to ensure we only send the notice once
- yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
- yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
- yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Now lets get the last load of messages in the service notice room and
# check that there is only one server notice
- room_id = yield self.server_notices_manager.get_notice_room_for_user(
- self.user_id
+ room_id = self.get_success(
+ self.server_notices_manager.get_notice_room_for_user(self.user_id)
)
- token = yield self.event_source.get_current_token()
- events, _ = yield self.store.get_recent_events_for_room(
- room_id, limit=100, end_token=token.room_key
+ token = self.get_success(self.event_source.get_current_token())
+ events, _ = self.get_success(
+ self.store.get_recent_events_for_room(
+ room_id, limit=100, end_token=token.room_key
+ )
)
count = 0
diff --git a/tests/unittest.py b/tests/unittest.py
index 7772a47078..27403de908 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -314,6 +314,9 @@ class HomeserverTestCase(TestCase):
"""
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
+ if "config" not in kwargs:
+ config = self.default_config()
+ kwargs["config"] = config
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
@@ -336,6 +339,15 @@ class HomeserverTestCase(TestCase):
self.pump(by=by)
return self.successResultOf(d)
+ def get_failure(self, d, exc):
+ """
+ Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
+ """
+ if not isinstance(d, Deferred):
+ return d
+ self.pump()
+ return self.failureResultOf(d, exc)
+
def register_user(self, username, password, admin=False):
"""
Register a user. Requires the Admin API be registered.
diff --git a/tests/utils.py b/tests/utils.py
index 67e99a0e40..1b8eeb5167 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -43,6 +44,10 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
+#
+# When running under postgres, we first create a base database with the name
+# POSTGRES_BASE_DB and update it to the current schema. Then, for each test case, we
+# create another unique database, using the base database as a template.
USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None)
@@ -50,28 +55,20 @@ POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
+# the dbname we will connect to in order to create the base database.
+POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
-def setupdb():
+def setupdb():
# If we're using PostgreSQL, set up the db once
if USE_POSTGRES_FOR_TESTS:
- pgconfig = {
- "name": "psycopg2",
- "args": {
- "database": POSTGRES_BASE_DB,
- "user": POSTGRES_USER,
- "host": POSTGRES_HOST,
- "password": POSTGRES_PASSWORD,
- "cp_min": 1,
- "cp_max": 5,
- },
- }
- config = Mock()
- config.password_providers = []
- config.database_config = pgconfig
- db_engine = create_engine(pgconfig)
+ # create a PostgresEngine
+ db_engine = create_engine({"name": "psycopg2", "args": {}})
+
+ # connect to postgres to create the base database.
db_conn = db_engine.module.connect(
- user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD
+ user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD,
+ dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
db_conn.autocommit = True
cur = db_conn.cursor()
@@ -96,7 +93,8 @@ def setupdb():
def _cleanup():
db_conn = db_engine.module.connect(
- user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD
+ user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD,
+ dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
db_conn.autocommit = True
cur = db_conn.cursor()
|